rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 414            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 415            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 416            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 417            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 419            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 420            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 421            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 422            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 423            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 426            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 427            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 428            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 429            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 430            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 431            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 432            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 437            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 438            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 439            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 440            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 441            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 442            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 443            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 444            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 445            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 446            .add_message_handler(update_context)
 447            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 448            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 449            .add_request_handler(share_agent_thread)
 450            .add_request_handler(get_shared_agent_thread)
 451            .add_request_handler(forward_project_search_chunk);
 452
 453        Arc::new(server)
 454    }
 455
 456    pub async fn start(&self) -> Result<()> {
 457        let server_id = *self.id.lock();
 458        let app_state = self.app_state.clone();
 459        let peer = self.peer.clone();
 460        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 461        let pool = self.connection_pool.clone();
 462        let livekit_client = self.app_state.livekit_client.clone();
 463
 464        let span = info_span!("start server");
 465        self.app_state.executor.spawn_detached(
 466            async move {
 467                tracing::info!("waiting for cleanup timeout");
 468                timeout.await;
 469                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 470
 471                app_state
 472                    .db
 473                    .delete_stale_channel_chat_participants(
 474                        &app_state.config.zed_environment,
 475                        server_id,
 476                    )
 477                    .await
 478                    .trace_err();
 479
 480                if let Some((room_ids, channel_ids)) = app_state
 481                    .db
 482                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 483                    .await
 484                    .trace_err()
 485                {
 486                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 487                    tracing::info!(
 488                        stale_channel_buffer_count = channel_ids.len(),
 489                        "retrieved stale channel buffers"
 490                    );
 491
 492                    for channel_id in channel_ids {
 493                        if let Some(refreshed_channel_buffer) = app_state
 494                            .db
 495                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 496                            .await
 497                            .trace_err()
 498                        {
 499                            for connection_id in refreshed_channel_buffer.connection_ids {
 500                                peer.send(
 501                                    connection_id,
 502                                    proto::UpdateChannelBufferCollaborators {
 503                                        channel_id: channel_id.to_proto(),
 504                                        collaborators: refreshed_channel_buffer
 505                                            .collaborators
 506                                            .clone(),
 507                                    },
 508                                )
 509                                .trace_err();
 510                            }
 511                        }
 512                    }
 513
 514                    for room_id in room_ids {
 515                        let mut contacts_to_update = HashSet::default();
 516                        let mut canceled_calls_to_user_ids = Vec::new();
 517                        let mut livekit_room = String::new();
 518                        let mut delete_livekit_room = false;
 519
 520                        if let Some(mut refreshed_room) = app_state
 521                            .db
 522                            .clear_stale_room_participants(room_id, server_id)
 523                            .await
 524                            .trace_err()
 525                        {
 526                            tracing::info!(
 527                                room_id = room_id.0,
 528                                new_participant_count = refreshed_room.room.participants.len(),
 529                                "refreshed room"
 530                            );
 531                            room_updated(&refreshed_room.room, &peer);
 532                            if let Some(channel) = refreshed_room.channel.as_ref() {
 533                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 534                            }
 535                            contacts_to_update
 536                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 537                            contacts_to_update
 538                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 539                            canceled_calls_to_user_ids =
 540                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 541                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 542                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 543                        }
 544
 545                        {
 546                            let pool = pool.lock();
 547                            for canceled_user_id in canceled_calls_to_user_ids {
 548                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 549                                    peer.send(
 550                                        connection_id,
 551                                        proto::CallCanceled {
 552                                            room_id: room_id.to_proto(),
 553                                        },
 554                                    )
 555                                    .trace_err();
 556                                }
 557                            }
 558                        }
 559
 560                        for user_id in contacts_to_update {
 561                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 562                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 563                            if let Some((busy, contacts)) = busy.zip(contacts) {
 564                                let pool = pool.lock();
 565                                let updated_contact = contact_for_user(user_id, busy, &pool);
 566                                for contact in contacts {
 567                                    if let db::Contact::Accepted {
 568                                        user_id: contact_user_id,
 569                                        ..
 570                                    } = contact
 571                                    {
 572                                        for contact_conn_id in
 573                                            pool.user_connection_ids(contact_user_id)
 574                                        {
 575                                            peer.send(
 576                                                contact_conn_id,
 577                                                proto::UpdateContacts {
 578                                                    contacts: vec![updated_contact.clone()],
 579                                                    remove_contacts: Default::default(),
 580                                                    incoming_requests: Default::default(),
 581                                                    remove_incoming_requests: Default::default(),
 582                                                    outgoing_requests: Default::default(),
 583                                                    remove_outgoing_requests: Default::default(),
 584                                                },
 585                                            )
 586                                            .trace_err();
 587                                        }
 588                                    }
 589                                }
 590                            }
 591                        }
 592
 593                        if let Some(live_kit) = livekit_client.as_ref()
 594                            && delete_livekit_room
 595                        {
 596                            live_kit.delete_room(livekit_room).await.trace_err();
 597                        }
 598                    }
 599                }
 600
 601                app_state
 602                    .db
 603                    .delete_stale_channel_chat_participants(
 604                        &app_state.config.zed_environment,
 605                        server_id,
 606                    )
 607                    .await
 608                    .trace_err();
 609
 610                app_state
 611                    .db
 612                    .clear_old_worktree_entries(server_id)
 613                    .await
 614                    .trace_err();
 615
 616                app_state
 617                    .db
 618                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 619                    .await
 620                    .trace_err();
 621            }
 622            .instrument(span),
 623        );
 624        Ok(())
 625    }
 626
 627    pub fn teardown(&self) {
 628        self.peer.teardown();
 629        self.connection_pool.lock().reset();
 630        let _ = self.teardown.send(true);
 631    }
 632
 633    #[cfg(feature = "test-support")]
 634    pub fn reset(&self, id: ServerId) {
 635        self.teardown();
 636        *self.id.lock() = id;
 637        self.peer.reset(id.0 as u32);
 638        let _ = self.teardown.send(false);
 639    }
 640
 641    #[cfg(feature = "test-support")]
 642    pub fn id(&self) -> ServerId {
 643        *self.id.lock()
 644    }
 645
 646    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 647    where
 648        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 649        Fut: 'static + Send + Future<Output = Result<()>>,
 650        M: EnvelopedMessage,
 651    {
 652        let prev_handler = self.handlers.insert(
 653            TypeId::of::<M>(),
 654            Box::new(move |envelope, session, span| {
 655                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 656                let received_at = envelope.received_at;
 657                tracing::info!("message received");
 658                let start_time = Instant::now();
 659                let future = (handler)(
 660                    *envelope,
 661                    MessageContext {
 662                        session,
 663                        span: span.clone(),
 664                    },
 665                );
 666                async move {
 667                    let result = future.await;
 668                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 669                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 670                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 671                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 672                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 673                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 674                    match result {
 675                        Err(error) => {
 676                            tracing::error!(?error, "error handling message")
 677                        }
 678                        Ok(()) => tracing::info!("finished handling message"),
 679                    }
 680                }
 681                .boxed()
 682            }),
 683        );
 684        if prev_handler.is_some() {
 685            panic!("registered a handler for the same message twice");
 686        }
 687        self
 688    }
 689
 690    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 691    where
 692        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 693        Fut: 'static + Send + Future<Output = Result<()>>,
 694        M: EnvelopedMessage,
 695    {
 696        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 697        self
 698    }
 699
 700    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 701    where
 702        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 703        Fut: Send + Future<Output = Result<()>>,
 704        M: RequestMessage,
 705    {
 706        let handler = Arc::new(handler);
 707        self.add_handler(move |envelope, session| {
 708            let receipt = envelope.receipt();
 709            let handler = handler.clone();
 710            async move {
 711                let peer = session.peer.clone();
 712                let responded = Arc::new(AtomicBool::default());
 713                let response = Response {
 714                    peer: peer.clone(),
 715                    responded: responded.clone(),
 716                    receipt,
 717                };
 718                match (handler)(envelope.payload, response, session).await {
 719                    Ok(()) => {
 720                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 721                            Ok(())
 722                        } else {
 723                            Err(anyhow!("handler did not send a response"))?
 724                        }
 725                    }
 726                    Err(error) => {
 727                        let proto_err = match &error {
 728                            Error::Internal(err) => err.to_proto(),
 729                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 730                        };
 731                        peer.respond_with_error(receipt, proto_err)?;
 732                        Err(error)
 733                    }
 734                }
 735            }
 736        })
 737    }
 738
 739    pub fn handle_connection(
 740        self: &Arc<Self>,
 741        connection: Connection,
 742        address: String,
 743        principal: Principal,
 744        zed_version: ZedVersion,
 745        release_channel: Option<String>,
 746        user_agent: Option<String>,
 747        geoip_country_code: Option<String>,
 748        system_id: Option<String>,
 749        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 750        executor: Executor,
 751        connection_guard: Option<ConnectionGuard>,
 752    ) -> impl Future<Output = ()> + use<> {
 753        let this = self.clone();
 754        let span = info_span!("handle connection", %address,
 755            connection_id=field::Empty,
 756            user_id=field::Empty,
 757            login=field::Empty,
 758            user_agent=field::Empty,
 759            geoip_country_code=field::Empty,
 760            release_channel=field::Empty,
 761        );
 762        principal.update_span(&span);
 763        if let Some(user_agent) = user_agent {
 764            span.record("user_agent", user_agent);
 765        }
 766        if let Some(release_channel) = release_channel {
 767            span.record("release_channel", release_channel);
 768        }
 769
 770        if let Some(country_code) = geoip_country_code.as_ref() {
 771            span.record("geoip_country_code", country_code);
 772        }
 773
 774        let mut teardown = self.teardown.subscribe();
 775        async move {
 776            if *teardown.borrow() {
 777                tracing::error!("server is tearing down");
 778                return;
 779            }
 780
 781            let (connection_id, handle_io, mut incoming_rx) =
 782                this.peer.add_connection(connection, {
 783                    let executor = executor.clone();
 784                    move |duration| executor.sleep(duration)
 785                });
 786            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 787
 788            tracing::info!("connection opened");
 789
 790            let session = Session {
 791                principal: principal.clone(),
 792                connection_id,
 793                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 794                peer: this.peer.clone(),
 795                connection_pool: this.connection_pool.clone(),
 796                app_state: this.app_state.clone(),
 797                geoip_country_code,
 798                system_id,
 799                _executor: executor.clone(),
 800            };
 801
 802            if let Err(error) = this
 803                .send_initial_client_update(
 804                    connection_id,
 805                    zed_version,
 806                    send_connection_id,
 807                    &session,
 808                )
 809                .await
 810            {
 811                tracing::error!(?error, "failed to send initial client update");
 812                return;
 813            }
 814            drop(connection_guard);
 815
 816            let handle_io = handle_io.fuse();
 817            futures::pin_mut!(handle_io);
 818
 819            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 820            // This prevents deadlocks when e.g., client A performs a request to client B and
 821            // client B performs a request to client A. If both clients stop processing further
 822            // messages until their respective request completes, they won't have a chance to
 823            // respond to the other client's request and cause a deadlock.
 824            //
 825            // This arrangement ensures we will attempt to process earlier messages first, but fall
 826            // back to processing messages arrived later in the spirit of making progress.
 827            const MAX_CONCURRENT_HANDLERS: usize = 256;
 828            let mut foreground_message_handlers = FuturesUnordered::new();
 829            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 830            let get_concurrent_handlers = {
 831                let concurrent_handlers = concurrent_handlers.clone();
 832                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 833            };
 834            loop {
 835                let next_message = async {
 836                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 837                    let message = incoming_rx.next().await;
 838                    // Cache the concurrent_handlers here, so that we know what the
 839                    // queue looks like as each handler starts
 840                    (permit, message, get_concurrent_handlers())
 841                }
 842                .fuse();
 843                futures::pin_mut!(next_message);
 844                futures::select_biased! {
 845                    _ = teardown.changed().fuse() => return,
 846                    result = handle_io => {
 847                        if let Err(error) = result {
 848                            tracing::error!(?error, "error handling I/O");
 849                        }
 850                        break;
 851                    }
 852                    _ = foreground_message_handlers.next() => {}
 853                    next_message = next_message => {
 854                        let (permit, message, concurrent_handlers) = next_message;
 855                        if let Some(message) = message {
 856                            let type_name = message.payload_type_name();
 857                            // note: we copy all the fields from the parent span so we can query them in the logs.
 858                            // (https://github.com/tokio-rs/tracing/issues/2670).
 859                            let span = tracing::info_span!("receive message",
 860                                %connection_id,
 861                                %address,
 862                                type_name,
 863                                concurrent_handlers,
 864                                user_id=field::Empty,
 865                                login=field::Empty,
 866                                lsp_query_request=field::Empty,
 867                                release_channel=field::Empty,
 868                                { TOTAL_DURATION_MS }=field::Empty,
 869                                { PROCESSING_DURATION_MS }=field::Empty,
 870                                { QUEUE_DURATION_MS }=field::Empty,
 871                                { HOST_WAITING_MS }=field::Empty
 872                            );
 873                            principal.update_span(&span);
 874                            let span_enter = span.enter();
 875                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 876                                let is_background = message.is_background();
 877                                let handle_message = (handler)(message, session.clone(), span.clone());
 878                                drop(span_enter);
 879
 880                                let handle_message = async move {
 881                                    handle_message.await;
 882                                    drop(permit);
 883                                }.instrument(span);
 884                                if is_background {
 885                                    executor.spawn_detached(handle_message);
 886                                } else {
 887                                    foreground_message_handlers.push(handle_message);
 888                                }
 889                            } else {
 890                                tracing::error!("no message handler");
 891                            }
 892                        } else {
 893                            tracing::info!("connection closed");
 894                            break;
 895                        }
 896                    }
 897                }
 898            }
 899
 900            drop(foreground_message_handlers);
 901            let concurrent_handlers = get_concurrent_handlers();
 902            tracing::info!(concurrent_handlers, "signing out");
 903            if let Err(error) = connection_lost(session, teardown, executor).await {
 904                tracing::error!(?error, "error signing out");
 905            }
 906        }
 907        .instrument(span)
 908    }
 909
 910    async fn send_initial_client_update(
 911        &self,
 912        connection_id: ConnectionId,
 913        zed_version: ZedVersion,
 914        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 915        session: &Session,
 916    ) -> Result<()> {
 917        self.peer.send(
 918            connection_id,
 919            proto::Hello {
 920                peer_id: Some(connection_id.into()),
 921            },
 922        )?;
 923        tracing::info!("sent hello message");
 924        if let Some(send_connection_id) = send_connection_id.take() {
 925            let _ = send_connection_id.send(connection_id);
 926        }
 927
 928        match &session.principal {
 929            Principal::User(user) => {
 930                if !user.connected_once {
 931                    self.peer.send(connection_id, proto::ShowContacts {})?;
 932                    self.app_state
 933                        .db
 934                        .set_user_connected_once(user.id, true)
 935                        .await?;
 936                }
 937
 938                let contacts = self.app_state.db.get_contacts(user.id).await?;
 939
 940                {
 941                    let mut pool = self.connection_pool.lock();
 942                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 943                    self.peer.send(
 944                        connection_id,
 945                        build_initial_contacts_update(contacts, &pool),
 946                    )?;
 947                }
 948
 949                if should_auto_subscribe_to_channels(&zed_version) {
 950                    subscribe_user_to_channels(user.id, session).await?;
 951                }
 952
 953                if let Some(incoming_call) =
 954                    self.app_state.db.incoming_call_for_user(user.id).await?
 955                {
 956                    self.peer.send(connection_id, incoming_call)?;
 957                }
 958
 959                update_user_contacts(user.id, session).await?;
 960            }
 961        }
 962
 963        Ok(())
 964    }
 965}
 966
 967impl Deref for ConnectionPoolGuard<'_> {
 968    type Target = ConnectionPool;
 969
 970    fn deref(&self) -> &Self::Target {
 971        &self.guard
 972    }
 973}
 974
 975impl DerefMut for ConnectionPoolGuard<'_> {
 976    fn deref_mut(&mut self) -> &mut Self::Target {
 977        &mut self.guard
 978    }
 979}
 980
 981impl Drop for ConnectionPoolGuard<'_> {
 982    fn drop(&mut self) {
 983        #[cfg(feature = "test-support")]
 984        self.check_invariants();
 985    }
 986}
 987
 988fn broadcast<F>(
 989    sender_id: Option<ConnectionId>,
 990    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 991    mut f: F,
 992) where
 993    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 994{
 995    for receiver_id in receiver_ids {
 996        if Some(receiver_id) != sender_id
 997            && let Err(error) = f(receiver_id)
 998        {
 999            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1000        }
1001    }
1002}
1003
1004pub struct ProtocolVersion(u32);
1005
1006impl Header for ProtocolVersion {
1007    fn name() -> &'static HeaderName {
1008        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1009        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1010    }
1011
1012    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1013    where
1014        Self: Sized,
1015        I: Iterator<Item = &'i axum::http::HeaderValue>,
1016    {
1017        let version = values
1018            .next()
1019            .ok_or_else(axum::headers::Error::invalid)?
1020            .to_str()
1021            .map_err(|_| axum::headers::Error::invalid())?
1022            .parse()
1023            .map_err(|_| axum::headers::Error::invalid())?;
1024        Ok(Self(version))
1025    }
1026
1027    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1028        values.extend([self.0.to_string().parse().unwrap()]);
1029    }
1030}
1031
1032pub struct AppVersionHeader(Version);
1033impl Header for AppVersionHeader {
1034    fn name() -> &'static HeaderName {
1035        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1036        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1037    }
1038
1039    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1040    where
1041        Self: Sized,
1042        I: Iterator<Item = &'i axum::http::HeaderValue>,
1043    {
1044        let version = values
1045            .next()
1046            .ok_or_else(axum::headers::Error::invalid)?
1047            .to_str()
1048            .map_err(|_| axum::headers::Error::invalid())?
1049            .parse()
1050            .map_err(|_| axum::headers::Error::invalid())?;
1051        Ok(Self(version))
1052    }
1053
1054    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1055        values.extend([self.0.to_string().parse().unwrap()]);
1056    }
1057}
1058
1059#[derive(Debug)]
1060pub struct ReleaseChannelHeader(String);
1061
1062impl Header for ReleaseChannelHeader {
1063    fn name() -> &'static HeaderName {
1064        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1065        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1066    }
1067
1068    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1069    where
1070        Self: Sized,
1071        I: Iterator<Item = &'i axum::http::HeaderValue>,
1072    {
1073        Ok(Self(
1074            values
1075                .next()
1076                .ok_or_else(axum::headers::Error::invalid)?
1077                .to_str()
1078                .map_err(|_| axum::headers::Error::invalid())?
1079                .to_owned(),
1080        ))
1081    }
1082
1083    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1084        values.extend([self.0.parse().unwrap()]);
1085    }
1086}
1087
1088pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1089    Router::new()
1090        .route("/rpc", get(handle_websocket_request))
1091        .layer(
1092            ServiceBuilder::new()
1093                .layer(Extension(server.app_state.clone()))
1094                .layer(middleware::from_fn(auth::validate_header)),
1095        )
1096        .route("/metrics", get(handle_metrics))
1097        .layer(Extension(server))
1098}
1099
1100pub async fn handle_websocket_request(
1101    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1102    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1103    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1104    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1105    Extension(server): Extension<Arc<Server>>,
1106    Extension(principal): Extension<Principal>,
1107    user_agent: Option<TypedHeader<UserAgent>>,
1108    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1109    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1110    ws: WebSocketUpgrade,
1111) -> axum::response::Response {
1112    if protocol_version != rpc::PROTOCOL_VERSION {
1113        return (
1114            StatusCode::UPGRADE_REQUIRED,
1115            "client must be upgraded".to_string(),
1116        )
1117            .into_response();
1118    }
1119
1120    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1121        return (
1122            StatusCode::UPGRADE_REQUIRED,
1123            "no version header found".to_string(),
1124        )
1125            .into_response();
1126    };
1127
1128    let release_channel = release_channel_header.map(|header| header.0.0);
1129
1130    if !version.can_collaborate() {
1131        return (
1132            StatusCode::UPGRADE_REQUIRED,
1133            "client must be upgraded".to_string(),
1134        )
1135            .into_response();
1136    }
1137
1138    let socket_address = socket_address.to_string();
1139
1140    // Acquire connection guard before WebSocket upgrade
1141    let connection_guard = match ConnectionGuard::try_acquire() {
1142        Ok(guard) => guard,
1143        Err(()) => {
1144            return (
1145                StatusCode::SERVICE_UNAVAILABLE,
1146                "Too many concurrent connections",
1147            )
1148                .into_response();
1149        }
1150    };
1151
1152    ws.on_upgrade(move |socket| {
1153        let socket = socket
1154            .map_ok(to_tungstenite_message)
1155            .err_into()
1156            .with(|message| async move { to_axum_message(message) });
1157        let connection = Connection::new(Box::pin(socket));
1158        async move {
1159            server
1160                .handle_connection(
1161                    connection,
1162                    socket_address,
1163                    principal,
1164                    version,
1165                    release_channel,
1166                    user_agent.map(|header| header.to_string()),
1167                    country_code_header.map(|header| header.to_string()),
1168                    system_id_header.map(|header| header.to_string()),
1169                    None,
1170                    Executor::Production,
1171                    Some(connection_guard),
1172                )
1173                .await;
1174        }
1175    })
1176}
1177
1178pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1179    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1180    let connections_metric = CONNECTIONS_METRIC
1181        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1182
1183    let connections = server
1184        .connection_pool
1185        .lock()
1186        .connections()
1187        .filter(|connection| !connection.admin)
1188        .count();
1189    connections_metric.set(connections as _);
1190
1191    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1192    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1193        register_int_gauge!(
1194            "shared_projects",
1195            "number of open projects with one or more guests"
1196        )
1197        .unwrap()
1198    });
1199
1200    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1201    shared_projects_metric.set(shared_projects as _);
1202
1203    let encoder = prometheus::TextEncoder::new();
1204    let metric_families = prometheus::gather();
1205    let encoded_metrics = encoder
1206        .encode_to_string(&metric_families)
1207        .map_err(|err| anyhow!("{err}"))?;
1208    Ok(encoded_metrics)
1209}
1210
1211#[instrument(err, skip(executor))]
1212async fn connection_lost(
1213    session: Session,
1214    mut teardown: watch::Receiver<bool>,
1215    executor: Executor,
1216) -> Result<()> {
1217    session.peer.disconnect(session.connection_id);
1218    session
1219        .connection_pool()
1220        .await
1221        .remove_connection(session.connection_id)?;
1222
1223    session
1224        .db()
1225        .await
1226        .connection_lost(session.connection_id)
1227        .await
1228        .trace_err();
1229
1230    futures::select_biased! {
1231        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1232
1233            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1234            leave_room_for_session(&session, session.connection_id).await.trace_err();
1235            leave_channel_buffers_for_session(&session)
1236                .await
1237                .trace_err();
1238
1239            if !session
1240                .connection_pool()
1241                .await
1242                .is_user_online(session.user_id())
1243            {
1244                let db = session.db().await;
1245                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1246                    room_updated(&room, &session.peer);
1247                }
1248            }
1249
1250            update_user_contacts(session.user_id(), &session).await?;
1251        },
1252        _ = teardown.changed().fuse() => {}
1253    }
1254
1255    Ok(())
1256}
1257
1258/// Acknowledges a ping from a client, used to keep the connection alive.
1259async fn ping(
1260    _: proto::Ping,
1261    response: Response<proto::Ping>,
1262    _session: MessageContext,
1263) -> Result<()> {
1264    response.send(proto::Ack {})?;
1265    Ok(())
1266}
1267
1268/// Creates a new room for calling (outside of channels)
1269async fn create_room(
1270    _request: proto::CreateRoom,
1271    response: Response<proto::CreateRoom>,
1272    session: MessageContext,
1273) -> Result<()> {
1274    let livekit_room = nanoid::nanoid!(30);
1275
1276    let live_kit_connection_info = util::maybe!(async {
1277        let live_kit = session.app_state.livekit_client.as_ref();
1278        let live_kit = live_kit?;
1279        let user_id = session.user_id().to_string();
1280
1281        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1282
1283        Some(proto::LiveKitConnectionInfo {
1284            server_url: live_kit.url().into(),
1285            token,
1286            can_publish: true,
1287        })
1288    })
1289    .await;
1290
1291    let room = session
1292        .db()
1293        .await
1294        .create_room(session.user_id(), session.connection_id, &livekit_room)
1295        .await?;
1296
1297    response.send(proto::CreateRoomResponse {
1298        room: Some(room.clone()),
1299        live_kit_connection_info,
1300    })?;
1301
1302    update_user_contacts(session.user_id(), &session).await?;
1303    Ok(())
1304}
1305
1306/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1307async fn join_room(
1308    request: proto::JoinRoom,
1309    response: Response<proto::JoinRoom>,
1310    session: MessageContext,
1311) -> Result<()> {
1312    let room_id = RoomId::from_proto(request.id);
1313
1314    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1315
1316    if let Some(channel_id) = channel_id {
1317        return join_channel_internal(channel_id, Box::new(response), session).await;
1318    }
1319
1320    let joined_room = {
1321        let room = session
1322            .db()
1323            .await
1324            .join_room(room_id, session.user_id(), session.connection_id)
1325            .await?;
1326        room_updated(&room.room, &session.peer);
1327        room.into_inner()
1328    };
1329
1330    for connection_id in session
1331        .connection_pool()
1332        .await
1333        .user_connection_ids(session.user_id())
1334    {
1335        session
1336            .peer
1337            .send(
1338                connection_id,
1339                proto::CallCanceled {
1340                    room_id: room_id.to_proto(),
1341                },
1342            )
1343            .trace_err();
1344    }
1345
1346    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1347    {
1348        live_kit
1349            .room_token(
1350                &joined_room.room.livekit_room,
1351                &session.user_id().to_string(),
1352            )
1353            .trace_err()
1354            .map(|token| proto::LiveKitConnectionInfo {
1355                server_url: live_kit.url().into(),
1356                token,
1357                can_publish: true,
1358            })
1359    } else {
1360        None
1361    };
1362
1363    response.send(proto::JoinRoomResponse {
1364        room: Some(joined_room.room),
1365        channel_id: None,
1366        live_kit_connection_info,
1367    })?;
1368
1369    update_user_contacts(session.user_id(), &session).await?;
1370    Ok(())
1371}
1372
1373/// Rejoin room is used to reconnect to a room after connection errors.
1374async fn rejoin_room(
1375    request: proto::RejoinRoom,
1376    response: Response<proto::RejoinRoom>,
1377    session: MessageContext,
1378) -> Result<()> {
1379    let room;
1380    let channel;
1381    {
1382        let mut rejoined_room = session
1383            .db()
1384            .await
1385            .rejoin_room(request, session.user_id(), session.connection_id)
1386            .await?;
1387
1388        response.send(proto::RejoinRoomResponse {
1389            room: Some(rejoined_room.room.clone()),
1390            reshared_projects: rejoined_room
1391                .reshared_projects
1392                .iter()
1393                .map(|project| proto::ResharedProject {
1394                    id: project.id.to_proto(),
1395                    collaborators: project
1396                        .collaborators
1397                        .iter()
1398                        .map(|collaborator| collaborator.to_proto())
1399                        .collect(),
1400                })
1401                .collect(),
1402            rejoined_projects: rejoined_room
1403                .rejoined_projects
1404                .iter()
1405                .map(|rejoined_project| rejoined_project.to_proto())
1406                .collect(),
1407        })?;
1408        room_updated(&rejoined_room.room, &session.peer);
1409
1410        for project in &rejoined_room.reshared_projects {
1411            for collaborator in &project.collaborators {
1412                session
1413                    .peer
1414                    .send(
1415                        collaborator.connection_id,
1416                        proto::UpdateProjectCollaborator {
1417                            project_id: project.id.to_proto(),
1418                            old_peer_id: Some(project.old_connection_id.into()),
1419                            new_peer_id: Some(session.connection_id.into()),
1420                        },
1421                    )
1422                    .trace_err();
1423            }
1424
1425            broadcast(
1426                Some(session.connection_id),
1427                project
1428                    .collaborators
1429                    .iter()
1430                    .map(|collaborator| collaborator.connection_id),
1431                |connection_id| {
1432                    session.peer.forward_send(
1433                        session.connection_id,
1434                        connection_id,
1435                        proto::UpdateProject {
1436                            project_id: project.id.to_proto(),
1437                            worktrees: project.worktrees.clone(),
1438                        },
1439                    )
1440                },
1441            );
1442        }
1443
1444        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1445
1446        let rejoined_room = rejoined_room.into_inner();
1447
1448        room = rejoined_room.room;
1449        channel = rejoined_room.channel;
1450    }
1451
1452    if let Some(channel) = channel {
1453        channel_updated(
1454            &channel,
1455            &room,
1456            &session.peer,
1457            &*session.connection_pool().await,
1458        );
1459    }
1460
1461    update_user_contacts(session.user_id(), &session).await?;
1462    Ok(())
1463}
1464
1465fn notify_rejoined_projects(
1466    rejoined_projects: &mut Vec<RejoinedProject>,
1467    session: &Session,
1468) -> Result<()> {
1469    for project in rejoined_projects.iter() {
1470        for collaborator in &project.collaborators {
1471            session
1472                .peer
1473                .send(
1474                    collaborator.connection_id,
1475                    proto::UpdateProjectCollaborator {
1476                        project_id: project.id.to_proto(),
1477                        old_peer_id: Some(project.old_connection_id.into()),
1478                        new_peer_id: Some(session.connection_id.into()),
1479                    },
1480                )
1481                .trace_err();
1482        }
1483    }
1484
1485    for project in rejoined_projects {
1486        for worktree in mem::take(&mut project.worktrees) {
1487            // Stream this worktree's entries.
1488            let message = proto::UpdateWorktree {
1489                project_id: project.id.to_proto(),
1490                worktree_id: worktree.id,
1491                abs_path: worktree.abs_path.clone(),
1492                root_name: worktree.root_name,
1493                updated_entries: worktree.updated_entries,
1494                removed_entries: worktree.removed_entries,
1495                scan_id: worktree.scan_id,
1496                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1497                updated_repositories: worktree.updated_repositories,
1498                removed_repositories: worktree.removed_repositories,
1499            };
1500            for update in proto::split_worktree_update(message) {
1501                session.peer.send(session.connection_id, update)?;
1502            }
1503
1504            // Stream this worktree's diagnostics.
1505            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1506            if let Some(summary) = worktree_diagnostics.next() {
1507                let message = proto::UpdateDiagnosticSummary {
1508                    project_id: project.id.to_proto(),
1509                    worktree_id: worktree.id,
1510                    summary: Some(summary),
1511                    more_summaries: worktree_diagnostics.collect(),
1512                };
1513                session.peer.send(session.connection_id, message)?;
1514            }
1515
1516            for settings_file in worktree.settings_files {
1517                session.peer.send(
1518                    session.connection_id,
1519                    proto::UpdateWorktreeSettings {
1520                        project_id: project.id.to_proto(),
1521                        worktree_id: worktree.id,
1522                        path: settings_file.path,
1523                        content: Some(settings_file.content),
1524                        kind: Some(settings_file.kind.to_proto().into()),
1525                        outside_worktree: Some(settings_file.outside_worktree),
1526                    },
1527                )?;
1528            }
1529        }
1530
1531        for repository in mem::take(&mut project.updated_repositories) {
1532            for update in split_repository_update(repository) {
1533                session.peer.send(session.connection_id, update)?;
1534            }
1535        }
1536
1537        for id in mem::take(&mut project.removed_repositories) {
1538            session.peer.send(
1539                session.connection_id,
1540                proto::RemoveRepository {
1541                    project_id: project.id.to_proto(),
1542                    id,
1543                },
1544            )?;
1545        }
1546    }
1547
1548    Ok(())
1549}
1550
1551/// leave room disconnects from the room.
1552async fn leave_room(
1553    _: proto::LeaveRoom,
1554    response: Response<proto::LeaveRoom>,
1555    session: MessageContext,
1556) -> Result<()> {
1557    leave_room_for_session(&session, session.connection_id).await?;
1558    response.send(proto::Ack {})?;
1559    Ok(())
1560}
1561
1562/// Updates the permissions of someone else in the room.
1563async fn set_room_participant_role(
1564    request: proto::SetRoomParticipantRole,
1565    response: Response<proto::SetRoomParticipantRole>,
1566    session: MessageContext,
1567) -> Result<()> {
1568    let user_id = UserId::from_proto(request.user_id);
1569    let role = ChannelRole::from(request.role());
1570
1571    let (livekit_room, can_publish) = {
1572        let room = session
1573            .db()
1574            .await
1575            .set_room_participant_role(
1576                session.user_id(),
1577                RoomId::from_proto(request.room_id),
1578                user_id,
1579                role,
1580            )
1581            .await?;
1582
1583        let livekit_room = room.livekit_room.clone();
1584        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1585        room_updated(&room, &session.peer);
1586        (livekit_room, can_publish)
1587    };
1588
1589    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1590        live_kit
1591            .update_participant(
1592                livekit_room.clone(),
1593                request.user_id.to_string(),
1594                livekit_api::proto::ParticipantPermission {
1595                    can_subscribe: true,
1596                    can_publish,
1597                    can_publish_data: can_publish,
1598                    hidden: false,
1599                    recorder: false,
1600                },
1601            )
1602            .await
1603            .trace_err();
1604    }
1605
1606    response.send(proto::Ack {})?;
1607    Ok(())
1608}
1609
1610/// Call someone else into the current room
1611async fn call(
1612    request: proto::Call,
1613    response: Response<proto::Call>,
1614    session: MessageContext,
1615) -> Result<()> {
1616    let room_id = RoomId::from_proto(request.room_id);
1617    let calling_user_id = session.user_id();
1618    let calling_connection_id = session.connection_id;
1619    let called_user_id = UserId::from_proto(request.called_user_id);
1620    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1621    if !session
1622        .db()
1623        .await
1624        .has_contact(calling_user_id, called_user_id)
1625        .await?
1626    {
1627        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1628    }
1629
1630    let incoming_call = {
1631        let (room, incoming_call) = &mut *session
1632            .db()
1633            .await
1634            .call(
1635                room_id,
1636                calling_user_id,
1637                calling_connection_id,
1638                called_user_id,
1639                initial_project_id,
1640            )
1641            .await?;
1642        room_updated(room, &session.peer);
1643        mem::take(incoming_call)
1644    };
1645    update_user_contacts(called_user_id, &session).await?;
1646
1647    let mut calls = session
1648        .connection_pool()
1649        .await
1650        .user_connection_ids(called_user_id)
1651        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1652        .collect::<FuturesUnordered<_>>();
1653
1654    while let Some(call_response) = calls.next().await {
1655        match call_response.as_ref() {
1656            Ok(_) => {
1657                response.send(proto::Ack {})?;
1658                return Ok(());
1659            }
1660            Err(_) => {
1661                call_response.trace_err();
1662            }
1663        }
1664    }
1665
1666    {
1667        let room = session
1668            .db()
1669            .await
1670            .call_failed(room_id, called_user_id)
1671            .await?;
1672        room_updated(&room, &session.peer);
1673    }
1674    update_user_contacts(called_user_id, &session).await?;
1675
1676    Err(anyhow!("failed to ring user"))?
1677}
1678
1679/// Cancel an outgoing call.
1680async fn cancel_call(
1681    request: proto::CancelCall,
1682    response: Response<proto::CancelCall>,
1683    session: MessageContext,
1684) -> Result<()> {
1685    let called_user_id = UserId::from_proto(request.called_user_id);
1686    let room_id = RoomId::from_proto(request.room_id);
1687    {
1688        let room = session
1689            .db()
1690            .await
1691            .cancel_call(room_id, session.connection_id, called_user_id)
1692            .await?;
1693        room_updated(&room, &session.peer);
1694    }
1695
1696    for connection_id in session
1697        .connection_pool()
1698        .await
1699        .user_connection_ids(called_user_id)
1700    {
1701        session
1702            .peer
1703            .send(
1704                connection_id,
1705                proto::CallCanceled {
1706                    room_id: room_id.to_proto(),
1707                },
1708            )
1709            .trace_err();
1710    }
1711    response.send(proto::Ack {})?;
1712
1713    update_user_contacts(called_user_id, &session).await?;
1714    Ok(())
1715}
1716
1717/// Decline an incoming call.
1718async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1719    let room_id = RoomId::from_proto(message.room_id);
1720    {
1721        let room = session
1722            .db()
1723            .await
1724            .decline_call(Some(room_id), session.user_id())
1725            .await?
1726            .context("declining call")?;
1727        room_updated(&room, &session.peer);
1728    }
1729
1730    for connection_id in session
1731        .connection_pool()
1732        .await
1733        .user_connection_ids(session.user_id())
1734    {
1735        session
1736            .peer
1737            .send(
1738                connection_id,
1739                proto::CallCanceled {
1740                    room_id: room_id.to_proto(),
1741                },
1742            )
1743            .trace_err();
1744    }
1745    update_user_contacts(session.user_id(), &session).await?;
1746    Ok(())
1747}
1748
1749/// Updates other participants in the room with your current location.
1750async fn update_participant_location(
1751    request: proto::UpdateParticipantLocation,
1752    response: Response<proto::UpdateParticipantLocation>,
1753    session: MessageContext,
1754) -> Result<()> {
1755    let room_id = RoomId::from_proto(request.room_id);
1756    let location = request.location.context("invalid location")?;
1757
1758    let db = session.db().await;
1759    let room = db
1760        .update_room_participant_location(room_id, session.connection_id, location)
1761        .await?;
1762
1763    room_updated(&room, &session.peer);
1764    response.send(proto::Ack {})?;
1765    Ok(())
1766}
1767
1768/// Share a project into the room.
1769async fn share_project(
1770    request: proto::ShareProject,
1771    response: Response<proto::ShareProject>,
1772    session: MessageContext,
1773) -> Result<()> {
1774    let (project_id, room) = &*session
1775        .db()
1776        .await
1777        .share_project(
1778            RoomId::from_proto(request.room_id),
1779            session.connection_id,
1780            &request.worktrees,
1781            request.is_ssh_project,
1782            request.windows_paths.unwrap_or(false),
1783        )
1784        .await?;
1785    response.send(proto::ShareProjectResponse {
1786        project_id: project_id.to_proto(),
1787    })?;
1788    room_updated(room, &session.peer);
1789
1790    Ok(())
1791}
1792
1793/// Unshare a project from the room.
1794async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1795    let project_id = ProjectId::from_proto(message.project_id);
1796    unshare_project_internal(project_id, session.connection_id, &session).await
1797}
1798
1799async fn unshare_project_internal(
1800    project_id: ProjectId,
1801    connection_id: ConnectionId,
1802    session: &Session,
1803) -> Result<()> {
1804    let delete = {
1805        let room_guard = session
1806            .db()
1807            .await
1808            .unshare_project(project_id, connection_id)
1809            .await?;
1810
1811        let (delete, room, guest_connection_ids) = &*room_guard;
1812
1813        let message = proto::UnshareProject {
1814            project_id: project_id.to_proto(),
1815        };
1816
1817        broadcast(
1818            Some(connection_id),
1819            guest_connection_ids.iter().copied(),
1820            |conn_id| session.peer.send(conn_id, message.clone()),
1821        );
1822        if let Some(room) = room {
1823            room_updated(room, &session.peer);
1824        }
1825
1826        *delete
1827    };
1828
1829    if delete {
1830        let db = session.db().await;
1831        db.delete_project(project_id).await?;
1832    }
1833
1834    Ok(())
1835}
1836
1837/// Join someone elses shared project.
1838async fn join_project(
1839    request: proto::JoinProject,
1840    response: Response<proto::JoinProject>,
1841    session: MessageContext,
1842) -> Result<()> {
1843    let project_id = ProjectId::from_proto(request.project_id);
1844
1845    tracing::info!(%project_id, "join project");
1846
1847    let db = session.db().await;
1848    let (project, replica_id) = &mut *db
1849        .join_project(
1850            project_id,
1851            session.connection_id,
1852            session.user_id(),
1853            request.committer_name.clone(),
1854            request.committer_email.clone(),
1855        )
1856        .await?;
1857    drop(db);
1858    tracing::info!(%project_id, "join remote project");
1859    let collaborators = project
1860        .collaborators
1861        .iter()
1862        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1863        .map(|collaborator| collaborator.to_proto())
1864        .collect::<Vec<_>>();
1865    let project_id = project.id;
1866    let guest_user_id = session.user_id();
1867
1868    let worktrees = project
1869        .worktrees
1870        .iter()
1871        .map(|(id, worktree)| proto::WorktreeMetadata {
1872            id: *id,
1873            root_name: worktree.root_name.clone(),
1874            visible: worktree.visible,
1875            abs_path: worktree.abs_path.clone(),
1876        })
1877        .collect::<Vec<_>>();
1878
1879    let add_project_collaborator = proto::AddProjectCollaborator {
1880        project_id: project_id.to_proto(),
1881        collaborator: Some(proto::Collaborator {
1882            peer_id: Some(session.connection_id.into()),
1883            replica_id: replica_id.0 as u32,
1884            user_id: guest_user_id.to_proto(),
1885            is_host: false,
1886            committer_name: request.committer_name.clone(),
1887            committer_email: request.committer_email.clone(),
1888        }),
1889    };
1890
1891    for collaborator in &collaborators {
1892        session
1893            .peer
1894            .send(
1895                collaborator.peer_id.unwrap().into(),
1896                add_project_collaborator.clone(),
1897            )
1898            .trace_err();
1899    }
1900
1901    // First, we send the metadata associated with each worktree.
1902    let (language_servers, language_server_capabilities) = project
1903        .language_servers
1904        .clone()
1905        .into_iter()
1906        .map(|server| (server.server, server.capabilities))
1907        .unzip();
1908    response.send(proto::JoinProjectResponse {
1909        project_id: project.id.0 as u64,
1910        worktrees,
1911        replica_id: replica_id.0 as u32,
1912        collaborators,
1913        language_servers,
1914        language_server_capabilities,
1915        role: project.role.into(),
1916        windows_paths: project.path_style == PathStyle::Windows,
1917    })?;
1918
1919    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1920        // Stream this worktree's entries.
1921        let message = proto::UpdateWorktree {
1922            project_id: project_id.to_proto(),
1923            worktree_id,
1924            abs_path: worktree.abs_path.clone(),
1925            root_name: worktree.root_name,
1926            updated_entries: worktree.entries,
1927            removed_entries: Default::default(),
1928            scan_id: worktree.scan_id,
1929            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1930            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1931            removed_repositories: Default::default(),
1932        };
1933        for update in proto::split_worktree_update(message) {
1934            session.peer.send(session.connection_id, update.clone())?;
1935        }
1936
1937        // Stream this worktree's diagnostics.
1938        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1939        if let Some(summary) = worktree_diagnostics.next() {
1940            let message = proto::UpdateDiagnosticSummary {
1941                project_id: project.id.to_proto(),
1942                worktree_id: worktree.id,
1943                summary: Some(summary),
1944                more_summaries: worktree_diagnostics.collect(),
1945            };
1946            session.peer.send(session.connection_id, message)?;
1947        }
1948
1949        for settings_file in worktree.settings_files {
1950            session.peer.send(
1951                session.connection_id,
1952                proto::UpdateWorktreeSettings {
1953                    project_id: project_id.to_proto(),
1954                    worktree_id: worktree.id,
1955                    path: settings_file.path,
1956                    content: Some(settings_file.content),
1957                    kind: Some(settings_file.kind.to_proto() as i32),
1958                    outside_worktree: Some(settings_file.outside_worktree),
1959                },
1960            )?;
1961        }
1962    }
1963
1964    for repository in mem::take(&mut project.repositories) {
1965        for update in split_repository_update(repository) {
1966            session.peer.send(session.connection_id, update)?;
1967        }
1968    }
1969
1970    for language_server in &project.language_servers {
1971        session.peer.send(
1972            session.connection_id,
1973            proto::UpdateLanguageServer {
1974                project_id: project_id.to_proto(),
1975                server_name: Some(language_server.server.name.clone()),
1976                language_server_id: language_server.server.id,
1977                variant: Some(
1978                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1979                        proto::LspDiskBasedDiagnosticsUpdated {},
1980                    ),
1981                ),
1982            },
1983        )?;
1984    }
1985
1986    Ok(())
1987}
1988
1989/// Leave someone elses shared project.
1990async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
1991    let sender_id = session.connection_id;
1992    let project_id = ProjectId::from_proto(request.project_id);
1993    let db = session.db().await;
1994
1995    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1996    tracing::info!(
1997        %project_id,
1998        "leave project"
1999    );
2000
2001    project_left(project, &session);
2002    if let Some(room) = room {
2003        room_updated(room, &session.peer);
2004    }
2005
2006    Ok(())
2007}
2008
2009/// Updates other participants with changes to the project
2010async fn update_project(
2011    request: proto::UpdateProject,
2012    response: Response<proto::UpdateProject>,
2013    session: MessageContext,
2014) -> Result<()> {
2015    let project_id = ProjectId::from_proto(request.project_id);
2016    let (room, guest_connection_ids) = &*session
2017        .db()
2018        .await
2019        .update_project(project_id, session.connection_id, &request.worktrees)
2020        .await?;
2021    broadcast(
2022        Some(session.connection_id),
2023        guest_connection_ids.iter().copied(),
2024        |connection_id| {
2025            session
2026                .peer
2027                .forward_send(session.connection_id, connection_id, request.clone())
2028        },
2029    );
2030    if let Some(room) = room {
2031        room_updated(room, &session.peer);
2032    }
2033    response.send(proto::Ack {})?;
2034
2035    Ok(())
2036}
2037
2038/// Updates other participants with changes to the worktree
2039async fn update_worktree(
2040    request: proto::UpdateWorktree,
2041    response: Response<proto::UpdateWorktree>,
2042    session: MessageContext,
2043) -> Result<()> {
2044    let guest_connection_ids = session
2045        .db()
2046        .await
2047        .update_worktree(&request, session.connection_id)
2048        .await?;
2049
2050    broadcast(
2051        Some(session.connection_id),
2052        guest_connection_ids.iter().copied(),
2053        |connection_id| {
2054            session
2055                .peer
2056                .forward_send(session.connection_id, connection_id, request.clone())
2057        },
2058    );
2059    response.send(proto::Ack {})?;
2060    Ok(())
2061}
2062
2063async fn update_repository(
2064    request: proto::UpdateRepository,
2065    response: Response<proto::UpdateRepository>,
2066    session: MessageContext,
2067) -> Result<()> {
2068    let guest_connection_ids = session
2069        .db()
2070        .await
2071        .update_repository(&request, session.connection_id)
2072        .await?;
2073
2074    broadcast(
2075        Some(session.connection_id),
2076        guest_connection_ids.iter().copied(),
2077        |connection_id| {
2078            session
2079                .peer
2080                .forward_send(session.connection_id, connection_id, request.clone())
2081        },
2082    );
2083    response.send(proto::Ack {})?;
2084    Ok(())
2085}
2086
2087async fn remove_repository(
2088    request: proto::RemoveRepository,
2089    response: Response<proto::RemoveRepository>,
2090    session: MessageContext,
2091) -> Result<()> {
2092    let guest_connection_ids = session
2093        .db()
2094        .await
2095        .remove_repository(&request, session.connection_id)
2096        .await?;
2097
2098    broadcast(
2099        Some(session.connection_id),
2100        guest_connection_ids.iter().copied(),
2101        |connection_id| {
2102            session
2103                .peer
2104                .forward_send(session.connection_id, connection_id, request.clone())
2105        },
2106    );
2107    response.send(proto::Ack {})?;
2108    Ok(())
2109}
2110
2111/// Updates other participants with changes to the diagnostics
2112async fn update_diagnostic_summary(
2113    message: proto::UpdateDiagnosticSummary,
2114    session: MessageContext,
2115) -> Result<()> {
2116    let guest_connection_ids = session
2117        .db()
2118        .await
2119        .update_diagnostic_summary(&message, session.connection_id)
2120        .await?;
2121
2122    broadcast(
2123        Some(session.connection_id),
2124        guest_connection_ids.iter().copied(),
2125        |connection_id| {
2126            session
2127                .peer
2128                .forward_send(session.connection_id, connection_id, message.clone())
2129        },
2130    );
2131
2132    Ok(())
2133}
2134
2135/// Updates other participants with changes to the worktree settings
2136async fn update_worktree_settings(
2137    message: proto::UpdateWorktreeSettings,
2138    session: MessageContext,
2139) -> Result<()> {
2140    let guest_connection_ids = session
2141        .db()
2142        .await
2143        .update_worktree_settings(&message, session.connection_id)
2144        .await?;
2145
2146    broadcast(
2147        Some(session.connection_id),
2148        guest_connection_ids.iter().copied(),
2149        |connection_id| {
2150            session
2151                .peer
2152                .forward_send(session.connection_id, connection_id, message.clone())
2153        },
2154    );
2155
2156    Ok(())
2157}
2158
2159/// Notify other participants that a language server has started.
2160async fn start_language_server(
2161    request: proto::StartLanguageServer,
2162    session: MessageContext,
2163) -> Result<()> {
2164    let guest_connection_ids = session
2165        .db()
2166        .await
2167        .start_language_server(&request, session.connection_id)
2168        .await?;
2169
2170    broadcast(
2171        Some(session.connection_id),
2172        guest_connection_ids.iter().copied(),
2173        |connection_id| {
2174            session
2175                .peer
2176                .forward_send(session.connection_id, connection_id, request.clone())
2177        },
2178    );
2179    Ok(())
2180}
2181
2182/// Notify other participants that a language server has changed.
2183async fn update_language_server(
2184    request: proto::UpdateLanguageServer,
2185    session: MessageContext,
2186) -> Result<()> {
2187    let project_id = ProjectId::from_proto(request.project_id);
2188    let db = session.db().await;
2189
2190    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2191        && let Some(capabilities) = update.capabilities.clone()
2192    {
2193        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2194            .await?;
2195    }
2196
2197    let project_connection_ids = db
2198        .project_connection_ids(project_id, session.connection_id, true)
2199        .await?;
2200    broadcast(
2201        Some(session.connection_id),
2202        project_connection_ids.iter().copied(),
2203        |connection_id| {
2204            session
2205                .peer
2206                .forward_send(session.connection_id, connection_id, request.clone())
2207        },
2208    );
2209    Ok(())
2210}
2211
2212/// forward a project request to the host. These requests should be read only
2213/// as guests are allowed to send them.
2214async fn forward_read_only_project_request<T>(
2215    request: T,
2216    response: Response<T>,
2217    session: MessageContext,
2218) -> Result<()>
2219where
2220    T: EntityMessage + RequestMessage,
2221{
2222    let project_id = ProjectId::from_proto(request.remote_entity_id());
2223    let host_connection_id = session
2224        .db()
2225        .await
2226        .host_for_read_only_project_request(project_id, session.connection_id)
2227        .await?;
2228    let payload = session.forward_request(host_connection_id, request).await?;
2229    response.send(payload)?;
2230    Ok(())
2231}
2232
2233/// forward a project request to the host. These requests are disallowed
2234/// for guests.
2235async fn forward_mutating_project_request<T>(
2236    request: T,
2237    response: Response<T>,
2238    session: MessageContext,
2239) -> Result<()>
2240where
2241    T: EntityMessage + RequestMessage,
2242{
2243    let project_id = ProjectId::from_proto(request.remote_entity_id());
2244
2245    let host_connection_id = session
2246        .db()
2247        .await
2248        .host_for_mutating_project_request(project_id, session.connection_id)
2249        .await?;
2250    let payload = session.forward_request(host_connection_id, request).await?;
2251    response.send(payload)?;
2252    Ok(())
2253}
2254
2255async fn disallow_guest_request<T>(
2256    _request: T,
2257    response: Response<T>,
2258    _session: MessageContext,
2259) -> Result<()>
2260where
2261    T: RequestMessage,
2262{
2263    response.peer.respond_with_error(
2264        response.receipt,
2265        ErrorCode::Forbidden
2266            .message("request is not allowed for guests".to_string())
2267            .to_proto(),
2268    )?;
2269    response.responded.store(true, SeqCst);
2270    Ok(())
2271}
2272
2273async fn lsp_query(
2274    request: proto::LspQuery,
2275    response: Response<proto::LspQuery>,
2276    session: MessageContext,
2277) -> Result<()> {
2278    let (name, should_write) = request.query_name_and_write_permissions();
2279    tracing::Span::current().record("lsp_query_request", name);
2280    tracing::info!("lsp_query message received");
2281    if should_write {
2282        forward_mutating_project_request(request, response, session).await
2283    } else {
2284        forward_read_only_project_request(request, response, session).await
2285    }
2286}
2287
2288/// Notify other participants that a new buffer has been created
2289async fn create_buffer_for_peer(
2290    request: proto::CreateBufferForPeer,
2291    session: MessageContext,
2292) -> Result<()> {
2293    session
2294        .db()
2295        .await
2296        .check_user_is_project_host(
2297            ProjectId::from_proto(request.project_id),
2298            session.connection_id,
2299        )
2300        .await?;
2301    let peer_id = request.peer_id.context("invalid peer id")?;
2302    session
2303        .peer
2304        .forward_send(session.connection_id, peer_id.into(), request)?;
2305    Ok(())
2306}
2307
2308/// Notify other participants that a new image has been created
2309async fn create_image_for_peer(
2310    request: proto::CreateImageForPeer,
2311    session: MessageContext,
2312) -> Result<()> {
2313    session
2314        .db()
2315        .await
2316        .check_user_is_project_host(
2317            ProjectId::from_proto(request.project_id),
2318            session.connection_id,
2319        )
2320        .await?;
2321    let peer_id = request.peer_id.context("invalid peer id")?;
2322    session
2323        .peer
2324        .forward_send(session.connection_id, peer_id.into(), request)?;
2325    Ok(())
2326}
2327
2328/// Notify other participants that a buffer has been updated. This is
2329/// allowed for guests as long as the update is limited to selections.
2330async fn update_buffer(
2331    request: proto::UpdateBuffer,
2332    response: Response<proto::UpdateBuffer>,
2333    session: MessageContext,
2334) -> Result<()> {
2335    let project_id = ProjectId::from_proto(request.project_id);
2336    let mut capability = Capability::ReadOnly;
2337
2338    for op in request.operations.iter() {
2339        match op.variant {
2340            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2341            Some(_) => capability = Capability::ReadWrite,
2342        }
2343    }
2344
2345    let host = {
2346        let guard = session
2347            .db()
2348            .await
2349            .connections_for_buffer_update(project_id, session.connection_id, capability)
2350            .await?;
2351
2352        let (host, guests) = &*guard;
2353
2354        broadcast(
2355            Some(session.connection_id),
2356            guests.clone(),
2357            |connection_id| {
2358                session
2359                    .peer
2360                    .forward_send(session.connection_id, connection_id, request.clone())
2361            },
2362        );
2363
2364        *host
2365    };
2366
2367    if host != session.connection_id {
2368        session.forward_request(host, request.clone()).await?;
2369    }
2370
2371    response.send(proto::Ack {})?;
2372    Ok(())
2373}
2374
2375async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2376    let project_id = ProjectId::from_proto(message.project_id);
2377
2378    let operation = message.operation.as_ref().context("invalid operation")?;
2379    let capability = match operation.variant.as_ref() {
2380        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2381            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2382                match buffer_op.variant {
2383                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2384                        Capability::ReadOnly
2385                    }
2386                    _ => Capability::ReadWrite,
2387                }
2388            } else {
2389                Capability::ReadWrite
2390            }
2391        }
2392        Some(_) => Capability::ReadWrite,
2393        None => Capability::ReadOnly,
2394    };
2395
2396    let guard = session
2397        .db()
2398        .await
2399        .connections_for_buffer_update(project_id, session.connection_id, capability)
2400        .await?;
2401
2402    let (host, guests) = &*guard;
2403
2404    broadcast(
2405        Some(session.connection_id),
2406        guests.iter().chain([host]).copied(),
2407        |connection_id| {
2408            session
2409                .peer
2410                .forward_send(session.connection_id, connection_id, message.clone())
2411        },
2412    );
2413
2414    Ok(())
2415}
2416
2417async fn forward_project_search_chunk(
2418    message: proto::FindSearchCandidatesChunk,
2419    response: Response<proto::FindSearchCandidatesChunk>,
2420    session: MessageContext,
2421) -> Result<()> {
2422    let peer_id = message.peer_id.context("missing peer_id")?;
2423    let payload = session
2424        .peer
2425        .forward_request(session.connection_id, peer_id.into(), message)
2426        .await?;
2427    response.send(payload)?;
2428    Ok(())
2429}
2430
2431/// Notify other participants that a project has been updated.
2432async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2433    request: T,
2434    session: MessageContext,
2435) -> Result<()> {
2436    let project_id = ProjectId::from_proto(request.remote_entity_id());
2437    let project_connection_ids = session
2438        .db()
2439        .await
2440        .project_connection_ids(project_id, session.connection_id, false)
2441        .await?;
2442
2443    broadcast(
2444        Some(session.connection_id),
2445        project_connection_ids.iter().copied(),
2446        |connection_id| {
2447            session
2448                .peer
2449                .forward_send(session.connection_id, connection_id, request.clone())
2450        },
2451    );
2452    Ok(())
2453}
2454
2455/// Start following another user in a call.
2456async fn follow(
2457    request: proto::Follow,
2458    response: Response<proto::Follow>,
2459    session: MessageContext,
2460) -> Result<()> {
2461    let room_id = RoomId::from_proto(request.room_id);
2462    let project_id = request.project_id.map(ProjectId::from_proto);
2463    let leader_id = request.leader_id.context("invalid leader id")?.into();
2464    let follower_id = session.connection_id;
2465
2466    session
2467        .db()
2468        .await
2469        .check_room_participants(room_id, leader_id, session.connection_id)
2470        .await?;
2471
2472    let response_payload = session.forward_request(leader_id, request).await?;
2473    response.send(response_payload)?;
2474
2475    if let Some(project_id) = project_id {
2476        let room = session
2477            .db()
2478            .await
2479            .follow(room_id, project_id, leader_id, follower_id)
2480            .await?;
2481        room_updated(&room, &session.peer);
2482    }
2483
2484    Ok(())
2485}
2486
2487/// Stop following another user in a call.
2488async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2489    let room_id = RoomId::from_proto(request.room_id);
2490    let project_id = request.project_id.map(ProjectId::from_proto);
2491    let leader_id = request.leader_id.context("invalid leader id")?.into();
2492    let follower_id = session.connection_id;
2493
2494    session
2495        .db()
2496        .await
2497        .check_room_participants(room_id, leader_id, session.connection_id)
2498        .await?;
2499
2500    session
2501        .peer
2502        .forward_send(session.connection_id, leader_id, request)?;
2503
2504    if let Some(project_id) = project_id {
2505        let room = session
2506            .db()
2507            .await
2508            .unfollow(room_id, project_id, leader_id, follower_id)
2509            .await?;
2510        room_updated(&room, &session.peer);
2511    }
2512
2513    Ok(())
2514}
2515
2516/// Notify everyone following you of your current location.
2517async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2518    let room_id = RoomId::from_proto(request.room_id);
2519    let database = session.db.lock().await;
2520
2521    let connection_ids = if let Some(project_id) = request.project_id {
2522        let project_id = ProjectId::from_proto(project_id);
2523        database
2524            .project_connection_ids(project_id, session.connection_id, true)
2525            .await?
2526    } else {
2527        database
2528            .room_connection_ids(room_id, session.connection_id)
2529            .await?
2530    };
2531
2532    // For now, don't send view update messages back to that view's current leader.
2533    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2534        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2535        _ => None,
2536    });
2537
2538    for connection_id in connection_ids.iter().cloned() {
2539        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2540            session
2541                .peer
2542                .forward_send(session.connection_id, connection_id, request.clone())?;
2543        }
2544    }
2545    Ok(())
2546}
2547
2548/// Get public data about users.
2549async fn get_users(
2550    request: proto::GetUsers,
2551    response: Response<proto::GetUsers>,
2552    session: MessageContext,
2553) -> Result<()> {
2554    let user_ids = request
2555        .user_ids
2556        .into_iter()
2557        .map(UserId::from_proto)
2558        .collect();
2559    let users = session
2560        .db()
2561        .await
2562        .get_users_by_ids(user_ids)
2563        .await?
2564        .into_iter()
2565        .map(|user| proto::User {
2566            id: user.id.to_proto(),
2567            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2568            github_login: user.github_login,
2569            name: user.name,
2570        })
2571        .collect();
2572    response.send(proto::UsersResponse { users })?;
2573    Ok(())
2574}
2575
2576/// Search for users (to invite) buy Github login
2577async fn fuzzy_search_users(
2578    request: proto::FuzzySearchUsers,
2579    response: Response<proto::FuzzySearchUsers>,
2580    session: MessageContext,
2581) -> Result<()> {
2582    let query = request.query;
2583    let users = match query.len() {
2584        0 => vec![],
2585        1 | 2 => session
2586            .db()
2587            .await
2588            .get_user_by_github_login(&query)
2589            .await?
2590            .into_iter()
2591            .collect(),
2592        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2593    };
2594    let users = users
2595        .into_iter()
2596        .filter(|user| user.id != session.user_id())
2597        .map(|user| proto::User {
2598            id: user.id.to_proto(),
2599            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2600            github_login: user.github_login,
2601            name: user.name,
2602        })
2603        .collect();
2604    response.send(proto::UsersResponse { users })?;
2605    Ok(())
2606}
2607
2608/// Send a contact request to another user.
2609async fn request_contact(
2610    request: proto::RequestContact,
2611    response: Response<proto::RequestContact>,
2612    session: MessageContext,
2613) -> Result<()> {
2614    let requester_id = session.user_id();
2615    let responder_id = UserId::from_proto(request.responder_id);
2616    if requester_id == responder_id {
2617        return Err(anyhow!("cannot add yourself as a contact"))?;
2618    }
2619
2620    let notifications = session
2621        .db()
2622        .await
2623        .send_contact_request(requester_id, responder_id)
2624        .await?;
2625
2626    // Update outgoing contact requests of requester
2627    let mut update = proto::UpdateContacts::default();
2628    update.outgoing_requests.push(responder_id.to_proto());
2629    for connection_id in session
2630        .connection_pool()
2631        .await
2632        .user_connection_ids(requester_id)
2633    {
2634        session.peer.send(connection_id, update.clone())?;
2635    }
2636
2637    // Update incoming contact requests of responder
2638    let mut update = proto::UpdateContacts::default();
2639    update
2640        .incoming_requests
2641        .push(proto::IncomingContactRequest {
2642            requester_id: requester_id.to_proto(),
2643        });
2644    let connection_pool = session.connection_pool().await;
2645    for connection_id in connection_pool.user_connection_ids(responder_id) {
2646        session.peer.send(connection_id, update.clone())?;
2647    }
2648
2649    send_notifications(&connection_pool, &session.peer, notifications);
2650
2651    response.send(proto::Ack {})?;
2652    Ok(())
2653}
2654
2655/// Accept or decline a contact request
2656async fn respond_to_contact_request(
2657    request: proto::RespondToContactRequest,
2658    response: Response<proto::RespondToContactRequest>,
2659    session: MessageContext,
2660) -> Result<()> {
2661    let responder_id = session.user_id();
2662    let requester_id = UserId::from_proto(request.requester_id);
2663    let db = session.db().await;
2664    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2665        db.dismiss_contact_notification(responder_id, requester_id)
2666            .await?;
2667    } else {
2668        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2669
2670        let notifications = db
2671            .respond_to_contact_request(responder_id, requester_id, accept)
2672            .await?;
2673        let requester_busy = db.is_user_busy(requester_id).await?;
2674        let responder_busy = db.is_user_busy(responder_id).await?;
2675
2676        let pool = session.connection_pool().await;
2677        // Update responder with new contact
2678        let mut update = proto::UpdateContacts::default();
2679        if accept {
2680            update
2681                .contacts
2682                .push(contact_for_user(requester_id, requester_busy, &pool));
2683        }
2684        update
2685            .remove_incoming_requests
2686            .push(requester_id.to_proto());
2687        for connection_id in pool.user_connection_ids(responder_id) {
2688            session.peer.send(connection_id, update.clone())?;
2689        }
2690
2691        // Update requester with new contact
2692        let mut update = proto::UpdateContacts::default();
2693        if accept {
2694            update
2695                .contacts
2696                .push(contact_for_user(responder_id, responder_busy, &pool));
2697        }
2698        update
2699            .remove_outgoing_requests
2700            .push(responder_id.to_proto());
2701
2702        for connection_id in pool.user_connection_ids(requester_id) {
2703            session.peer.send(connection_id, update.clone())?;
2704        }
2705
2706        send_notifications(&pool, &session.peer, notifications);
2707    }
2708
2709    response.send(proto::Ack {})?;
2710    Ok(())
2711}
2712
2713/// Remove a contact.
2714async fn remove_contact(
2715    request: proto::RemoveContact,
2716    response: Response<proto::RemoveContact>,
2717    session: MessageContext,
2718) -> Result<()> {
2719    let requester_id = session.user_id();
2720    let responder_id = UserId::from_proto(request.user_id);
2721    let db = session.db().await;
2722    let (contact_accepted, deleted_notification_id) =
2723        db.remove_contact(requester_id, responder_id).await?;
2724
2725    let pool = session.connection_pool().await;
2726    // Update outgoing contact requests of requester
2727    let mut update = proto::UpdateContacts::default();
2728    if contact_accepted {
2729        update.remove_contacts.push(responder_id.to_proto());
2730    } else {
2731        update
2732            .remove_outgoing_requests
2733            .push(responder_id.to_proto());
2734    }
2735    for connection_id in pool.user_connection_ids(requester_id) {
2736        session.peer.send(connection_id, update.clone())?;
2737    }
2738
2739    // Update incoming contact requests of responder
2740    let mut update = proto::UpdateContacts::default();
2741    if contact_accepted {
2742        update.remove_contacts.push(requester_id.to_proto());
2743    } else {
2744        update
2745            .remove_incoming_requests
2746            .push(requester_id.to_proto());
2747    }
2748    for connection_id in pool.user_connection_ids(responder_id) {
2749        session.peer.send(connection_id, update.clone())?;
2750        if let Some(notification_id) = deleted_notification_id {
2751            session.peer.send(
2752                connection_id,
2753                proto::DeleteNotification {
2754                    notification_id: notification_id.to_proto(),
2755                },
2756            )?;
2757        }
2758    }
2759
2760    response.send(proto::Ack {})?;
2761    Ok(())
2762}
2763
2764fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2765    version.0.minor < 139
2766}
2767
2768async fn subscribe_to_channels(
2769    _: proto::SubscribeToChannels,
2770    session: MessageContext,
2771) -> Result<()> {
2772    subscribe_user_to_channels(session.user_id(), &session).await?;
2773    Ok(())
2774}
2775
2776async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2777    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2778    let mut pool = session.connection_pool().await;
2779    for membership in &channels_for_user.channel_memberships {
2780        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2781    }
2782    session.peer.send(
2783        session.connection_id,
2784        build_update_user_channels(&channels_for_user),
2785    )?;
2786    session.peer.send(
2787        session.connection_id,
2788        build_channels_update(channels_for_user),
2789    )?;
2790    Ok(())
2791}
2792
2793/// Creates a new channel.
2794async fn create_channel(
2795    request: proto::CreateChannel,
2796    response: Response<proto::CreateChannel>,
2797    session: MessageContext,
2798) -> Result<()> {
2799    let db = session.db().await;
2800
2801    let parent_id = request.parent_id.map(ChannelId::from_proto);
2802    let (channel, membership) = db
2803        .create_channel(&request.name, parent_id, session.user_id())
2804        .await?;
2805
2806    let root_id = channel.root_id();
2807    let channel = Channel::from_model(channel);
2808
2809    response.send(proto::CreateChannelResponse {
2810        channel: Some(channel.to_proto()),
2811        parent_id: request.parent_id,
2812    })?;
2813
2814    let mut connection_pool = session.connection_pool().await;
2815    if let Some(membership) = membership {
2816        connection_pool.subscribe_to_channel(
2817            membership.user_id,
2818            membership.channel_id,
2819            membership.role,
2820        );
2821        let update = proto::UpdateUserChannels {
2822            channel_memberships: vec![proto::ChannelMembership {
2823                channel_id: membership.channel_id.to_proto(),
2824                role: membership.role.into(),
2825            }],
2826            ..Default::default()
2827        };
2828        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2829            session.peer.send(connection_id, update.clone())?;
2830        }
2831    }
2832
2833    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2834        if !role.can_see_channel(channel.visibility) {
2835            continue;
2836        }
2837
2838        let update = proto::UpdateChannels {
2839            channels: vec![channel.to_proto()],
2840            ..Default::default()
2841        };
2842        session.peer.send(connection_id, update.clone())?;
2843    }
2844
2845    Ok(())
2846}
2847
2848/// Delete a channel
2849async fn delete_channel(
2850    request: proto::DeleteChannel,
2851    response: Response<proto::DeleteChannel>,
2852    session: MessageContext,
2853) -> Result<()> {
2854    let db = session.db().await;
2855
2856    let channel_id = request.channel_id;
2857    let (root_channel, removed_channels) = db
2858        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2859        .await?;
2860    response.send(proto::Ack {})?;
2861
2862    // Notify members of removed channels
2863    let mut update = proto::UpdateChannels::default();
2864    update
2865        .delete_channels
2866        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2867
2868    let connection_pool = session.connection_pool().await;
2869    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2870        session.peer.send(connection_id, update.clone())?;
2871    }
2872
2873    Ok(())
2874}
2875
2876/// Invite someone to join a channel.
2877async fn invite_channel_member(
2878    request: proto::InviteChannelMember,
2879    response: Response<proto::InviteChannelMember>,
2880    session: MessageContext,
2881) -> Result<()> {
2882    let db = session.db().await;
2883    let channel_id = ChannelId::from_proto(request.channel_id);
2884    let invitee_id = UserId::from_proto(request.user_id);
2885    let InviteMemberResult {
2886        channel,
2887        notifications,
2888    } = db
2889        .invite_channel_member(
2890            channel_id,
2891            invitee_id,
2892            session.user_id(),
2893            request.role().into(),
2894        )
2895        .await?;
2896
2897    let update = proto::UpdateChannels {
2898        channel_invitations: vec![channel.to_proto()],
2899        ..Default::default()
2900    };
2901
2902    let connection_pool = session.connection_pool().await;
2903    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2904        session.peer.send(connection_id, update.clone())?;
2905    }
2906
2907    send_notifications(&connection_pool, &session.peer, notifications);
2908
2909    response.send(proto::Ack {})?;
2910    Ok(())
2911}
2912
2913/// remove someone from a channel
2914async fn remove_channel_member(
2915    request: proto::RemoveChannelMember,
2916    response: Response<proto::RemoveChannelMember>,
2917    session: MessageContext,
2918) -> Result<()> {
2919    let db = session.db().await;
2920    let channel_id = ChannelId::from_proto(request.channel_id);
2921    let member_id = UserId::from_proto(request.user_id);
2922
2923    let RemoveChannelMemberResult {
2924        membership_update,
2925        notification_id,
2926    } = db
2927        .remove_channel_member(channel_id, member_id, session.user_id())
2928        .await?;
2929
2930    let mut connection_pool = session.connection_pool().await;
2931    notify_membership_updated(
2932        &mut connection_pool,
2933        membership_update,
2934        member_id,
2935        &session.peer,
2936    );
2937    for connection_id in connection_pool.user_connection_ids(member_id) {
2938        if let Some(notification_id) = notification_id {
2939            session
2940                .peer
2941                .send(
2942                    connection_id,
2943                    proto::DeleteNotification {
2944                        notification_id: notification_id.to_proto(),
2945                    },
2946                )
2947                .trace_err();
2948        }
2949    }
2950
2951    response.send(proto::Ack {})?;
2952    Ok(())
2953}
2954
2955/// Toggle the channel between public and private.
2956/// Care is taken to maintain the invariant that public channels only descend from public channels,
2957/// (though members-only channels can appear at any point in the hierarchy).
2958async fn set_channel_visibility(
2959    request: proto::SetChannelVisibility,
2960    response: Response<proto::SetChannelVisibility>,
2961    session: MessageContext,
2962) -> Result<()> {
2963    let db = session.db().await;
2964    let channel_id = ChannelId::from_proto(request.channel_id);
2965    let visibility = request.visibility().into();
2966
2967    let channel_model = db
2968        .set_channel_visibility(channel_id, visibility, session.user_id())
2969        .await?;
2970    let root_id = channel_model.root_id();
2971    let channel = Channel::from_model(channel_model);
2972
2973    let mut connection_pool = session.connection_pool().await;
2974    for (user_id, role) in connection_pool
2975        .channel_user_ids(root_id)
2976        .collect::<Vec<_>>()
2977        .into_iter()
2978    {
2979        let update = if role.can_see_channel(channel.visibility) {
2980            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2981            proto::UpdateChannels {
2982                channels: vec![channel.to_proto()],
2983                ..Default::default()
2984            }
2985        } else {
2986            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2987            proto::UpdateChannels {
2988                delete_channels: vec![channel.id.to_proto()],
2989                ..Default::default()
2990            }
2991        };
2992
2993        for connection_id in connection_pool.user_connection_ids(user_id) {
2994            session.peer.send(connection_id, update.clone())?;
2995        }
2996    }
2997
2998    response.send(proto::Ack {})?;
2999    Ok(())
3000}
3001
3002/// Alter the role for a user in the channel.
3003async fn set_channel_member_role(
3004    request: proto::SetChannelMemberRole,
3005    response: Response<proto::SetChannelMemberRole>,
3006    session: MessageContext,
3007) -> Result<()> {
3008    let db = session.db().await;
3009    let channel_id = ChannelId::from_proto(request.channel_id);
3010    let member_id = UserId::from_proto(request.user_id);
3011    let result = db
3012        .set_channel_member_role(
3013            channel_id,
3014            session.user_id(),
3015            member_id,
3016            request.role().into(),
3017        )
3018        .await?;
3019
3020    match result {
3021        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3022            let mut connection_pool = session.connection_pool().await;
3023            notify_membership_updated(
3024                &mut connection_pool,
3025                membership_update,
3026                member_id,
3027                &session.peer,
3028            )
3029        }
3030        db::SetMemberRoleResult::InviteUpdated(channel) => {
3031            let update = proto::UpdateChannels {
3032                channel_invitations: vec![channel.to_proto()],
3033                ..Default::default()
3034            };
3035
3036            for connection_id in session
3037                .connection_pool()
3038                .await
3039                .user_connection_ids(member_id)
3040            {
3041                session.peer.send(connection_id, update.clone())?;
3042            }
3043        }
3044    }
3045
3046    response.send(proto::Ack {})?;
3047    Ok(())
3048}
3049
3050/// Change the name of a channel
3051async fn rename_channel(
3052    request: proto::RenameChannel,
3053    response: Response<proto::RenameChannel>,
3054    session: MessageContext,
3055) -> Result<()> {
3056    let db = session.db().await;
3057    let channel_id = ChannelId::from_proto(request.channel_id);
3058    let channel_model = db
3059        .rename_channel(channel_id, session.user_id(), &request.name)
3060        .await?;
3061    let root_id = channel_model.root_id();
3062    let channel = Channel::from_model(channel_model);
3063
3064    response.send(proto::RenameChannelResponse {
3065        channel: Some(channel.to_proto()),
3066    })?;
3067
3068    let connection_pool = session.connection_pool().await;
3069    let update = proto::UpdateChannels {
3070        channels: vec![channel.to_proto()],
3071        ..Default::default()
3072    };
3073    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3074        if role.can_see_channel(channel.visibility) {
3075            session.peer.send(connection_id, update.clone())?;
3076        }
3077    }
3078
3079    Ok(())
3080}
3081
3082/// Move a channel to a new parent.
3083async fn move_channel(
3084    request: proto::MoveChannel,
3085    response: Response<proto::MoveChannel>,
3086    session: MessageContext,
3087) -> Result<()> {
3088    let channel_id = ChannelId::from_proto(request.channel_id);
3089    let to = ChannelId::from_proto(request.to);
3090
3091    let (root_id, channels) = session
3092        .db()
3093        .await
3094        .move_channel(channel_id, to, session.user_id())
3095        .await?;
3096
3097    let connection_pool = session.connection_pool().await;
3098    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3099        let channels = channels
3100            .iter()
3101            .filter_map(|channel| {
3102                if role.can_see_channel(channel.visibility) {
3103                    Some(channel.to_proto())
3104                } else {
3105                    None
3106                }
3107            })
3108            .collect::<Vec<_>>();
3109        if channels.is_empty() {
3110            continue;
3111        }
3112
3113        let update = proto::UpdateChannels {
3114            channels,
3115            ..Default::default()
3116        };
3117
3118        session.peer.send(connection_id, update.clone())?;
3119    }
3120
3121    response.send(Ack {})?;
3122    Ok(())
3123}
3124
3125async fn reorder_channel(
3126    request: proto::ReorderChannel,
3127    response: Response<proto::ReorderChannel>,
3128    session: MessageContext,
3129) -> Result<()> {
3130    let channel_id = ChannelId::from_proto(request.channel_id);
3131    let direction = request.direction();
3132
3133    let updated_channels = session
3134        .db()
3135        .await
3136        .reorder_channel(channel_id, direction, session.user_id())
3137        .await?;
3138
3139    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3140        let connection_pool = session.connection_pool().await;
3141        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3142            let channels = updated_channels
3143                .iter()
3144                .filter_map(|channel| {
3145                    if role.can_see_channel(channel.visibility) {
3146                        Some(channel.to_proto())
3147                    } else {
3148                        None
3149                    }
3150                })
3151                .collect::<Vec<_>>();
3152
3153            if channels.is_empty() {
3154                continue;
3155            }
3156
3157            let update = proto::UpdateChannels {
3158                channels,
3159                ..Default::default()
3160            };
3161
3162            session.peer.send(connection_id, update.clone())?;
3163        }
3164    }
3165
3166    response.send(Ack {})?;
3167    Ok(())
3168}
3169
3170/// Get the list of channel members
3171async fn get_channel_members(
3172    request: proto::GetChannelMembers,
3173    response: Response<proto::GetChannelMembers>,
3174    session: MessageContext,
3175) -> Result<()> {
3176    let db = session.db().await;
3177    let channel_id = ChannelId::from_proto(request.channel_id);
3178    let limit = if request.limit == 0 {
3179        u16::MAX as u64
3180    } else {
3181        request.limit
3182    };
3183    let (members, users) = db
3184        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3185        .await?;
3186    response.send(proto::GetChannelMembersResponse { members, users })?;
3187    Ok(())
3188}
3189
3190/// Accept or decline a channel invitation.
3191async fn respond_to_channel_invite(
3192    request: proto::RespondToChannelInvite,
3193    response: Response<proto::RespondToChannelInvite>,
3194    session: MessageContext,
3195) -> Result<()> {
3196    let db = session.db().await;
3197    let channel_id = ChannelId::from_proto(request.channel_id);
3198    let RespondToChannelInvite {
3199        membership_update,
3200        notifications,
3201    } = db
3202        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3203        .await?;
3204
3205    let mut connection_pool = session.connection_pool().await;
3206    if let Some(membership_update) = membership_update {
3207        notify_membership_updated(
3208            &mut connection_pool,
3209            membership_update,
3210            session.user_id(),
3211            &session.peer,
3212        );
3213    } else {
3214        let update = proto::UpdateChannels {
3215            remove_channel_invitations: vec![channel_id.to_proto()],
3216            ..Default::default()
3217        };
3218
3219        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3220            session.peer.send(connection_id, update.clone())?;
3221        }
3222    };
3223
3224    send_notifications(&connection_pool, &session.peer, notifications);
3225
3226    response.send(proto::Ack {})?;
3227
3228    Ok(())
3229}
3230
3231/// Join the channels' room
3232async fn join_channel(
3233    request: proto::JoinChannel,
3234    response: Response<proto::JoinChannel>,
3235    session: MessageContext,
3236) -> Result<()> {
3237    let channel_id = ChannelId::from_proto(request.channel_id);
3238    join_channel_internal(channel_id, Box::new(response), session).await
3239}
3240
3241trait JoinChannelInternalResponse {
3242    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3243}
3244impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3245    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3246        Response::<proto::JoinChannel>::send(self, result)
3247    }
3248}
3249impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3250    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3251        Response::<proto::JoinRoom>::send(self, result)
3252    }
3253}
3254
3255async fn join_channel_internal(
3256    channel_id: ChannelId,
3257    response: Box<impl JoinChannelInternalResponse>,
3258    session: MessageContext,
3259) -> Result<()> {
3260    let joined_room = {
3261        let mut db = session.db().await;
3262        // If zed quits without leaving the room, and the user re-opens zed before the
3263        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3264        // room they were in.
3265        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3266            tracing::info!(
3267                stale_connection_id = %connection,
3268                "cleaning up stale connection",
3269            );
3270            drop(db);
3271            leave_room_for_session(&session, connection).await?;
3272            db = session.db().await;
3273        }
3274
3275        let (joined_room, membership_updated, role) = db
3276            .join_channel(channel_id, session.user_id(), session.connection_id)
3277            .await?;
3278
3279        let live_kit_connection_info =
3280            session
3281                .app_state
3282                .livekit_client
3283                .as_ref()
3284                .and_then(|live_kit| {
3285                    let (can_publish, token) = if role == ChannelRole::Guest {
3286                        (
3287                            false,
3288                            live_kit
3289                                .guest_token(
3290                                    &joined_room.room.livekit_room,
3291                                    &session.user_id().to_string(),
3292                                )
3293                                .trace_err()?,
3294                        )
3295                    } else {
3296                        (
3297                            true,
3298                            live_kit
3299                                .room_token(
3300                                    &joined_room.room.livekit_room,
3301                                    &session.user_id().to_string(),
3302                                )
3303                                .trace_err()?,
3304                        )
3305                    };
3306
3307                    Some(LiveKitConnectionInfo {
3308                        server_url: live_kit.url().into(),
3309                        token,
3310                        can_publish,
3311                    })
3312                });
3313
3314        response.send(proto::JoinRoomResponse {
3315            room: Some(joined_room.room.clone()),
3316            channel_id: joined_room
3317                .channel
3318                .as_ref()
3319                .map(|channel| channel.id.to_proto()),
3320            live_kit_connection_info,
3321        })?;
3322
3323        let mut connection_pool = session.connection_pool().await;
3324        if let Some(membership_updated) = membership_updated {
3325            notify_membership_updated(
3326                &mut connection_pool,
3327                membership_updated,
3328                session.user_id(),
3329                &session.peer,
3330            );
3331        }
3332
3333        room_updated(&joined_room.room, &session.peer);
3334
3335        joined_room
3336    };
3337
3338    channel_updated(
3339        &joined_room.channel.context("channel not returned")?,
3340        &joined_room.room,
3341        &session.peer,
3342        &*session.connection_pool().await,
3343    );
3344
3345    update_user_contacts(session.user_id(), &session).await?;
3346    Ok(())
3347}
3348
3349/// Start editing the channel notes
3350async fn join_channel_buffer(
3351    request: proto::JoinChannelBuffer,
3352    response: Response<proto::JoinChannelBuffer>,
3353    session: MessageContext,
3354) -> Result<()> {
3355    let db = session.db().await;
3356    let channel_id = ChannelId::from_proto(request.channel_id);
3357
3358    let open_response = db
3359        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3360        .await?;
3361
3362    let collaborators = open_response.collaborators.clone();
3363    response.send(open_response)?;
3364
3365    let update = UpdateChannelBufferCollaborators {
3366        channel_id: channel_id.to_proto(),
3367        collaborators: collaborators.clone(),
3368    };
3369    channel_buffer_updated(
3370        session.connection_id,
3371        collaborators
3372            .iter()
3373            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3374        &update,
3375        &session.peer,
3376    );
3377
3378    Ok(())
3379}
3380
3381/// Edit the channel notes
3382async fn update_channel_buffer(
3383    request: proto::UpdateChannelBuffer,
3384    session: MessageContext,
3385) -> Result<()> {
3386    let db = session.db().await;
3387    let channel_id = ChannelId::from_proto(request.channel_id);
3388
3389    let (collaborators, epoch, version) = db
3390        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3391        .await?;
3392
3393    channel_buffer_updated(
3394        session.connection_id,
3395        collaborators.clone(),
3396        &proto::UpdateChannelBuffer {
3397            channel_id: channel_id.to_proto(),
3398            operations: request.operations,
3399        },
3400        &session.peer,
3401    );
3402
3403    let pool = &*session.connection_pool().await;
3404
3405    let non_collaborators =
3406        pool.channel_connection_ids(channel_id)
3407            .filter_map(|(connection_id, _)| {
3408                if collaborators.contains(&connection_id) {
3409                    None
3410                } else {
3411                    Some(connection_id)
3412                }
3413            });
3414
3415    broadcast(None, non_collaborators, |peer_id| {
3416        session.peer.send(
3417            peer_id,
3418            proto::UpdateChannels {
3419                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3420                    channel_id: channel_id.to_proto(),
3421                    epoch: epoch as u64,
3422                    version: version.clone(),
3423                }],
3424                ..Default::default()
3425            },
3426        )
3427    });
3428
3429    Ok(())
3430}
3431
3432/// Rejoin the channel notes after a connection blip
3433async fn rejoin_channel_buffers(
3434    request: proto::RejoinChannelBuffers,
3435    response: Response<proto::RejoinChannelBuffers>,
3436    session: MessageContext,
3437) -> Result<()> {
3438    let db = session.db().await;
3439    let buffers = db
3440        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3441        .await?;
3442
3443    for rejoined_buffer in &buffers {
3444        let collaborators_to_notify = rejoined_buffer
3445            .buffer
3446            .collaborators
3447            .iter()
3448            .filter_map(|c| Some(c.peer_id?.into()));
3449        channel_buffer_updated(
3450            session.connection_id,
3451            collaborators_to_notify,
3452            &proto::UpdateChannelBufferCollaborators {
3453                channel_id: rejoined_buffer.buffer.channel_id,
3454                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3455            },
3456            &session.peer,
3457        );
3458    }
3459
3460    response.send(proto::RejoinChannelBuffersResponse {
3461        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3462    })?;
3463
3464    Ok(())
3465}
3466
3467/// Stop editing the channel notes
3468async fn leave_channel_buffer(
3469    request: proto::LeaveChannelBuffer,
3470    response: Response<proto::LeaveChannelBuffer>,
3471    session: MessageContext,
3472) -> Result<()> {
3473    let db = session.db().await;
3474    let channel_id = ChannelId::from_proto(request.channel_id);
3475
3476    let left_buffer = db
3477        .leave_channel_buffer(channel_id, session.connection_id)
3478        .await?;
3479
3480    response.send(Ack {})?;
3481
3482    channel_buffer_updated(
3483        session.connection_id,
3484        left_buffer.connections,
3485        &proto::UpdateChannelBufferCollaborators {
3486            channel_id: channel_id.to_proto(),
3487            collaborators: left_buffer.collaborators,
3488        },
3489        &session.peer,
3490    );
3491
3492    Ok(())
3493}
3494
3495fn channel_buffer_updated<T: EnvelopedMessage>(
3496    sender_id: ConnectionId,
3497    collaborators: impl IntoIterator<Item = ConnectionId>,
3498    message: &T,
3499    peer: &Peer,
3500) {
3501    broadcast(Some(sender_id), collaborators, |peer_id| {
3502        peer.send(peer_id, message.clone())
3503    });
3504}
3505
3506fn send_notifications(
3507    connection_pool: &ConnectionPool,
3508    peer: &Peer,
3509    notifications: db::NotificationBatch,
3510) {
3511    for (user_id, notification) in notifications {
3512        for connection_id in connection_pool.user_connection_ids(user_id) {
3513            if let Err(error) = peer.send(
3514                connection_id,
3515                proto::AddNotification {
3516                    notification: Some(notification.clone()),
3517                },
3518            ) {
3519                tracing::error!(
3520                    "failed to send notification to {:?} {}",
3521                    connection_id,
3522                    error
3523                );
3524            }
3525        }
3526    }
3527}
3528
3529/// Send a message to the channel
3530async fn send_channel_message(
3531    _request: proto::SendChannelMessage,
3532    _response: Response<proto::SendChannelMessage>,
3533    _session: MessageContext,
3534) -> Result<()> {
3535    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3536}
3537
3538/// Delete a channel message
3539async fn remove_channel_message(
3540    _request: proto::RemoveChannelMessage,
3541    _response: Response<proto::RemoveChannelMessage>,
3542    _session: MessageContext,
3543) -> Result<()> {
3544    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3545}
3546
3547async fn update_channel_message(
3548    _request: proto::UpdateChannelMessage,
3549    _response: Response<proto::UpdateChannelMessage>,
3550    _session: MessageContext,
3551) -> Result<()> {
3552    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3553}
3554
3555/// Mark a channel message as read
3556async fn acknowledge_channel_message(
3557    _request: proto::AckChannelMessage,
3558    _session: MessageContext,
3559) -> Result<()> {
3560    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3561}
3562
3563/// Mark a buffer version as synced
3564async fn acknowledge_buffer_version(
3565    request: proto::AckBufferOperation,
3566    session: MessageContext,
3567) -> Result<()> {
3568    let buffer_id = BufferId::from_proto(request.buffer_id);
3569    session
3570        .db()
3571        .await
3572        .observe_buffer_version(
3573            buffer_id,
3574            session.user_id(),
3575            request.epoch as i32,
3576            &request.version,
3577        )
3578        .await?;
3579    Ok(())
3580}
3581
3582/// Start receiving chat updates for a channel
3583async fn join_channel_chat(
3584    _request: proto::JoinChannelChat,
3585    _response: Response<proto::JoinChannelChat>,
3586    _session: MessageContext,
3587) -> Result<()> {
3588    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3589}
3590
3591/// Stop receiving chat updates for a channel
3592async fn leave_channel_chat(
3593    _request: proto::LeaveChannelChat,
3594    _session: MessageContext,
3595) -> Result<()> {
3596    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3597}
3598
3599/// Retrieve the chat history for a channel
3600async fn get_channel_messages(
3601    _request: proto::GetChannelMessages,
3602    _response: Response<proto::GetChannelMessages>,
3603    _session: MessageContext,
3604) -> Result<()> {
3605    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3606}
3607
3608/// Retrieve specific chat messages
3609async fn get_channel_messages_by_id(
3610    _request: proto::GetChannelMessagesById,
3611    _response: Response<proto::GetChannelMessagesById>,
3612    _session: MessageContext,
3613) -> Result<()> {
3614    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3615}
3616
3617/// Retrieve the current users notifications
3618async fn get_notifications(
3619    request: proto::GetNotifications,
3620    response: Response<proto::GetNotifications>,
3621    session: MessageContext,
3622) -> Result<()> {
3623    let notifications = session
3624        .db()
3625        .await
3626        .get_notifications(
3627            session.user_id(),
3628            NOTIFICATION_COUNT_PER_PAGE,
3629            request.before_id.map(db::NotificationId::from_proto),
3630        )
3631        .await?;
3632    response.send(proto::GetNotificationsResponse {
3633        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3634        notifications,
3635    })?;
3636    Ok(())
3637}
3638
3639/// Mark notifications as read
3640async fn mark_notification_as_read(
3641    request: proto::MarkNotificationRead,
3642    response: Response<proto::MarkNotificationRead>,
3643    session: MessageContext,
3644) -> Result<()> {
3645    let database = &session.db().await;
3646    let notifications = database
3647        .mark_notification_as_read_by_id(
3648            session.user_id(),
3649            NotificationId::from_proto(request.notification_id),
3650        )
3651        .await?;
3652    send_notifications(
3653        &*session.connection_pool().await,
3654        &session.peer,
3655        notifications,
3656    );
3657    response.send(proto::Ack {})?;
3658    Ok(())
3659}
3660
3661fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3662    let message = match message {
3663        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3664        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3665        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3666        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3667        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3668            code: frame.code.into(),
3669            reason: frame.reason.as_str().to_owned().into(),
3670        })),
3671        // We should never receive a frame while reading the message, according
3672        // to the `tungstenite` maintainers:
3673        //
3674        // > It cannot occur when you read messages from the WebSocket, but it
3675        // > can be used when you want to send the raw frames (e.g. you want to
3676        // > send the frames to the WebSocket without composing the full message first).
3677        // >
3678        // > — https://github.com/snapview/tungstenite-rs/issues/268
3679        TungsteniteMessage::Frame(_) => {
3680            bail!("received an unexpected frame while reading the message")
3681        }
3682    };
3683
3684    Ok(message)
3685}
3686
3687fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3688    match message {
3689        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3690        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3691        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3692        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3693        AxumMessage::Close(frame) => {
3694            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3695                code: frame.code.into(),
3696                reason: frame.reason.as_ref().into(),
3697            }))
3698        }
3699    }
3700}
3701
3702fn notify_membership_updated(
3703    connection_pool: &mut ConnectionPool,
3704    result: MembershipUpdated,
3705    user_id: UserId,
3706    peer: &Peer,
3707) {
3708    for membership in &result.new_channels.channel_memberships {
3709        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3710    }
3711    for channel_id in &result.removed_channels {
3712        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3713    }
3714
3715    let user_channels_update = proto::UpdateUserChannels {
3716        channel_memberships: result
3717            .new_channels
3718            .channel_memberships
3719            .iter()
3720            .map(|cm| proto::ChannelMembership {
3721                channel_id: cm.channel_id.to_proto(),
3722                role: cm.role.into(),
3723            })
3724            .collect(),
3725        ..Default::default()
3726    };
3727
3728    let mut update = build_channels_update(result.new_channels);
3729    update.delete_channels = result
3730        .removed_channels
3731        .into_iter()
3732        .map(|id| id.to_proto())
3733        .collect();
3734    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3735
3736    for connection_id in connection_pool.user_connection_ids(user_id) {
3737        peer.send(connection_id, user_channels_update.clone())
3738            .trace_err();
3739        peer.send(connection_id, update.clone()).trace_err();
3740    }
3741}
3742
3743fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3744    proto::UpdateUserChannels {
3745        channel_memberships: channels
3746            .channel_memberships
3747            .iter()
3748            .map(|m| proto::ChannelMembership {
3749                channel_id: m.channel_id.to_proto(),
3750                role: m.role.into(),
3751            })
3752            .collect(),
3753        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3754    }
3755}
3756
3757fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3758    let mut update = proto::UpdateChannels::default();
3759
3760    for channel in channels.channels {
3761        update.channels.push(channel.to_proto());
3762    }
3763
3764    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3765
3766    for (channel_id, participants) in channels.channel_participants {
3767        update
3768            .channel_participants
3769            .push(proto::ChannelParticipants {
3770                channel_id: channel_id.to_proto(),
3771                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3772            });
3773    }
3774
3775    for channel in channels.invited_channels {
3776        update.channel_invitations.push(channel.to_proto());
3777    }
3778
3779    update
3780}
3781
3782fn build_initial_contacts_update(
3783    contacts: Vec<db::Contact>,
3784    pool: &ConnectionPool,
3785) -> proto::UpdateContacts {
3786    let mut update = proto::UpdateContacts::default();
3787
3788    for contact in contacts {
3789        match contact {
3790            db::Contact::Accepted { user_id, busy } => {
3791                update.contacts.push(contact_for_user(user_id, busy, pool));
3792            }
3793            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3794            db::Contact::Incoming { user_id } => {
3795                update
3796                    .incoming_requests
3797                    .push(proto::IncomingContactRequest {
3798                        requester_id: user_id.to_proto(),
3799                    })
3800            }
3801        }
3802    }
3803
3804    update
3805}
3806
3807fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3808    proto::Contact {
3809        user_id: user_id.to_proto(),
3810        online: pool.is_user_online(user_id),
3811        busy,
3812    }
3813}
3814
3815fn room_updated(room: &proto::Room, peer: &Peer) {
3816    broadcast(
3817        None,
3818        room.participants
3819            .iter()
3820            .filter_map(|participant| Some(participant.peer_id?.into())),
3821        |peer_id| {
3822            peer.send(
3823                peer_id,
3824                proto::RoomUpdated {
3825                    room: Some(room.clone()),
3826                },
3827            )
3828        },
3829    );
3830}
3831
3832fn channel_updated(
3833    channel: &db::channel::Model,
3834    room: &proto::Room,
3835    peer: &Peer,
3836    pool: &ConnectionPool,
3837) {
3838    let participants = room
3839        .participants
3840        .iter()
3841        .map(|p| p.user_id)
3842        .collect::<Vec<_>>();
3843
3844    broadcast(
3845        None,
3846        pool.channel_connection_ids(channel.root_id())
3847            .filter_map(|(channel_id, role)| {
3848                role.can_see_channel(channel.visibility)
3849                    .then_some(channel_id)
3850            }),
3851        |peer_id| {
3852            peer.send(
3853                peer_id,
3854                proto::UpdateChannels {
3855                    channel_participants: vec![proto::ChannelParticipants {
3856                        channel_id: channel.id.to_proto(),
3857                        participant_user_ids: participants.clone(),
3858                    }],
3859                    ..Default::default()
3860                },
3861            )
3862        },
3863    );
3864}
3865
3866async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3867    let db = session.db().await;
3868
3869    let contacts = db.get_contacts(user_id).await?;
3870    let busy = db.is_user_busy(user_id).await?;
3871
3872    let pool = session.connection_pool().await;
3873    let updated_contact = contact_for_user(user_id, busy, &pool);
3874    for contact in contacts {
3875        if let db::Contact::Accepted {
3876            user_id: contact_user_id,
3877            ..
3878        } = contact
3879        {
3880            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3881                session
3882                    .peer
3883                    .send(
3884                        contact_conn_id,
3885                        proto::UpdateContacts {
3886                            contacts: vec![updated_contact.clone()],
3887                            remove_contacts: Default::default(),
3888                            incoming_requests: Default::default(),
3889                            remove_incoming_requests: Default::default(),
3890                            outgoing_requests: Default::default(),
3891                            remove_outgoing_requests: Default::default(),
3892                        },
3893                    )
3894                    .trace_err();
3895            }
3896        }
3897    }
3898    Ok(())
3899}
3900
3901async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3902    let mut contacts_to_update = HashSet::default();
3903
3904    let room_id;
3905    let canceled_calls_to_user_ids;
3906    let livekit_room;
3907    let delete_livekit_room;
3908    let room;
3909    let channel;
3910
3911    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3912        contacts_to_update.insert(session.user_id());
3913
3914        for project in left_room.left_projects.values() {
3915            project_left(project, session);
3916        }
3917
3918        room_id = RoomId::from_proto(left_room.room.id);
3919        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3920        livekit_room = mem::take(&mut left_room.room.livekit_room);
3921        delete_livekit_room = left_room.deleted;
3922        room = mem::take(&mut left_room.room);
3923        channel = mem::take(&mut left_room.channel);
3924
3925        room_updated(&room, &session.peer);
3926    } else {
3927        return Ok(());
3928    }
3929
3930    if let Some(channel) = channel {
3931        channel_updated(
3932            &channel,
3933            &room,
3934            &session.peer,
3935            &*session.connection_pool().await,
3936        );
3937    }
3938
3939    {
3940        let pool = session.connection_pool().await;
3941        for canceled_user_id in canceled_calls_to_user_ids {
3942            for connection_id in pool.user_connection_ids(canceled_user_id) {
3943                session
3944                    .peer
3945                    .send(
3946                        connection_id,
3947                        proto::CallCanceled {
3948                            room_id: room_id.to_proto(),
3949                        },
3950                    )
3951                    .trace_err();
3952            }
3953            contacts_to_update.insert(canceled_user_id);
3954        }
3955    }
3956
3957    for contact_user_id in contacts_to_update {
3958        update_user_contacts(contact_user_id, session).await?;
3959    }
3960
3961    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3962        live_kit
3963            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3964            .await
3965            .trace_err();
3966
3967        if delete_livekit_room {
3968            live_kit.delete_room(livekit_room).await.trace_err();
3969        }
3970    }
3971
3972    Ok(())
3973}
3974
3975async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3976    let left_channel_buffers = session
3977        .db()
3978        .await
3979        .leave_channel_buffers(session.connection_id)
3980        .await?;
3981
3982    for left_buffer in left_channel_buffers {
3983        channel_buffer_updated(
3984            session.connection_id,
3985            left_buffer.connections,
3986            &proto::UpdateChannelBufferCollaborators {
3987                channel_id: left_buffer.channel_id.to_proto(),
3988                collaborators: left_buffer.collaborators,
3989            },
3990            &session.peer,
3991        );
3992    }
3993
3994    Ok(())
3995}
3996
3997fn project_left(project: &db::LeftProject, session: &Session) {
3998    for connection_id in &project.connection_ids {
3999        if project.should_unshare {
4000            session
4001                .peer
4002                .send(
4003                    *connection_id,
4004                    proto::UnshareProject {
4005                        project_id: project.id.to_proto(),
4006                    },
4007                )
4008                .trace_err();
4009        } else {
4010            session
4011                .peer
4012                .send(
4013                    *connection_id,
4014                    proto::RemoveProjectCollaborator {
4015                        project_id: project.id.to_proto(),
4016                        peer_id: Some(session.connection_id.into()),
4017                    },
4018                )
4019                .trace_err();
4020        }
4021    }
4022}
4023
4024async fn share_agent_thread(
4025    request: proto::ShareAgentThread,
4026    response: Response<proto::ShareAgentThread>,
4027    session: MessageContext,
4028) -> Result<()> {
4029    let user_id = session.user_id();
4030
4031    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4032        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4033
4034    session
4035        .db()
4036        .await
4037        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4038        .await?;
4039
4040    response.send(proto::Ack {})?;
4041
4042    Ok(())
4043}
4044
4045async fn get_shared_agent_thread(
4046    request: proto::GetSharedAgentThread,
4047    response: Response<proto::GetSharedAgentThread>,
4048    session: MessageContext,
4049) -> Result<()> {
4050    let share_id = SharedThreadId::from_proto(request.session_id)
4051        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4052
4053    let result = session.db().await.get_shared_thread(share_id).await?;
4054
4055    match result {
4056        Some((thread, username)) => {
4057            response.send(proto::GetSharedAgentThreadResponse {
4058                title: thread.title,
4059                thread_data: thread.data,
4060                sharer_username: username,
4061                created_at: thread.created_at.and_utc().to_rfc3339(),
4062            })?;
4063        }
4064        None => {
4065            return Err(anyhow!("Shared thread not found").into());
4066        }
4067    }
4068
4069    Ok(())
4070}
4071
4072pub trait ResultExt {
4073    type Ok;
4074
4075    fn trace_err(self) -> Option<Self::Ok>;
4076}
4077
4078impl<T, E> ResultExt for Result<T, E>
4079where
4080    E: std::fmt::Debug,
4081{
4082    type Ok = T;
4083
4084    #[track_caller]
4085    fn trace_err(self) -> Option<T> {
4086        match self {
4087            Ok(value) => Some(value),
4088            Err(error) => {
4089                tracing::error!("{:?}", error);
4090                None
4091            }
4092        }
4093    }
4094}