rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 415            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 416            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 417            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 419            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 421            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 422            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 423            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 426            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 427            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 430            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 431            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 437            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
 439            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 440            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 441            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 442            .add_request_handler(forward_mutating_project_request::<proto::GitEditRef>)
 443            .add_request_handler(forward_mutating_project_request::<proto::GitRepairWorktrees>)
 444            .add_request_handler(disallow_guest_request::<proto::GitCreateArchiveCheckpoint>)
 445            .add_request_handler(disallow_guest_request::<proto::GitRestoreArchiveCheckpoint>)
 446            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 447            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 448            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 449            .add_request_handler(share_agent_thread)
 450            .add_request_handler(get_shared_agent_thread)
 451            .add_request_handler(forward_project_search_chunk);
 452
 453        Arc::new(server)
 454    }
 455
 456    pub async fn start(&self) -> Result<()> {
 457        let server_id = *self.id.lock();
 458        let app_state = self.app_state.clone();
 459        let peer = self.peer.clone();
 460        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 461        let pool = self.connection_pool.clone();
 462        let livekit_client = self.app_state.livekit_client.clone();
 463
 464        let span = info_span!("start server");
 465        self.app_state.executor.spawn_detached(
 466            async move {
 467                tracing::info!("waiting for cleanup timeout");
 468                timeout.await;
 469                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 470
 471                app_state
 472                    .db
 473                    .delete_stale_channel_chat_participants(
 474                        &app_state.config.zed_environment,
 475                        server_id,
 476                    )
 477                    .await
 478                    .trace_err();
 479
 480                if let Some((room_ids, channel_ids)) = app_state
 481                    .db
 482                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 483                    .await
 484                    .trace_err()
 485                {
 486                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 487                    tracing::info!(
 488                        stale_channel_buffer_count = channel_ids.len(),
 489                        "retrieved stale channel buffers"
 490                    );
 491
 492                    for channel_id in channel_ids {
 493                        if let Some(refreshed_channel_buffer) = app_state
 494                            .db
 495                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 496                            .await
 497                            .trace_err()
 498                        {
 499                            for connection_id in refreshed_channel_buffer.connection_ids {
 500                                peer.send(
 501                                    connection_id,
 502                                    proto::UpdateChannelBufferCollaborators {
 503                                        channel_id: channel_id.to_proto(),
 504                                        collaborators: refreshed_channel_buffer
 505                                            .collaborators
 506                                            .clone(),
 507                                    },
 508                                )
 509                                .trace_err();
 510                            }
 511                        }
 512                    }
 513
 514                    for room_id in room_ids {
 515                        let mut contacts_to_update = HashSet::default();
 516                        let mut canceled_calls_to_user_ids = Vec::new();
 517                        let mut livekit_room = String::new();
 518                        let mut delete_livekit_room = false;
 519
 520                        if let Some(mut refreshed_room) = app_state
 521                            .db
 522                            .clear_stale_room_participants(room_id, server_id)
 523                            .await
 524                            .trace_err()
 525                        {
 526                            tracing::info!(
 527                                room_id = room_id.0,
 528                                new_participant_count = refreshed_room.room.participants.len(),
 529                                "refreshed room"
 530                            );
 531                            room_updated(&refreshed_room.room, &peer);
 532                            if let Some(channel) = refreshed_room.channel.as_ref() {
 533                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 534                            }
 535                            contacts_to_update
 536                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 537                            contacts_to_update
 538                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 539                            canceled_calls_to_user_ids =
 540                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 541                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 542                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 543                        }
 544
 545                        {
 546                            let pool = pool.lock();
 547                            for canceled_user_id in canceled_calls_to_user_ids {
 548                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 549                                    peer.send(
 550                                        connection_id,
 551                                        proto::CallCanceled {
 552                                            room_id: room_id.to_proto(),
 553                                        },
 554                                    )
 555                                    .trace_err();
 556                                }
 557                            }
 558                        }
 559
 560                        for user_id in contacts_to_update {
 561                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 562                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 563                            if let Some((busy, contacts)) = busy.zip(contacts) {
 564                                let pool = pool.lock();
 565                                let updated_contact = contact_for_user(user_id, busy, &pool);
 566                                for contact in contacts {
 567                                    if let db::Contact::Accepted {
 568                                        user_id: contact_user_id,
 569                                        ..
 570                                    } = contact
 571                                    {
 572                                        for contact_conn_id in
 573                                            pool.user_connection_ids(contact_user_id)
 574                                        {
 575                                            peer.send(
 576                                                contact_conn_id,
 577                                                proto::UpdateContacts {
 578                                                    contacts: vec![updated_contact.clone()],
 579                                                    remove_contacts: Default::default(),
 580                                                    incoming_requests: Default::default(),
 581                                                    remove_incoming_requests: Default::default(),
 582                                                    outgoing_requests: Default::default(),
 583                                                    remove_outgoing_requests: Default::default(),
 584                                                },
 585                                            )
 586                                            .trace_err();
 587                                        }
 588                                    }
 589                                }
 590                            }
 591                        }
 592
 593                        if let Some(live_kit) = livekit_client.as_ref()
 594                            && delete_livekit_room
 595                        {
 596                            live_kit.delete_room(livekit_room).await.trace_err();
 597                        }
 598                    }
 599                }
 600
 601                app_state
 602                    .db
 603                    .delete_stale_channel_chat_participants(
 604                        &app_state.config.zed_environment,
 605                        server_id,
 606                    )
 607                    .await
 608                    .trace_err();
 609
 610                app_state
 611                    .db
 612                    .clear_old_worktree_entries(server_id)
 613                    .await
 614                    .trace_err();
 615
 616                app_state
 617                    .db
 618                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 619                    .await
 620                    .trace_err();
 621            }
 622            .instrument(span),
 623        );
 624        Ok(())
 625    }
 626
 627    pub fn teardown(&self) {
 628        self.peer.teardown();
 629        self.connection_pool.lock().reset();
 630        let _ = self.teardown.send(true);
 631    }
 632
 633    #[cfg(feature = "test-support")]
 634    pub fn reset(&self, id: ServerId) {
 635        self.teardown();
 636        *self.id.lock() = id;
 637        self.peer.reset(id.0 as u32);
 638        let _ = self.teardown.send(false);
 639    }
 640
 641    #[cfg(feature = "test-support")]
 642    pub fn id(&self) -> ServerId {
 643        *self.id.lock()
 644    }
 645
 646    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 647    where
 648        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 649        Fut: 'static + Send + Future<Output = Result<()>>,
 650        M: EnvelopedMessage,
 651    {
 652        let prev_handler = self.handlers.insert(
 653            TypeId::of::<M>(),
 654            Box::new(move |envelope, session, span| {
 655                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 656                let received_at = envelope.received_at;
 657                tracing::info!("message received");
 658                let start_time = Instant::now();
 659                let future = (handler)(
 660                    *envelope,
 661                    MessageContext {
 662                        session,
 663                        span: span.clone(),
 664                    },
 665                );
 666                async move {
 667                    let result = future.await;
 668                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 669                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 670                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 671                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 672                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 673                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 674                    match result {
 675                        Err(error) => {
 676                            tracing::error!(?error, "error handling message")
 677                        }
 678                        Ok(()) => tracing::info!("finished handling message"),
 679                    }
 680                }
 681                .boxed()
 682            }),
 683        );
 684        if prev_handler.is_some() {
 685            panic!("registered a handler for the same message twice");
 686        }
 687        self
 688    }
 689
 690    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 691    where
 692        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 693        Fut: 'static + Send + Future<Output = Result<()>>,
 694        M: EnvelopedMessage,
 695    {
 696        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 697        self
 698    }
 699
 700    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 701    where
 702        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 703        Fut: Send + Future<Output = Result<()>>,
 704        M: RequestMessage,
 705    {
 706        let handler = Arc::new(handler);
 707        self.add_handler(move |envelope, session| {
 708            let receipt = envelope.receipt();
 709            let handler = handler.clone();
 710            async move {
 711                let peer = session.peer.clone();
 712                let responded = Arc::new(AtomicBool::default());
 713                let response = Response {
 714                    peer: peer.clone(),
 715                    responded: responded.clone(),
 716                    receipt,
 717                };
 718                match (handler)(envelope.payload, response, session).await {
 719                    Ok(()) => {
 720                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 721                            Ok(())
 722                        } else {
 723                            Err(anyhow!("handler did not send a response"))?
 724                        }
 725                    }
 726                    Err(error) => {
 727                        let proto_err = match &error {
 728                            Error::Internal(err) => err.to_proto(),
 729                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 730                        };
 731                        peer.respond_with_error(receipt, proto_err)?;
 732                        Err(error)
 733                    }
 734                }
 735            }
 736        })
 737    }
 738
 739    pub fn handle_connection(
 740        self: &Arc<Self>,
 741        connection: Connection,
 742        address: String,
 743        principal: Principal,
 744        zed_version: ZedVersion,
 745        release_channel: Option<String>,
 746        user_agent: Option<String>,
 747        geoip_country_code: Option<String>,
 748        system_id: Option<String>,
 749        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 750        executor: Executor,
 751        connection_guard: Option<ConnectionGuard>,
 752    ) -> impl Future<Output = ()> + use<> {
 753        let this = self.clone();
 754        let span = info_span!("handle connection", %address,
 755            connection_id=field::Empty,
 756            user_id=field::Empty,
 757            login=field::Empty,
 758            user_agent=field::Empty,
 759            geoip_country_code=field::Empty,
 760            release_channel=field::Empty,
 761        );
 762        principal.update_span(&span);
 763        if let Some(user_agent) = user_agent {
 764            span.record("user_agent", user_agent);
 765        }
 766        if let Some(release_channel) = release_channel {
 767            span.record("release_channel", release_channel);
 768        }
 769
 770        if let Some(country_code) = geoip_country_code.as_ref() {
 771            span.record("geoip_country_code", country_code);
 772        }
 773
 774        let mut teardown = self.teardown.subscribe();
 775        async move {
 776            if *teardown.borrow() {
 777                tracing::error!("server is tearing down");
 778                return;
 779            }
 780
 781            let (connection_id, handle_io, mut incoming_rx) =
 782                this.peer.add_connection(connection, {
 783                    let executor = executor.clone();
 784                    move |duration| executor.sleep(duration)
 785                });
 786            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 787
 788            tracing::info!("connection opened");
 789
 790            let session = Session {
 791                principal: principal.clone(),
 792                connection_id,
 793                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 794                peer: this.peer.clone(),
 795                connection_pool: this.connection_pool.clone(),
 796                app_state: this.app_state.clone(),
 797                geoip_country_code,
 798                system_id,
 799                _executor: executor.clone(),
 800            };
 801
 802            if let Err(error) = this
 803                .send_initial_client_update(
 804                    connection_id,
 805                    zed_version,
 806                    send_connection_id,
 807                    &session,
 808                )
 809                .await
 810            {
 811                tracing::error!(?error, "failed to send initial client update");
 812                return;
 813            }
 814            drop(connection_guard);
 815
 816            let handle_io = handle_io.fuse();
 817            futures::pin_mut!(handle_io);
 818
 819            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 820            // This prevents deadlocks when e.g., client A performs a request to client B and
 821            // client B performs a request to client A. If both clients stop processing further
 822            // messages until their respective request completes, they won't have a chance to
 823            // respond to the other client's request and cause a deadlock.
 824            //
 825            // This arrangement ensures we will attempt to process earlier messages first, but fall
 826            // back to processing messages arrived later in the spirit of making progress.
 827            const MAX_CONCURRENT_HANDLERS: usize = 256;
 828            let mut foreground_message_handlers = FuturesUnordered::new();
 829            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 830            let get_concurrent_handlers = {
 831                let concurrent_handlers = concurrent_handlers.clone();
 832                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 833            };
 834            loop {
 835                let next_message = async {
 836                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 837                    let message = incoming_rx.next().await;
 838                    // Cache the concurrent_handlers here, so that we know what the
 839                    // queue looks like as each handler starts
 840                    (permit, message, get_concurrent_handlers())
 841                }
 842                .fuse();
 843                futures::pin_mut!(next_message);
 844                futures::select_biased! {
 845                    _ = teardown.changed().fuse() => return,
 846                    result = handle_io => {
 847                        if let Err(error) = result {
 848                            tracing::error!(?error, "error handling I/O");
 849                        }
 850                        break;
 851                    }
 852                    _ = foreground_message_handlers.next() => {}
 853                    next_message = next_message => {
 854                        let (permit, message, concurrent_handlers) = next_message;
 855                        if let Some(message) = message {
 856                            let type_name = message.payload_type_name();
 857                            // note: we copy all the fields from the parent span so we can query them in the logs.
 858                            // (https://github.com/tokio-rs/tracing/issues/2670).
 859                            let span = tracing::info_span!("receive message",
 860                                %connection_id,
 861                                %address,
 862                                type_name,
 863                                concurrent_handlers,
 864                                user_id=field::Empty,
 865                                login=field::Empty,
 866                                lsp_query_request=field::Empty,
 867                                release_channel=field::Empty,
 868                                { TOTAL_DURATION_MS }=field::Empty,
 869                                { PROCESSING_DURATION_MS }=field::Empty,
 870                                { QUEUE_DURATION_MS }=field::Empty,
 871                                { HOST_WAITING_MS }=field::Empty
 872                            );
 873                            principal.update_span(&span);
 874                            let span_enter = span.enter();
 875                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 876                                let is_background = message.is_background();
 877                                let handle_message = (handler)(message, session.clone(), span.clone());
 878                                drop(span_enter);
 879
 880                                let handle_message = async move {
 881                                    handle_message.await;
 882                                    drop(permit);
 883                                }.instrument(span);
 884                                if is_background {
 885                                    executor.spawn_detached(handle_message);
 886                                } else {
 887                                    foreground_message_handlers.push(handle_message);
 888                                }
 889                            } else {
 890                                tracing::error!("no message handler");
 891                            }
 892                        } else {
 893                            tracing::info!("connection closed");
 894                            break;
 895                        }
 896                    }
 897                }
 898            }
 899
 900            drop(foreground_message_handlers);
 901            let concurrent_handlers = get_concurrent_handlers();
 902            tracing::info!(concurrent_handlers, "signing out");
 903            if let Err(error) = connection_lost(session, teardown, executor).await {
 904                tracing::error!(?error, "error signing out");
 905            }
 906        }
 907        .instrument(span)
 908    }
 909
 910    async fn send_initial_client_update(
 911        &self,
 912        connection_id: ConnectionId,
 913        zed_version: ZedVersion,
 914        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 915        session: &Session,
 916    ) -> Result<()> {
 917        self.peer.send(
 918            connection_id,
 919            proto::Hello {
 920                peer_id: Some(connection_id.into()),
 921            },
 922        )?;
 923        tracing::info!("sent hello message");
 924        if let Some(send_connection_id) = send_connection_id.take() {
 925            let _ = send_connection_id.send(connection_id);
 926        }
 927
 928        match &session.principal {
 929            Principal::User(user) => {
 930                if !user.connected_once {
 931                    self.peer.send(connection_id, proto::ShowContacts {})?;
 932                    self.app_state
 933                        .db
 934                        .set_user_connected_once(user.id, true)
 935                        .await?;
 936                }
 937
 938                let contacts = self.app_state.db.get_contacts(user.id).await?;
 939
 940                {
 941                    let mut pool = self.connection_pool.lock();
 942                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 943                    self.peer.send(
 944                        connection_id,
 945                        build_initial_contacts_update(contacts, &pool),
 946                    )?;
 947                }
 948
 949                if should_auto_subscribe_to_channels(&zed_version) {
 950                    subscribe_user_to_channels(user.id, session).await?;
 951                }
 952
 953                if let Some(incoming_call) =
 954                    self.app_state.db.incoming_call_for_user(user.id).await?
 955                {
 956                    self.peer.send(connection_id, incoming_call)?;
 957                }
 958
 959                update_user_contacts(user.id, session).await?;
 960            }
 961        }
 962
 963        Ok(())
 964    }
 965}
 966
 967impl Deref for ConnectionPoolGuard<'_> {
 968    type Target = ConnectionPool;
 969
 970    fn deref(&self) -> &Self::Target {
 971        &self.guard
 972    }
 973}
 974
 975impl DerefMut for ConnectionPoolGuard<'_> {
 976    fn deref_mut(&mut self) -> &mut Self::Target {
 977        &mut self.guard
 978    }
 979}
 980
 981impl Drop for ConnectionPoolGuard<'_> {
 982    fn drop(&mut self) {
 983        #[cfg(feature = "test-support")]
 984        self.check_invariants();
 985    }
 986}
 987
 988fn broadcast<F>(
 989    sender_id: Option<ConnectionId>,
 990    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 991    mut f: F,
 992) where
 993    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 994{
 995    for receiver_id in receiver_ids {
 996        if Some(receiver_id) != sender_id
 997            && let Err(error) = f(receiver_id)
 998        {
 999            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1000        }
1001    }
1002}
1003
1004pub struct ProtocolVersion(u32);
1005
1006impl Header for ProtocolVersion {
1007    fn name() -> &'static HeaderName {
1008        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1009        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1010    }
1011
1012    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1013    where
1014        Self: Sized,
1015        I: Iterator<Item = &'i axum::http::HeaderValue>,
1016    {
1017        let version = values
1018            .next()
1019            .ok_or_else(axum::headers::Error::invalid)?
1020            .to_str()
1021            .map_err(|_| axum::headers::Error::invalid())?
1022            .parse()
1023            .map_err(|_| axum::headers::Error::invalid())?;
1024        Ok(Self(version))
1025    }
1026
1027    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1028        values.extend([self.0.to_string().parse().unwrap()]);
1029    }
1030}
1031
1032pub struct AppVersionHeader(Version);
1033impl Header for AppVersionHeader {
1034    fn name() -> &'static HeaderName {
1035        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1036        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1037    }
1038
1039    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1040    where
1041        Self: Sized,
1042        I: Iterator<Item = &'i axum::http::HeaderValue>,
1043    {
1044        let version = values
1045            .next()
1046            .ok_or_else(axum::headers::Error::invalid)?
1047            .to_str()
1048            .map_err(|_| axum::headers::Error::invalid())?
1049            .parse()
1050            .map_err(|_| axum::headers::Error::invalid())?;
1051        Ok(Self(version))
1052    }
1053
1054    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1055        values.extend([self.0.to_string().parse().unwrap()]);
1056    }
1057}
1058
1059#[derive(Debug)]
1060pub struct ReleaseChannelHeader(String);
1061
1062impl Header for ReleaseChannelHeader {
1063    fn name() -> &'static HeaderName {
1064        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1065        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1066    }
1067
1068    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1069    where
1070        Self: Sized,
1071        I: Iterator<Item = &'i axum::http::HeaderValue>,
1072    {
1073        Ok(Self(
1074            values
1075                .next()
1076                .ok_or_else(axum::headers::Error::invalid)?
1077                .to_str()
1078                .map_err(|_| axum::headers::Error::invalid())?
1079                .to_owned(),
1080        ))
1081    }
1082
1083    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1084        values.extend([self.0.parse().unwrap()]);
1085    }
1086}
1087
1088pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1089    Router::new()
1090        .route("/rpc", get(handle_websocket_request))
1091        .layer(
1092            ServiceBuilder::new()
1093                .layer(Extension(server.app_state.clone()))
1094                .layer(middleware::from_fn(auth::validate_header)),
1095        )
1096        .route("/metrics", get(handle_metrics))
1097        .layer(Extension(server))
1098}
1099
1100pub async fn handle_websocket_request(
1101    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1102    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1103    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1104    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1105    Extension(server): Extension<Arc<Server>>,
1106    Extension(principal): Extension<Principal>,
1107    user_agent: Option<TypedHeader<UserAgent>>,
1108    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1109    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1110    ws: WebSocketUpgrade,
1111) -> axum::response::Response {
1112    if protocol_version != rpc::PROTOCOL_VERSION {
1113        return (
1114            StatusCode::UPGRADE_REQUIRED,
1115            "client must be upgraded".to_string(),
1116        )
1117            .into_response();
1118    }
1119
1120    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1121        return (
1122            StatusCode::UPGRADE_REQUIRED,
1123            "no version header found".to_string(),
1124        )
1125            .into_response();
1126    };
1127
1128    let release_channel = release_channel_header.map(|header| header.0.0);
1129
1130    if !version.can_collaborate() {
1131        return (
1132            StatusCode::UPGRADE_REQUIRED,
1133            "client must be upgraded".to_string(),
1134        )
1135            .into_response();
1136    }
1137
1138    let socket_address = socket_address.to_string();
1139
1140    // Acquire connection guard before WebSocket upgrade
1141    let connection_guard = match ConnectionGuard::try_acquire() {
1142        Ok(guard) => guard,
1143        Err(()) => {
1144            return (
1145                StatusCode::SERVICE_UNAVAILABLE,
1146                "Too many concurrent connections",
1147            )
1148                .into_response();
1149        }
1150    };
1151
1152    ws.on_upgrade(move |socket| {
1153        let socket = socket
1154            .map_ok(to_tungstenite_message)
1155            .err_into()
1156            .with(|message| async move { to_axum_message(message) });
1157        let connection = Connection::new(Box::pin(socket));
1158        async move {
1159            server
1160                .handle_connection(
1161                    connection,
1162                    socket_address,
1163                    principal,
1164                    version,
1165                    release_channel,
1166                    user_agent.map(|header| header.to_string()),
1167                    country_code_header.map(|header| header.to_string()),
1168                    system_id_header.map(|header| header.to_string()),
1169                    None,
1170                    Executor::Production,
1171                    Some(connection_guard),
1172                )
1173                .await;
1174        }
1175    })
1176}
1177
1178pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1179    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1180    let connections_metric = CONNECTIONS_METRIC
1181        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1182
1183    let connections = server
1184        .connection_pool
1185        .lock()
1186        .connections()
1187        .filter(|connection| !connection.admin)
1188        .count();
1189    connections_metric.set(connections as _);
1190
1191    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1192    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1193        register_int_gauge!(
1194            "shared_projects",
1195            "number of open projects with one or more guests"
1196        )
1197        .unwrap()
1198    });
1199
1200    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1201    shared_projects_metric.set(shared_projects as _);
1202
1203    let encoder = prometheus::TextEncoder::new();
1204    let metric_families = prometheus::gather();
1205    let encoded_metrics = encoder
1206        .encode_to_string(&metric_families)
1207        .map_err(|err| anyhow!("{err}"))?;
1208    Ok(encoded_metrics)
1209}
1210
1211#[instrument(err, skip(executor))]
1212async fn connection_lost(
1213    session: Session,
1214    mut teardown: watch::Receiver<bool>,
1215    executor: Executor,
1216) -> Result<()> {
1217    session.peer.disconnect(session.connection_id);
1218    session
1219        .connection_pool()
1220        .await
1221        .remove_connection(session.connection_id)?;
1222
1223    session
1224        .db()
1225        .await
1226        .connection_lost(session.connection_id)
1227        .await
1228        .trace_err();
1229
1230    futures::select_biased! {
1231        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1232
1233            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1234            leave_room_for_session(&session, session.connection_id).await.trace_err();
1235            leave_channel_buffers_for_session(&session)
1236                .await
1237                .trace_err();
1238
1239            if !session
1240                .connection_pool()
1241                .await
1242                .is_user_online(session.user_id())
1243            {
1244                let db = session.db().await;
1245                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1246                    room_updated(&room, &session.peer);
1247                }
1248            }
1249
1250            update_user_contacts(session.user_id(), &session).await?;
1251        },
1252        _ = teardown.changed().fuse() => {}
1253    }
1254
1255    Ok(())
1256}
1257
1258/// Acknowledges a ping from a client, used to keep the connection alive.
1259async fn ping(
1260    _: proto::Ping,
1261    response: Response<proto::Ping>,
1262    _session: MessageContext,
1263) -> Result<()> {
1264    response.send(proto::Ack {})?;
1265    Ok(())
1266}
1267
1268/// Creates a new room for calling (outside of channels)
1269async fn create_room(
1270    _request: proto::CreateRoom,
1271    response: Response<proto::CreateRoom>,
1272    session: MessageContext,
1273) -> Result<()> {
1274    let livekit_room = nanoid::nanoid!(30);
1275
1276    let live_kit_connection_info = util::maybe!(async {
1277        let live_kit = session.app_state.livekit_client.as_ref();
1278        let live_kit = live_kit?;
1279        let user_id = session.user_id().to_string();
1280
1281        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1282
1283        Some(proto::LiveKitConnectionInfo {
1284            server_url: live_kit.url().into(),
1285            token,
1286            can_publish: true,
1287        })
1288    })
1289    .await;
1290
1291    let room = session
1292        .db()
1293        .await
1294        .create_room(session.user_id(), session.connection_id, &livekit_room)
1295        .await?;
1296
1297    response.send(proto::CreateRoomResponse {
1298        room: Some(room.clone()),
1299        live_kit_connection_info,
1300    })?;
1301
1302    update_user_contacts(session.user_id(), &session).await?;
1303    Ok(())
1304}
1305
1306/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1307async fn join_room(
1308    request: proto::JoinRoom,
1309    response: Response<proto::JoinRoom>,
1310    session: MessageContext,
1311) -> Result<()> {
1312    let room_id = RoomId::from_proto(request.id);
1313
1314    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1315
1316    if let Some(channel_id) = channel_id {
1317        return join_channel_internal(channel_id, Box::new(response), session).await;
1318    }
1319
1320    let joined_room = {
1321        let room = session
1322            .db()
1323            .await
1324            .join_room(room_id, session.user_id(), session.connection_id)
1325            .await?;
1326        room_updated(&room.room, &session.peer);
1327        room.into_inner()
1328    };
1329
1330    for connection_id in session
1331        .connection_pool()
1332        .await
1333        .user_connection_ids(session.user_id())
1334    {
1335        session
1336            .peer
1337            .send(
1338                connection_id,
1339                proto::CallCanceled {
1340                    room_id: room_id.to_proto(),
1341                },
1342            )
1343            .trace_err();
1344    }
1345
1346    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1347    {
1348        live_kit
1349            .room_token(
1350                &joined_room.room.livekit_room,
1351                &session.user_id().to_string(),
1352            )
1353            .trace_err()
1354            .map(|token| proto::LiveKitConnectionInfo {
1355                server_url: live_kit.url().into(),
1356                token,
1357                can_publish: true,
1358            })
1359    } else {
1360        None
1361    };
1362
1363    response.send(proto::JoinRoomResponse {
1364        room: Some(joined_room.room),
1365        channel_id: None,
1366        live_kit_connection_info,
1367    })?;
1368
1369    update_user_contacts(session.user_id(), &session).await?;
1370    Ok(())
1371}
1372
1373/// Rejoin room is used to reconnect to a room after connection errors.
1374async fn rejoin_room(
1375    request: proto::RejoinRoom,
1376    response: Response<proto::RejoinRoom>,
1377    session: MessageContext,
1378) -> Result<()> {
1379    let room;
1380    let channel;
1381    {
1382        let mut rejoined_room = session
1383            .db()
1384            .await
1385            .rejoin_room(request, session.user_id(), session.connection_id)
1386            .await?;
1387
1388        response.send(proto::RejoinRoomResponse {
1389            room: Some(rejoined_room.room.clone()),
1390            reshared_projects: rejoined_room
1391                .reshared_projects
1392                .iter()
1393                .map(|project| proto::ResharedProject {
1394                    id: project.id.to_proto(),
1395                    collaborators: project
1396                        .collaborators
1397                        .iter()
1398                        .map(|collaborator| collaborator.to_proto())
1399                        .collect(),
1400                })
1401                .collect(),
1402            rejoined_projects: rejoined_room
1403                .rejoined_projects
1404                .iter()
1405                .map(|rejoined_project| rejoined_project.to_proto())
1406                .collect(),
1407        })?;
1408        room_updated(&rejoined_room.room, &session.peer);
1409
1410        for project in &rejoined_room.reshared_projects {
1411            for collaborator in &project.collaborators {
1412                session
1413                    .peer
1414                    .send(
1415                        collaborator.connection_id,
1416                        proto::UpdateProjectCollaborator {
1417                            project_id: project.id.to_proto(),
1418                            old_peer_id: Some(project.old_connection_id.into()),
1419                            new_peer_id: Some(session.connection_id.into()),
1420                        },
1421                    )
1422                    .trace_err();
1423            }
1424
1425            broadcast(
1426                Some(session.connection_id),
1427                project
1428                    .collaborators
1429                    .iter()
1430                    .map(|collaborator| collaborator.connection_id),
1431                |connection_id| {
1432                    session.peer.forward_send(
1433                        session.connection_id,
1434                        connection_id,
1435                        proto::UpdateProject {
1436                            project_id: project.id.to_proto(),
1437                            worktrees: project.worktrees.clone(),
1438                        },
1439                    )
1440                },
1441            );
1442        }
1443
1444        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1445
1446        let rejoined_room = rejoined_room.into_inner();
1447
1448        room = rejoined_room.room;
1449        channel = rejoined_room.channel;
1450    }
1451
1452    if let Some(channel) = channel {
1453        channel_updated(
1454            &channel,
1455            &room,
1456            &session.peer,
1457            &*session.connection_pool().await,
1458        );
1459    }
1460
1461    update_user_contacts(session.user_id(), &session).await?;
1462    Ok(())
1463}
1464
1465fn notify_rejoined_projects(
1466    rejoined_projects: &mut Vec<RejoinedProject>,
1467    session: &Session,
1468) -> Result<()> {
1469    for project in rejoined_projects.iter() {
1470        for collaborator in &project.collaborators {
1471            session
1472                .peer
1473                .send(
1474                    collaborator.connection_id,
1475                    proto::UpdateProjectCollaborator {
1476                        project_id: project.id.to_proto(),
1477                        old_peer_id: Some(project.old_connection_id.into()),
1478                        new_peer_id: Some(session.connection_id.into()),
1479                    },
1480                )
1481                .trace_err();
1482        }
1483    }
1484
1485    for project in rejoined_projects {
1486        for worktree in mem::take(&mut project.worktrees) {
1487            // Stream this worktree's entries.
1488            let message = proto::UpdateWorktree {
1489                project_id: project.id.to_proto(),
1490                worktree_id: worktree.id,
1491                abs_path: worktree.abs_path.clone(),
1492                root_name: worktree.root_name,
1493                root_repo_common_dir: worktree.root_repo_common_dir,
1494                updated_entries: worktree.updated_entries,
1495                removed_entries: worktree.removed_entries,
1496                scan_id: worktree.scan_id,
1497                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1498                updated_repositories: worktree.updated_repositories,
1499                removed_repositories: worktree.removed_repositories,
1500            };
1501            for update in proto::split_worktree_update(message) {
1502                session.peer.send(session.connection_id, update)?;
1503            }
1504
1505            // Stream this worktree's diagnostics.
1506            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1507            if let Some(summary) = worktree_diagnostics.next() {
1508                let message = proto::UpdateDiagnosticSummary {
1509                    project_id: project.id.to_proto(),
1510                    worktree_id: worktree.id,
1511                    summary: Some(summary),
1512                    more_summaries: worktree_diagnostics.collect(),
1513                };
1514                session.peer.send(session.connection_id, message)?;
1515            }
1516
1517            for settings_file in worktree.settings_files {
1518                session.peer.send(
1519                    session.connection_id,
1520                    proto::UpdateWorktreeSettings {
1521                        project_id: project.id.to_proto(),
1522                        worktree_id: worktree.id,
1523                        path: settings_file.path,
1524                        content: Some(settings_file.content),
1525                        kind: Some(settings_file.kind.to_proto().into()),
1526                        outside_worktree: Some(settings_file.outside_worktree),
1527                    },
1528                )?;
1529            }
1530        }
1531
1532        for repository in mem::take(&mut project.updated_repositories) {
1533            for update in split_repository_update(repository) {
1534                session.peer.send(session.connection_id, update)?;
1535            }
1536        }
1537
1538        for id in mem::take(&mut project.removed_repositories) {
1539            session.peer.send(
1540                session.connection_id,
1541                proto::RemoveRepository {
1542                    project_id: project.id.to_proto(),
1543                    id,
1544                },
1545            )?;
1546        }
1547    }
1548
1549    Ok(())
1550}
1551
1552/// leave room disconnects from the room.
1553async fn leave_room(
1554    _: proto::LeaveRoom,
1555    response: Response<proto::LeaveRoom>,
1556    session: MessageContext,
1557) -> Result<()> {
1558    leave_room_for_session(&session, session.connection_id).await?;
1559    response.send(proto::Ack {})?;
1560    Ok(())
1561}
1562
1563/// Updates the permissions of someone else in the room.
1564async fn set_room_participant_role(
1565    request: proto::SetRoomParticipantRole,
1566    response: Response<proto::SetRoomParticipantRole>,
1567    session: MessageContext,
1568) -> Result<()> {
1569    let user_id = UserId::from_proto(request.user_id);
1570    let role = ChannelRole::from(request.role());
1571
1572    let (livekit_room, can_publish) = {
1573        let room = session
1574            .db()
1575            .await
1576            .set_room_participant_role(
1577                session.user_id(),
1578                RoomId::from_proto(request.room_id),
1579                user_id,
1580                role,
1581            )
1582            .await?;
1583
1584        let livekit_room = room.livekit_room.clone();
1585        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1586        room_updated(&room, &session.peer);
1587        (livekit_room, can_publish)
1588    };
1589
1590    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1591        live_kit
1592            .update_participant(
1593                livekit_room.clone(),
1594                request.user_id.to_string(),
1595                livekit_api::proto::ParticipantPermission {
1596                    can_subscribe: true,
1597                    can_publish,
1598                    can_publish_data: can_publish,
1599                    hidden: false,
1600                    recorder: false,
1601                },
1602            )
1603            .await
1604            .trace_err();
1605    }
1606
1607    response.send(proto::Ack {})?;
1608    Ok(())
1609}
1610
1611/// Call someone else into the current room
1612async fn call(
1613    request: proto::Call,
1614    response: Response<proto::Call>,
1615    session: MessageContext,
1616) -> Result<()> {
1617    let room_id = RoomId::from_proto(request.room_id);
1618    let calling_user_id = session.user_id();
1619    let calling_connection_id = session.connection_id;
1620    let called_user_id = UserId::from_proto(request.called_user_id);
1621    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1622    if !session
1623        .db()
1624        .await
1625        .has_contact(calling_user_id, called_user_id)
1626        .await?
1627    {
1628        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1629    }
1630
1631    let incoming_call = {
1632        let (room, incoming_call) = &mut *session
1633            .db()
1634            .await
1635            .call(
1636                room_id,
1637                calling_user_id,
1638                calling_connection_id,
1639                called_user_id,
1640                initial_project_id,
1641            )
1642            .await?;
1643        room_updated(room, &session.peer);
1644        mem::take(incoming_call)
1645    };
1646    update_user_contacts(called_user_id, &session).await?;
1647
1648    let mut calls = session
1649        .connection_pool()
1650        .await
1651        .user_connection_ids(called_user_id)
1652        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1653        .collect::<FuturesUnordered<_>>();
1654
1655    while let Some(call_response) = calls.next().await {
1656        match call_response.as_ref() {
1657            Ok(_) => {
1658                response.send(proto::Ack {})?;
1659                return Ok(());
1660            }
1661            Err(_) => {
1662                call_response.trace_err();
1663            }
1664        }
1665    }
1666
1667    {
1668        let room = session
1669            .db()
1670            .await
1671            .call_failed(room_id, called_user_id)
1672            .await?;
1673        room_updated(&room, &session.peer);
1674    }
1675    update_user_contacts(called_user_id, &session).await?;
1676
1677    Err(anyhow!("failed to ring user"))?
1678}
1679
1680/// Cancel an outgoing call.
1681async fn cancel_call(
1682    request: proto::CancelCall,
1683    response: Response<proto::CancelCall>,
1684    session: MessageContext,
1685) -> Result<()> {
1686    let called_user_id = UserId::from_proto(request.called_user_id);
1687    let room_id = RoomId::from_proto(request.room_id);
1688    {
1689        let room = session
1690            .db()
1691            .await
1692            .cancel_call(room_id, session.connection_id, called_user_id)
1693            .await?;
1694        room_updated(&room, &session.peer);
1695    }
1696
1697    for connection_id in session
1698        .connection_pool()
1699        .await
1700        .user_connection_ids(called_user_id)
1701    {
1702        session
1703            .peer
1704            .send(
1705                connection_id,
1706                proto::CallCanceled {
1707                    room_id: room_id.to_proto(),
1708                },
1709            )
1710            .trace_err();
1711    }
1712    response.send(proto::Ack {})?;
1713
1714    update_user_contacts(called_user_id, &session).await?;
1715    Ok(())
1716}
1717
1718/// Decline an incoming call.
1719async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1720    let room_id = RoomId::from_proto(message.room_id);
1721    {
1722        let room = session
1723            .db()
1724            .await
1725            .decline_call(Some(room_id), session.user_id())
1726            .await?
1727            .context("declining call")?;
1728        room_updated(&room, &session.peer);
1729    }
1730
1731    for connection_id in session
1732        .connection_pool()
1733        .await
1734        .user_connection_ids(session.user_id())
1735    {
1736        session
1737            .peer
1738            .send(
1739                connection_id,
1740                proto::CallCanceled {
1741                    room_id: room_id.to_proto(),
1742                },
1743            )
1744            .trace_err();
1745    }
1746    update_user_contacts(session.user_id(), &session).await?;
1747    Ok(())
1748}
1749
1750/// Updates other participants in the room with your current location.
1751async fn update_participant_location(
1752    request: proto::UpdateParticipantLocation,
1753    response: Response<proto::UpdateParticipantLocation>,
1754    session: MessageContext,
1755) -> Result<()> {
1756    let room_id = RoomId::from_proto(request.room_id);
1757    let location = request.location.context("invalid location")?;
1758
1759    let db = session.db().await;
1760    let room = db
1761        .update_room_participant_location(room_id, session.connection_id, location)
1762        .await?;
1763
1764    room_updated(&room, &session.peer);
1765    response.send(proto::Ack {})?;
1766    Ok(())
1767}
1768
1769/// Share a project into the room.
1770async fn share_project(
1771    request: proto::ShareProject,
1772    response: Response<proto::ShareProject>,
1773    session: MessageContext,
1774) -> Result<()> {
1775    let (project_id, room) = &*session
1776        .db()
1777        .await
1778        .share_project(
1779            RoomId::from_proto(request.room_id),
1780            session.connection_id,
1781            &request.worktrees,
1782            request.is_ssh_project,
1783            request.windows_paths.unwrap_or(false),
1784            &request.features,
1785        )
1786        .await?;
1787    response.send(proto::ShareProjectResponse {
1788        project_id: project_id.to_proto(),
1789    })?;
1790    room_updated(room, &session.peer);
1791
1792    Ok(())
1793}
1794
1795/// Unshare a project from the room.
1796async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1797    let project_id = ProjectId::from_proto(message.project_id);
1798    unshare_project_internal(project_id, session.connection_id, &session).await
1799}
1800
1801async fn unshare_project_internal(
1802    project_id: ProjectId,
1803    connection_id: ConnectionId,
1804    session: &Session,
1805) -> Result<()> {
1806    let delete = {
1807        let room_guard = session
1808            .db()
1809            .await
1810            .unshare_project(project_id, connection_id)
1811            .await?;
1812
1813        let (delete, room, guest_connection_ids) = &*room_guard;
1814
1815        let message = proto::UnshareProject {
1816            project_id: project_id.to_proto(),
1817        };
1818
1819        broadcast(
1820            Some(connection_id),
1821            guest_connection_ids.iter().copied(),
1822            |conn_id| session.peer.send(conn_id, message.clone()),
1823        );
1824        if let Some(room) = room {
1825            room_updated(room, &session.peer);
1826        }
1827
1828        *delete
1829    };
1830
1831    if delete {
1832        let db = session.db().await;
1833        db.delete_project(project_id).await?;
1834    }
1835
1836    Ok(())
1837}
1838
1839/// Join someone elses shared project.
1840async fn join_project(
1841    request: proto::JoinProject,
1842    response: Response<proto::JoinProject>,
1843    session: MessageContext,
1844) -> Result<()> {
1845    let project_id = ProjectId::from_proto(request.project_id);
1846
1847    tracing::info!(%project_id, "join project");
1848
1849    let db = session.db().await;
1850    let project_model = db.get_project(project_id).await?;
1851    let host_features: Vec<String> =
1852        serde_json::from_str(&project_model.features).unwrap_or_default();
1853    let guest_features: HashSet<_> = request.features.iter().collect();
1854    let host_features_set: HashSet<_> = host_features.iter().collect();
1855    if guest_features != host_features_set {
1856        let host_connection_id = project_model.host_connection()?;
1857        let mut pool = session.connection_pool().await;
1858        let host_version = pool
1859            .connection(host_connection_id)
1860            .map(|c| c.zed_version.to_string());
1861        let guest_version = pool
1862            .connection(session.connection_id)
1863            .map(|c| c.zed_version.to_string());
1864        drop(pool);
1865        Err(anyhow!(
1866            "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1867            host_version.as_deref().unwrap_or("unknown"),
1868            guest_version.as_deref().unwrap_or("unknown"),
1869        ))?;
1870    }
1871
1872    let (project, replica_id) = &mut *db
1873        .join_project(
1874            project_id,
1875            session.connection_id,
1876            session.user_id(),
1877            request.committer_name.clone(),
1878            request.committer_email.clone(),
1879        )
1880        .await?;
1881    drop(db);
1882
1883    tracing::info!(%project_id, "join remote project");
1884    let collaborators = project
1885        .collaborators
1886        .iter()
1887        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1888        .map(|collaborator| collaborator.to_proto())
1889        .collect::<Vec<_>>();
1890    let project_id = project.id;
1891    let guest_user_id = session.user_id();
1892
1893    let worktrees = project
1894        .worktrees
1895        .iter()
1896        .map(|(id, worktree)| proto::WorktreeMetadata {
1897            id: *id,
1898            root_name: worktree.root_name.clone(),
1899            visible: worktree.visible,
1900            abs_path: worktree.abs_path.clone(),
1901            root_repo_common_dir: None,
1902        })
1903        .collect::<Vec<_>>();
1904
1905    let add_project_collaborator = proto::AddProjectCollaborator {
1906        project_id: project_id.to_proto(),
1907        collaborator: Some(proto::Collaborator {
1908            peer_id: Some(session.connection_id.into()),
1909            replica_id: replica_id.0 as u32,
1910            user_id: guest_user_id.to_proto(),
1911            is_host: false,
1912            committer_name: request.committer_name.clone(),
1913            committer_email: request.committer_email.clone(),
1914        }),
1915    };
1916
1917    for collaborator in &collaborators {
1918        session
1919            .peer
1920            .send(
1921                collaborator.peer_id.unwrap().into(),
1922                add_project_collaborator.clone(),
1923            )
1924            .trace_err();
1925    }
1926
1927    // First, we send the metadata associated with each worktree.
1928    let (language_servers, language_server_capabilities) = project
1929        .language_servers
1930        .clone()
1931        .into_iter()
1932        .map(|server| (server.server, server.capabilities))
1933        .unzip();
1934    response.send(proto::JoinProjectResponse {
1935        project_id: project.id.0 as u64,
1936        worktrees,
1937        replica_id: replica_id.0 as u32,
1938        collaborators,
1939        language_servers,
1940        language_server_capabilities,
1941        role: project.role.into(),
1942        windows_paths: project.path_style == PathStyle::Windows,
1943        features: project.features.clone(),
1944    })?;
1945
1946    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1947        // Stream this worktree's entries.
1948        let message = proto::UpdateWorktree {
1949            project_id: project_id.to_proto(),
1950            worktree_id,
1951            abs_path: worktree.abs_path.clone(),
1952            root_name: worktree.root_name,
1953            root_repo_common_dir: worktree.root_repo_common_dir,
1954            updated_entries: worktree.entries,
1955            removed_entries: Default::default(),
1956            scan_id: worktree.scan_id,
1957            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1958            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1959            removed_repositories: Default::default(),
1960        };
1961        for update in proto::split_worktree_update(message) {
1962            session.peer.send(session.connection_id, update.clone())?;
1963        }
1964
1965        // Stream this worktree's diagnostics.
1966        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1967        if let Some(summary) = worktree_diagnostics.next() {
1968            let message = proto::UpdateDiagnosticSummary {
1969                project_id: project.id.to_proto(),
1970                worktree_id: worktree.id,
1971                summary: Some(summary),
1972                more_summaries: worktree_diagnostics.collect(),
1973            };
1974            session.peer.send(session.connection_id, message)?;
1975        }
1976
1977        for settings_file in worktree.settings_files {
1978            session.peer.send(
1979                session.connection_id,
1980                proto::UpdateWorktreeSettings {
1981                    project_id: project_id.to_proto(),
1982                    worktree_id: worktree.id,
1983                    path: settings_file.path,
1984                    content: Some(settings_file.content),
1985                    kind: Some(settings_file.kind.to_proto() as i32),
1986                    outside_worktree: Some(settings_file.outside_worktree),
1987                },
1988            )?;
1989        }
1990    }
1991
1992    for repository in mem::take(&mut project.repositories) {
1993        for update in split_repository_update(repository) {
1994            session.peer.send(session.connection_id, update)?;
1995        }
1996    }
1997
1998    for language_server in &project.language_servers {
1999        session.peer.send(
2000            session.connection_id,
2001            proto::UpdateLanguageServer {
2002                project_id: project_id.to_proto(),
2003                server_name: Some(language_server.server.name.clone()),
2004                language_server_id: language_server.server.id,
2005                variant: Some(
2006                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2007                        proto::LspDiskBasedDiagnosticsUpdated {},
2008                    ),
2009                ),
2010            },
2011        )?;
2012    }
2013
2014    Ok(())
2015}
2016
2017/// Leave someone elses shared project.
2018async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2019    let sender_id = session.connection_id;
2020    let project_id = ProjectId::from_proto(request.project_id);
2021    let db = session.db().await;
2022
2023    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2024    tracing::info!(
2025        %project_id,
2026        "leave project"
2027    );
2028
2029    project_left(project, &session);
2030    if let Some(room) = room {
2031        room_updated(room, &session.peer);
2032    }
2033
2034    Ok(())
2035}
2036
2037/// Updates other participants with changes to the project
2038async fn update_project(
2039    request: proto::UpdateProject,
2040    response: Response<proto::UpdateProject>,
2041    session: MessageContext,
2042) -> Result<()> {
2043    let project_id = ProjectId::from_proto(request.project_id);
2044    let (room, guest_connection_ids) = &*session
2045        .db()
2046        .await
2047        .update_project(project_id, session.connection_id, &request.worktrees)
2048        .await?;
2049    broadcast(
2050        Some(session.connection_id),
2051        guest_connection_ids.iter().copied(),
2052        |connection_id| {
2053            session
2054                .peer
2055                .forward_send(session.connection_id, connection_id, request.clone())
2056        },
2057    );
2058    if let Some(room) = room {
2059        room_updated(room, &session.peer);
2060    }
2061    response.send(proto::Ack {})?;
2062
2063    Ok(())
2064}
2065
2066/// Updates other participants with changes to the worktree
2067async fn update_worktree(
2068    request: proto::UpdateWorktree,
2069    response: Response<proto::UpdateWorktree>,
2070    session: MessageContext,
2071) -> Result<()> {
2072    let guest_connection_ids = session
2073        .db()
2074        .await
2075        .update_worktree(&request, session.connection_id)
2076        .await?;
2077
2078    broadcast(
2079        Some(session.connection_id),
2080        guest_connection_ids.iter().copied(),
2081        |connection_id| {
2082            session
2083                .peer
2084                .forward_send(session.connection_id, connection_id, request.clone())
2085        },
2086    );
2087    response.send(proto::Ack {})?;
2088    Ok(())
2089}
2090
2091async fn update_repository(
2092    request: proto::UpdateRepository,
2093    response: Response<proto::UpdateRepository>,
2094    session: MessageContext,
2095) -> Result<()> {
2096    let guest_connection_ids = session
2097        .db()
2098        .await
2099        .update_repository(&request, session.connection_id)
2100        .await?;
2101
2102    broadcast(
2103        Some(session.connection_id),
2104        guest_connection_ids.iter().copied(),
2105        |connection_id| {
2106            session
2107                .peer
2108                .forward_send(session.connection_id, connection_id, request.clone())
2109        },
2110    );
2111    response.send(proto::Ack {})?;
2112    Ok(())
2113}
2114
2115async fn remove_repository(
2116    request: proto::RemoveRepository,
2117    response: Response<proto::RemoveRepository>,
2118    session: MessageContext,
2119) -> Result<()> {
2120    let guest_connection_ids = session
2121        .db()
2122        .await
2123        .remove_repository(&request, session.connection_id)
2124        .await?;
2125
2126    broadcast(
2127        Some(session.connection_id),
2128        guest_connection_ids.iter().copied(),
2129        |connection_id| {
2130            session
2131                .peer
2132                .forward_send(session.connection_id, connection_id, request.clone())
2133        },
2134    );
2135    response.send(proto::Ack {})?;
2136    Ok(())
2137}
2138
2139/// Updates other participants with changes to the diagnostics
2140async fn update_diagnostic_summary(
2141    message: proto::UpdateDiagnosticSummary,
2142    session: MessageContext,
2143) -> Result<()> {
2144    let guest_connection_ids = session
2145        .db()
2146        .await
2147        .update_diagnostic_summary(&message, session.connection_id)
2148        .await?;
2149
2150    broadcast(
2151        Some(session.connection_id),
2152        guest_connection_ids.iter().copied(),
2153        |connection_id| {
2154            session
2155                .peer
2156                .forward_send(session.connection_id, connection_id, message.clone())
2157        },
2158    );
2159
2160    Ok(())
2161}
2162
2163/// Updates other participants with changes to the worktree settings
2164async fn update_worktree_settings(
2165    message: proto::UpdateWorktreeSettings,
2166    session: MessageContext,
2167) -> Result<()> {
2168    let guest_connection_ids = session
2169        .db()
2170        .await
2171        .update_worktree_settings(&message, session.connection_id)
2172        .await?;
2173
2174    broadcast(
2175        Some(session.connection_id),
2176        guest_connection_ids.iter().copied(),
2177        |connection_id| {
2178            session
2179                .peer
2180                .forward_send(session.connection_id, connection_id, message.clone())
2181        },
2182    );
2183
2184    Ok(())
2185}
2186
2187/// Notify other participants that a language server has started.
2188async fn start_language_server(
2189    request: proto::StartLanguageServer,
2190    session: MessageContext,
2191) -> Result<()> {
2192    let guest_connection_ids = session
2193        .db()
2194        .await
2195        .start_language_server(&request, session.connection_id)
2196        .await?;
2197
2198    broadcast(
2199        Some(session.connection_id),
2200        guest_connection_ids.iter().copied(),
2201        |connection_id| {
2202            session
2203                .peer
2204                .forward_send(session.connection_id, connection_id, request.clone())
2205        },
2206    );
2207    Ok(())
2208}
2209
2210/// Notify other participants that a language server has changed.
2211async fn update_language_server(
2212    request: proto::UpdateLanguageServer,
2213    session: MessageContext,
2214) -> Result<()> {
2215    let project_id = ProjectId::from_proto(request.project_id);
2216    let db = session.db().await;
2217
2218    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2219        && let Some(capabilities) = update.capabilities.clone()
2220    {
2221        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2222            .await?;
2223    }
2224
2225    let project_connection_ids = db
2226        .project_connection_ids(project_id, session.connection_id, true)
2227        .await?;
2228    broadcast(
2229        Some(session.connection_id),
2230        project_connection_ids.iter().copied(),
2231        |connection_id| {
2232            session
2233                .peer
2234                .forward_send(session.connection_id, connection_id, request.clone())
2235        },
2236    );
2237    Ok(())
2238}
2239
2240/// forward a project request to the host. These requests should be read only
2241/// as guests are allowed to send them.
2242async fn forward_read_only_project_request<T>(
2243    request: T,
2244    response: Response<T>,
2245    session: MessageContext,
2246) -> Result<()>
2247where
2248    T: EntityMessage + RequestMessage,
2249{
2250    let project_id = ProjectId::from_proto(request.remote_entity_id());
2251    let host_connection_id = session
2252        .db()
2253        .await
2254        .host_for_read_only_project_request(project_id, session.connection_id)
2255        .await?;
2256    let payload = session.forward_request(host_connection_id, request).await?;
2257    response.send(payload)?;
2258    Ok(())
2259}
2260
2261/// forward a project request to the host. These requests are disallowed
2262/// for guests.
2263async fn forward_mutating_project_request<T>(
2264    request: T,
2265    response: Response<T>,
2266    session: MessageContext,
2267) -> Result<()>
2268where
2269    T: EntityMessage + RequestMessage,
2270{
2271    let project_id = ProjectId::from_proto(request.remote_entity_id());
2272
2273    let host_connection_id = session
2274        .db()
2275        .await
2276        .host_for_mutating_project_request(project_id, session.connection_id)
2277        .await?;
2278    let payload = session.forward_request(host_connection_id, request).await?;
2279    response.send(payload)?;
2280    Ok(())
2281}
2282
2283async fn disallow_guest_request<T>(
2284    _request: T,
2285    response: Response<T>,
2286    _session: MessageContext,
2287) -> Result<()>
2288where
2289    T: RequestMessage,
2290{
2291    response.peer.respond_with_error(
2292        response.receipt,
2293        ErrorCode::Forbidden
2294            .message("request is not allowed for guests".to_string())
2295            .to_proto(),
2296    )?;
2297    response.responded.store(true, SeqCst);
2298    Ok(())
2299}
2300
2301async fn lsp_query(
2302    request: proto::LspQuery,
2303    response: Response<proto::LspQuery>,
2304    session: MessageContext,
2305) -> Result<()> {
2306    let (name, should_write) = request.query_name_and_write_permissions();
2307    tracing::Span::current().record("lsp_query_request", name);
2308    tracing::info!("lsp_query message received");
2309    if should_write {
2310        forward_mutating_project_request(request, response, session).await
2311    } else {
2312        forward_read_only_project_request(request, response, session).await
2313    }
2314}
2315
2316/// Notify other participants that a new buffer has been created
2317async fn create_buffer_for_peer(
2318    request: proto::CreateBufferForPeer,
2319    session: MessageContext,
2320) -> Result<()> {
2321    session
2322        .db()
2323        .await
2324        .check_user_is_project_host(
2325            ProjectId::from_proto(request.project_id),
2326            session.connection_id,
2327        )
2328        .await?;
2329    let peer_id = request.peer_id.context("invalid peer id")?;
2330    session
2331        .peer
2332        .forward_send(session.connection_id, peer_id.into(), request)?;
2333    Ok(())
2334}
2335
2336/// Notify other participants that a new image has been created
2337async fn create_image_for_peer(
2338    request: proto::CreateImageForPeer,
2339    session: MessageContext,
2340) -> Result<()> {
2341    session
2342        .db()
2343        .await
2344        .check_user_is_project_host(
2345            ProjectId::from_proto(request.project_id),
2346            session.connection_id,
2347        )
2348        .await?;
2349    let peer_id = request.peer_id.context("invalid peer id")?;
2350    session
2351        .peer
2352        .forward_send(session.connection_id, peer_id.into(), request)?;
2353    Ok(())
2354}
2355
2356/// Notify other participants that a buffer has been updated. This is
2357/// allowed for guests as long as the update is limited to selections.
2358async fn update_buffer(
2359    request: proto::UpdateBuffer,
2360    response: Response<proto::UpdateBuffer>,
2361    session: MessageContext,
2362) -> Result<()> {
2363    let project_id = ProjectId::from_proto(request.project_id);
2364    let mut capability = Capability::ReadOnly;
2365
2366    for op in request.operations.iter() {
2367        match op.variant {
2368            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2369            Some(_) => capability = Capability::ReadWrite,
2370        }
2371    }
2372
2373    let host = {
2374        let guard = session
2375            .db()
2376            .await
2377            .connections_for_buffer_update(project_id, session.connection_id, capability)
2378            .await?;
2379
2380        let (host, guests) = &*guard;
2381
2382        broadcast(
2383            Some(session.connection_id),
2384            guests.clone(),
2385            |connection_id| {
2386                session
2387                    .peer
2388                    .forward_send(session.connection_id, connection_id, request.clone())
2389            },
2390        );
2391
2392        *host
2393    };
2394
2395    if host != session.connection_id {
2396        session.forward_request(host, request.clone()).await?;
2397    }
2398
2399    response.send(proto::Ack {})?;
2400    Ok(())
2401}
2402
2403async fn forward_project_search_chunk(
2404    message: proto::FindSearchCandidatesChunk,
2405    response: Response<proto::FindSearchCandidatesChunk>,
2406    session: MessageContext,
2407) -> Result<()> {
2408    let peer_id = message.peer_id.context("missing peer_id")?;
2409    let payload = session
2410        .peer
2411        .forward_request(session.connection_id, peer_id.into(), message)
2412        .await?;
2413    response.send(payload)?;
2414    Ok(())
2415}
2416
2417/// Notify other participants that a project has been updated.
2418async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2419    request: T,
2420    session: MessageContext,
2421) -> Result<()> {
2422    let project_id = ProjectId::from_proto(request.remote_entity_id());
2423    let project_connection_ids = session
2424        .db()
2425        .await
2426        .project_connection_ids(project_id, session.connection_id, false)
2427        .await?;
2428
2429    broadcast(
2430        Some(session.connection_id),
2431        project_connection_ids.iter().copied(),
2432        |connection_id| {
2433            session
2434                .peer
2435                .forward_send(session.connection_id, connection_id, request.clone())
2436        },
2437    );
2438    Ok(())
2439}
2440
2441/// Start following another user in a call.
2442async fn follow(
2443    request: proto::Follow,
2444    response: Response<proto::Follow>,
2445    session: MessageContext,
2446) -> Result<()> {
2447    let room_id = RoomId::from_proto(request.room_id);
2448    let project_id = request.project_id.map(ProjectId::from_proto);
2449    let leader_id = request.leader_id.context("invalid leader id")?.into();
2450    let follower_id = session.connection_id;
2451
2452    session
2453        .db()
2454        .await
2455        .check_room_participants(room_id, leader_id, session.connection_id)
2456        .await?;
2457
2458    let response_payload = session.forward_request(leader_id, request).await?;
2459    response.send(response_payload)?;
2460
2461    if let Some(project_id) = project_id {
2462        let room = session
2463            .db()
2464            .await
2465            .follow(room_id, project_id, leader_id, follower_id)
2466            .await?;
2467        room_updated(&room, &session.peer);
2468    }
2469
2470    Ok(())
2471}
2472
2473/// Stop following another user in a call.
2474async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2475    let room_id = RoomId::from_proto(request.room_id);
2476    let project_id = request.project_id.map(ProjectId::from_proto);
2477    let leader_id = request.leader_id.context("invalid leader id")?.into();
2478    let follower_id = session.connection_id;
2479
2480    session
2481        .db()
2482        .await
2483        .check_room_participants(room_id, leader_id, session.connection_id)
2484        .await?;
2485
2486    session
2487        .peer
2488        .forward_send(session.connection_id, leader_id, request)?;
2489
2490    if let Some(project_id) = project_id {
2491        let room = session
2492            .db()
2493            .await
2494            .unfollow(room_id, project_id, leader_id, follower_id)
2495            .await?;
2496        room_updated(&room, &session.peer);
2497    }
2498
2499    Ok(())
2500}
2501
2502/// Notify everyone following you of your current location.
2503async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2504    let room_id = RoomId::from_proto(request.room_id);
2505    let database = session.db.lock().await;
2506
2507    let connection_ids = if let Some(project_id) = request.project_id {
2508        let project_id = ProjectId::from_proto(project_id);
2509        database
2510            .project_connection_ids(project_id, session.connection_id, true)
2511            .await?
2512    } else {
2513        database
2514            .room_connection_ids(room_id, session.connection_id)
2515            .await?
2516    };
2517
2518    // For now, don't send view update messages back to that view's current leader.
2519    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2520        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2521        _ => None,
2522    });
2523
2524    for connection_id in connection_ids.iter().cloned() {
2525        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2526            session
2527                .peer
2528                .forward_send(session.connection_id, connection_id, request.clone())?;
2529        }
2530    }
2531    Ok(())
2532}
2533
2534/// Get public data about users.
2535async fn get_users(
2536    request: proto::GetUsers,
2537    response: Response<proto::GetUsers>,
2538    session: MessageContext,
2539) -> Result<()> {
2540    let user_ids = request
2541        .user_ids
2542        .into_iter()
2543        .map(UserId::from_proto)
2544        .collect();
2545    let users = session
2546        .db()
2547        .await
2548        .get_users_by_ids(user_ids)
2549        .await?
2550        .into_iter()
2551        .map(|user| proto::User {
2552            id: user.id.to_proto(),
2553            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2554            github_login: user.github_login,
2555            name: user.name,
2556        })
2557        .collect();
2558    response.send(proto::UsersResponse { users })?;
2559    Ok(())
2560}
2561
2562/// Search for users (to invite) buy Github login
2563async fn fuzzy_search_users(
2564    request: proto::FuzzySearchUsers,
2565    response: Response<proto::FuzzySearchUsers>,
2566    session: MessageContext,
2567) -> Result<()> {
2568    let query = request.query;
2569    let users = match query.len() {
2570        0 => vec![],
2571        1 | 2 => session
2572            .db()
2573            .await
2574            .get_user_by_github_login(&query)
2575            .await?
2576            .into_iter()
2577            .collect(),
2578        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2579    };
2580    let users = users
2581        .into_iter()
2582        .filter(|user| user.id != session.user_id())
2583        .map(|user| proto::User {
2584            id: user.id.to_proto(),
2585            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2586            github_login: user.github_login,
2587            name: user.name,
2588        })
2589        .collect();
2590    response.send(proto::UsersResponse { users })?;
2591    Ok(())
2592}
2593
2594/// Send a contact request to another user.
2595async fn request_contact(
2596    request: proto::RequestContact,
2597    response: Response<proto::RequestContact>,
2598    session: MessageContext,
2599) -> Result<()> {
2600    let requester_id = session.user_id();
2601    let responder_id = UserId::from_proto(request.responder_id);
2602    if requester_id == responder_id {
2603        return Err(anyhow!("cannot add yourself as a contact"))?;
2604    }
2605
2606    let notifications = session
2607        .db()
2608        .await
2609        .send_contact_request(requester_id, responder_id)
2610        .await?;
2611
2612    // Update outgoing contact requests of requester
2613    let mut update = proto::UpdateContacts::default();
2614    update.outgoing_requests.push(responder_id.to_proto());
2615    for connection_id in session
2616        .connection_pool()
2617        .await
2618        .user_connection_ids(requester_id)
2619    {
2620        session.peer.send(connection_id, update.clone())?;
2621    }
2622
2623    // Update incoming contact requests of responder
2624    let mut update = proto::UpdateContacts::default();
2625    update
2626        .incoming_requests
2627        .push(proto::IncomingContactRequest {
2628            requester_id: requester_id.to_proto(),
2629        });
2630    let connection_pool = session.connection_pool().await;
2631    for connection_id in connection_pool.user_connection_ids(responder_id) {
2632        session.peer.send(connection_id, update.clone())?;
2633    }
2634
2635    send_notifications(&connection_pool, &session.peer, notifications);
2636
2637    response.send(proto::Ack {})?;
2638    Ok(())
2639}
2640
2641/// Accept or decline a contact request
2642async fn respond_to_contact_request(
2643    request: proto::RespondToContactRequest,
2644    response: Response<proto::RespondToContactRequest>,
2645    session: MessageContext,
2646) -> Result<()> {
2647    let responder_id = session.user_id();
2648    let requester_id = UserId::from_proto(request.requester_id);
2649    let db = session.db().await;
2650    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2651        db.dismiss_contact_notification(responder_id, requester_id)
2652            .await?;
2653    } else {
2654        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2655
2656        let notifications = db
2657            .respond_to_contact_request(responder_id, requester_id, accept)
2658            .await?;
2659        let requester_busy = db.is_user_busy(requester_id).await?;
2660        let responder_busy = db.is_user_busy(responder_id).await?;
2661
2662        let pool = session.connection_pool().await;
2663        // Update responder with new contact
2664        let mut update = proto::UpdateContacts::default();
2665        if accept {
2666            update
2667                .contacts
2668                .push(contact_for_user(requester_id, requester_busy, &pool));
2669        }
2670        update
2671            .remove_incoming_requests
2672            .push(requester_id.to_proto());
2673        for connection_id in pool.user_connection_ids(responder_id) {
2674            session.peer.send(connection_id, update.clone())?;
2675        }
2676
2677        // Update requester with new contact
2678        let mut update = proto::UpdateContacts::default();
2679        if accept {
2680            update
2681                .contacts
2682                .push(contact_for_user(responder_id, responder_busy, &pool));
2683        }
2684        update
2685            .remove_outgoing_requests
2686            .push(responder_id.to_proto());
2687
2688        for connection_id in pool.user_connection_ids(requester_id) {
2689            session.peer.send(connection_id, update.clone())?;
2690        }
2691
2692        send_notifications(&pool, &session.peer, notifications);
2693    }
2694
2695    response.send(proto::Ack {})?;
2696    Ok(())
2697}
2698
2699/// Remove a contact.
2700async fn remove_contact(
2701    request: proto::RemoveContact,
2702    response: Response<proto::RemoveContact>,
2703    session: MessageContext,
2704) -> Result<()> {
2705    let requester_id = session.user_id();
2706    let responder_id = UserId::from_proto(request.user_id);
2707    let db = session.db().await;
2708    let (contact_accepted, deleted_notification_id) =
2709        db.remove_contact(requester_id, responder_id).await?;
2710
2711    let pool = session.connection_pool().await;
2712    // Update outgoing contact requests of requester
2713    let mut update = proto::UpdateContacts::default();
2714    if contact_accepted {
2715        update.remove_contacts.push(responder_id.to_proto());
2716    } else {
2717        update
2718            .remove_outgoing_requests
2719            .push(responder_id.to_proto());
2720    }
2721    for connection_id in pool.user_connection_ids(requester_id) {
2722        session.peer.send(connection_id, update.clone())?;
2723    }
2724
2725    // Update incoming contact requests of responder
2726    let mut update = proto::UpdateContacts::default();
2727    if contact_accepted {
2728        update.remove_contacts.push(requester_id.to_proto());
2729    } else {
2730        update
2731            .remove_incoming_requests
2732            .push(requester_id.to_proto());
2733    }
2734    for connection_id in pool.user_connection_ids(responder_id) {
2735        session.peer.send(connection_id, update.clone())?;
2736        if let Some(notification_id) = deleted_notification_id {
2737            session.peer.send(
2738                connection_id,
2739                proto::DeleteNotification {
2740                    notification_id: notification_id.to_proto(),
2741                },
2742            )?;
2743        }
2744    }
2745
2746    response.send(proto::Ack {})?;
2747    Ok(())
2748}
2749
2750fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2751    version.0.minor < 139
2752}
2753
2754async fn subscribe_to_channels(
2755    _: proto::SubscribeToChannels,
2756    session: MessageContext,
2757) -> Result<()> {
2758    subscribe_user_to_channels(session.user_id(), &session).await?;
2759    Ok(())
2760}
2761
2762async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2763    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2764    let mut pool = session.connection_pool().await;
2765    for membership in &channels_for_user.channel_memberships {
2766        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2767    }
2768    session.peer.send(
2769        session.connection_id,
2770        build_update_user_channels(&channels_for_user),
2771    )?;
2772    session.peer.send(
2773        session.connection_id,
2774        build_channels_update(channels_for_user),
2775    )?;
2776    Ok(())
2777}
2778
2779/// Creates a new channel.
2780async fn create_channel(
2781    request: proto::CreateChannel,
2782    response: Response<proto::CreateChannel>,
2783    session: MessageContext,
2784) -> Result<()> {
2785    let db = session.db().await;
2786
2787    let parent_id = request.parent_id.map(ChannelId::from_proto);
2788    let (channel, membership) = db
2789        .create_channel(&request.name, parent_id, session.user_id())
2790        .await?;
2791
2792    let root_id = channel.root_id();
2793    let channel = Channel::from_model(channel);
2794
2795    response.send(proto::CreateChannelResponse {
2796        channel: Some(channel.to_proto()),
2797        parent_id: request.parent_id,
2798    })?;
2799
2800    let mut connection_pool = session.connection_pool().await;
2801    if let Some(membership) = membership {
2802        connection_pool.subscribe_to_channel(
2803            membership.user_id,
2804            membership.channel_id,
2805            membership.role,
2806        );
2807        let update = proto::UpdateUserChannels {
2808            channel_memberships: vec![proto::ChannelMembership {
2809                channel_id: membership.channel_id.to_proto(),
2810                role: membership.role.into(),
2811            }],
2812            ..Default::default()
2813        };
2814        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2815            session.peer.send(connection_id, update.clone())?;
2816        }
2817    }
2818
2819    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2820        if !role.can_see_channel(channel.visibility) {
2821            continue;
2822        }
2823
2824        let update = proto::UpdateChannels {
2825            channels: vec![channel.to_proto()],
2826            ..Default::default()
2827        };
2828        session.peer.send(connection_id, update.clone())?;
2829    }
2830
2831    Ok(())
2832}
2833
2834/// Delete a channel
2835async fn delete_channel(
2836    request: proto::DeleteChannel,
2837    response: Response<proto::DeleteChannel>,
2838    session: MessageContext,
2839) -> Result<()> {
2840    let db = session.db().await;
2841
2842    let channel_id = request.channel_id;
2843    let (root_channel, removed_channels) = db
2844        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2845        .await?;
2846    response.send(proto::Ack {})?;
2847
2848    // Notify members of removed channels
2849    let mut update = proto::UpdateChannels::default();
2850    update
2851        .delete_channels
2852        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2853
2854    let connection_pool = session.connection_pool().await;
2855    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2856        session.peer.send(connection_id, update.clone())?;
2857    }
2858
2859    Ok(())
2860}
2861
2862/// Invite someone to join a channel.
2863async fn invite_channel_member(
2864    request: proto::InviteChannelMember,
2865    response: Response<proto::InviteChannelMember>,
2866    session: MessageContext,
2867) -> Result<()> {
2868    let db = session.db().await;
2869    let channel_id = ChannelId::from_proto(request.channel_id);
2870    let invitee_id = UserId::from_proto(request.user_id);
2871    let InviteMemberResult {
2872        channel,
2873        notifications,
2874    } = db
2875        .invite_channel_member(
2876            channel_id,
2877            invitee_id,
2878            session.user_id(),
2879            request.role().into(),
2880        )
2881        .await?;
2882
2883    let update = proto::UpdateChannels {
2884        channel_invitations: vec![channel.to_proto()],
2885        ..Default::default()
2886    };
2887
2888    let connection_pool = session.connection_pool().await;
2889    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2890        session.peer.send(connection_id, update.clone())?;
2891    }
2892
2893    send_notifications(&connection_pool, &session.peer, notifications);
2894
2895    response.send(proto::Ack {})?;
2896    Ok(())
2897}
2898
2899/// remove someone from a channel
2900async fn remove_channel_member(
2901    request: proto::RemoveChannelMember,
2902    response: Response<proto::RemoveChannelMember>,
2903    session: MessageContext,
2904) -> Result<()> {
2905    let db = session.db().await;
2906    let channel_id = ChannelId::from_proto(request.channel_id);
2907    let member_id = UserId::from_proto(request.user_id);
2908
2909    let RemoveChannelMemberResult {
2910        membership_update,
2911        notification_id,
2912    } = db
2913        .remove_channel_member(channel_id, member_id, session.user_id())
2914        .await?;
2915
2916    let mut connection_pool = session.connection_pool().await;
2917    notify_membership_updated(
2918        &mut connection_pool,
2919        membership_update,
2920        member_id,
2921        &session.peer,
2922    );
2923    for connection_id in connection_pool.user_connection_ids(member_id) {
2924        if let Some(notification_id) = notification_id {
2925            session
2926                .peer
2927                .send(
2928                    connection_id,
2929                    proto::DeleteNotification {
2930                        notification_id: notification_id.to_proto(),
2931                    },
2932                )
2933                .trace_err();
2934        }
2935    }
2936
2937    response.send(proto::Ack {})?;
2938    Ok(())
2939}
2940
2941/// Toggle the channel between public and private.
2942/// Care is taken to maintain the invariant that public channels only descend from public channels,
2943/// (though members-only channels can appear at any point in the hierarchy).
2944async fn set_channel_visibility(
2945    request: proto::SetChannelVisibility,
2946    response: Response<proto::SetChannelVisibility>,
2947    session: MessageContext,
2948) -> Result<()> {
2949    let db = session.db().await;
2950    let channel_id = ChannelId::from_proto(request.channel_id);
2951    let visibility = request.visibility().into();
2952
2953    let channel_model = db
2954        .set_channel_visibility(channel_id, visibility, session.user_id())
2955        .await?;
2956    let root_id = channel_model.root_id();
2957    let channel = Channel::from_model(channel_model);
2958
2959    let mut connection_pool = session.connection_pool().await;
2960    for (user_id, role) in connection_pool
2961        .channel_user_ids(root_id)
2962        .collect::<Vec<_>>()
2963        .into_iter()
2964    {
2965        let update = if role.can_see_channel(channel.visibility) {
2966            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2967            proto::UpdateChannels {
2968                channels: vec![channel.to_proto()],
2969                ..Default::default()
2970            }
2971        } else {
2972            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2973            proto::UpdateChannels {
2974                delete_channels: vec![channel.id.to_proto()],
2975                ..Default::default()
2976            }
2977        };
2978
2979        for connection_id in connection_pool.user_connection_ids(user_id) {
2980            session.peer.send(connection_id, update.clone())?;
2981        }
2982    }
2983
2984    response.send(proto::Ack {})?;
2985    Ok(())
2986}
2987
2988/// Alter the role for a user in the channel.
2989async fn set_channel_member_role(
2990    request: proto::SetChannelMemberRole,
2991    response: Response<proto::SetChannelMemberRole>,
2992    session: MessageContext,
2993) -> Result<()> {
2994    let db = session.db().await;
2995    let channel_id = ChannelId::from_proto(request.channel_id);
2996    let member_id = UserId::from_proto(request.user_id);
2997    let result = db
2998        .set_channel_member_role(
2999            channel_id,
3000            session.user_id(),
3001            member_id,
3002            request.role().into(),
3003        )
3004        .await?;
3005
3006    match result {
3007        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3008            let mut connection_pool = session.connection_pool().await;
3009            notify_membership_updated(
3010                &mut connection_pool,
3011                membership_update,
3012                member_id,
3013                &session.peer,
3014            )
3015        }
3016        db::SetMemberRoleResult::InviteUpdated(channel) => {
3017            let update = proto::UpdateChannels {
3018                channel_invitations: vec![channel.to_proto()],
3019                ..Default::default()
3020            };
3021
3022            for connection_id in session
3023                .connection_pool()
3024                .await
3025                .user_connection_ids(member_id)
3026            {
3027                session.peer.send(connection_id, update.clone())?;
3028            }
3029        }
3030    }
3031
3032    response.send(proto::Ack {})?;
3033    Ok(())
3034}
3035
3036/// Change the name of a channel
3037async fn rename_channel(
3038    request: proto::RenameChannel,
3039    response: Response<proto::RenameChannel>,
3040    session: MessageContext,
3041) -> Result<()> {
3042    let db = session.db().await;
3043    let channel_id = ChannelId::from_proto(request.channel_id);
3044    let channel_model = db
3045        .rename_channel(channel_id, session.user_id(), &request.name)
3046        .await?;
3047    let root_id = channel_model.root_id();
3048    let channel = Channel::from_model(channel_model);
3049
3050    response.send(proto::RenameChannelResponse {
3051        channel: Some(channel.to_proto()),
3052    })?;
3053
3054    let connection_pool = session.connection_pool().await;
3055    let update = proto::UpdateChannels {
3056        channels: vec![channel.to_proto()],
3057        ..Default::default()
3058    };
3059    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3060        if role.can_see_channel(channel.visibility) {
3061            session.peer.send(connection_id, update.clone())?;
3062        }
3063    }
3064
3065    Ok(())
3066}
3067
3068/// Move a channel to a new parent.
3069async fn move_channel(
3070    request: proto::MoveChannel,
3071    response: Response<proto::MoveChannel>,
3072    session: MessageContext,
3073) -> Result<()> {
3074    let channel_id = ChannelId::from_proto(request.channel_id);
3075    let to = ChannelId::from_proto(request.to);
3076
3077    let (root_id, channels) = session
3078        .db()
3079        .await
3080        .move_channel(channel_id, to, session.user_id())
3081        .await?;
3082
3083    let connection_pool = session.connection_pool().await;
3084    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3085        let channels = channels
3086            .iter()
3087            .filter_map(|channel| {
3088                if role.can_see_channel(channel.visibility) {
3089                    Some(channel.to_proto())
3090                } else {
3091                    None
3092                }
3093            })
3094            .collect::<Vec<_>>();
3095        if channels.is_empty() {
3096            continue;
3097        }
3098
3099        let update = proto::UpdateChannels {
3100            channels,
3101            ..Default::default()
3102        };
3103
3104        session.peer.send(connection_id, update.clone())?;
3105    }
3106
3107    response.send(Ack {})?;
3108    Ok(())
3109}
3110
3111async fn reorder_channel(
3112    request: proto::ReorderChannel,
3113    response: Response<proto::ReorderChannel>,
3114    session: MessageContext,
3115) -> Result<()> {
3116    let channel_id = ChannelId::from_proto(request.channel_id);
3117    let direction = request.direction();
3118
3119    let updated_channels = session
3120        .db()
3121        .await
3122        .reorder_channel(channel_id, direction, session.user_id())
3123        .await?;
3124
3125    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3126        let connection_pool = session.connection_pool().await;
3127        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3128            let channels = updated_channels
3129                .iter()
3130                .filter_map(|channel| {
3131                    if role.can_see_channel(channel.visibility) {
3132                        Some(channel.to_proto())
3133                    } else {
3134                        None
3135                    }
3136                })
3137                .collect::<Vec<_>>();
3138
3139            if channels.is_empty() {
3140                continue;
3141            }
3142
3143            let update = proto::UpdateChannels {
3144                channels,
3145                ..Default::default()
3146            };
3147
3148            session.peer.send(connection_id, update.clone())?;
3149        }
3150    }
3151
3152    response.send(Ack {})?;
3153    Ok(())
3154}
3155
3156/// Get the list of channel members
3157async fn get_channel_members(
3158    request: proto::GetChannelMembers,
3159    response: Response<proto::GetChannelMembers>,
3160    session: MessageContext,
3161) -> Result<()> {
3162    let db = session.db().await;
3163    let channel_id = ChannelId::from_proto(request.channel_id);
3164    let limit = if request.limit == 0 {
3165        u16::MAX as u64
3166    } else {
3167        request.limit
3168    };
3169    let (members, users) = db
3170        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3171        .await?;
3172    response.send(proto::GetChannelMembersResponse { members, users })?;
3173    Ok(())
3174}
3175
3176/// Accept or decline a channel invitation.
3177async fn respond_to_channel_invite(
3178    request: proto::RespondToChannelInvite,
3179    response: Response<proto::RespondToChannelInvite>,
3180    session: MessageContext,
3181) -> Result<()> {
3182    let db = session.db().await;
3183    let channel_id = ChannelId::from_proto(request.channel_id);
3184    let RespondToChannelInvite {
3185        membership_update,
3186        notifications,
3187    } = db
3188        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3189        .await?;
3190
3191    let mut connection_pool = session.connection_pool().await;
3192    if let Some(membership_update) = membership_update {
3193        notify_membership_updated(
3194            &mut connection_pool,
3195            membership_update,
3196            session.user_id(),
3197            &session.peer,
3198        );
3199    } else {
3200        let update = proto::UpdateChannels {
3201            remove_channel_invitations: vec![channel_id.to_proto()],
3202            ..Default::default()
3203        };
3204
3205        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3206            session.peer.send(connection_id, update.clone())?;
3207        }
3208    };
3209
3210    send_notifications(&connection_pool, &session.peer, notifications);
3211
3212    response.send(proto::Ack {})?;
3213
3214    Ok(())
3215}
3216
3217/// Join the channels' room
3218async fn join_channel(
3219    request: proto::JoinChannel,
3220    response: Response<proto::JoinChannel>,
3221    session: MessageContext,
3222) -> Result<()> {
3223    let channel_id = ChannelId::from_proto(request.channel_id);
3224    join_channel_internal(channel_id, Box::new(response), session).await
3225}
3226
3227trait JoinChannelInternalResponse {
3228    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3229}
3230impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3231    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3232        Response::<proto::JoinChannel>::send(self, result)
3233    }
3234}
3235impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3236    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3237        Response::<proto::JoinRoom>::send(self, result)
3238    }
3239}
3240
3241async fn join_channel_internal(
3242    channel_id: ChannelId,
3243    response: Box<impl JoinChannelInternalResponse>,
3244    session: MessageContext,
3245) -> Result<()> {
3246    let joined_room = {
3247        let mut db = session.db().await;
3248        // If zed quits without leaving the room, and the user re-opens zed before the
3249        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3250        // room they were in.
3251        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3252            tracing::info!(
3253                stale_connection_id = %connection,
3254                "cleaning up stale connection",
3255            );
3256            drop(db);
3257            leave_room_for_session(&session, connection).await?;
3258            db = session.db().await;
3259        }
3260
3261        let (joined_room, membership_updated, role) = db
3262            .join_channel(channel_id, session.user_id(), session.connection_id)
3263            .await?;
3264
3265        let live_kit_connection_info =
3266            session
3267                .app_state
3268                .livekit_client
3269                .as_ref()
3270                .and_then(|live_kit| {
3271                    let (can_publish, token) = if role == ChannelRole::Guest {
3272                        (
3273                            false,
3274                            live_kit
3275                                .guest_token(
3276                                    &joined_room.room.livekit_room,
3277                                    &session.user_id().to_string(),
3278                                )
3279                                .trace_err()?,
3280                        )
3281                    } else {
3282                        (
3283                            true,
3284                            live_kit
3285                                .room_token(
3286                                    &joined_room.room.livekit_room,
3287                                    &session.user_id().to_string(),
3288                                )
3289                                .trace_err()?,
3290                        )
3291                    };
3292
3293                    Some(LiveKitConnectionInfo {
3294                        server_url: live_kit.url().into(),
3295                        token,
3296                        can_publish,
3297                    })
3298                });
3299
3300        response.send(proto::JoinRoomResponse {
3301            room: Some(joined_room.room.clone()),
3302            channel_id: joined_room
3303                .channel
3304                .as_ref()
3305                .map(|channel| channel.id.to_proto()),
3306            live_kit_connection_info,
3307        })?;
3308
3309        let mut connection_pool = session.connection_pool().await;
3310        if let Some(membership_updated) = membership_updated {
3311            notify_membership_updated(
3312                &mut connection_pool,
3313                membership_updated,
3314                session.user_id(),
3315                &session.peer,
3316            );
3317        }
3318
3319        room_updated(&joined_room.room, &session.peer);
3320
3321        joined_room
3322    };
3323
3324    channel_updated(
3325        &joined_room.channel.context("channel not returned")?,
3326        &joined_room.room,
3327        &session.peer,
3328        &*session.connection_pool().await,
3329    );
3330
3331    update_user_contacts(session.user_id(), &session).await?;
3332    Ok(())
3333}
3334
3335/// Start editing the channel notes
3336async fn join_channel_buffer(
3337    request: proto::JoinChannelBuffer,
3338    response: Response<proto::JoinChannelBuffer>,
3339    session: MessageContext,
3340) -> Result<()> {
3341    let db = session.db().await;
3342    let channel_id = ChannelId::from_proto(request.channel_id);
3343
3344    let open_response = db
3345        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3346        .await?;
3347
3348    let collaborators = open_response.collaborators.clone();
3349    response.send(open_response)?;
3350
3351    let update = UpdateChannelBufferCollaborators {
3352        channel_id: channel_id.to_proto(),
3353        collaborators: collaborators.clone(),
3354    };
3355    channel_buffer_updated(
3356        session.connection_id,
3357        collaborators
3358            .iter()
3359            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3360        &update,
3361        &session.peer,
3362    );
3363
3364    Ok(())
3365}
3366
3367/// Edit the channel notes
3368async fn update_channel_buffer(
3369    request: proto::UpdateChannelBuffer,
3370    session: MessageContext,
3371) -> Result<()> {
3372    let db = session.db().await;
3373    let channel_id = ChannelId::from_proto(request.channel_id);
3374
3375    let (collaborators, epoch, version) = db
3376        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3377        .await?;
3378
3379    channel_buffer_updated(
3380        session.connection_id,
3381        collaborators.clone(),
3382        &proto::UpdateChannelBuffer {
3383            channel_id: channel_id.to_proto(),
3384            operations: request.operations,
3385        },
3386        &session.peer,
3387    );
3388
3389    let pool = &*session.connection_pool().await;
3390
3391    let non_collaborators =
3392        pool.channel_connection_ids(channel_id)
3393            .filter_map(|(connection_id, _)| {
3394                if collaborators.contains(&connection_id) {
3395                    None
3396                } else {
3397                    Some(connection_id)
3398                }
3399            });
3400
3401    broadcast(None, non_collaborators, |peer_id| {
3402        session.peer.send(
3403            peer_id,
3404            proto::UpdateChannels {
3405                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3406                    channel_id: channel_id.to_proto(),
3407                    epoch: epoch as u64,
3408                    version: version.clone(),
3409                }],
3410                ..Default::default()
3411            },
3412        )
3413    });
3414
3415    Ok(())
3416}
3417
3418/// Rejoin the channel notes after a connection blip
3419async fn rejoin_channel_buffers(
3420    request: proto::RejoinChannelBuffers,
3421    response: Response<proto::RejoinChannelBuffers>,
3422    session: MessageContext,
3423) -> Result<()> {
3424    let db = session.db().await;
3425    let buffers = db
3426        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3427        .await?;
3428
3429    for rejoined_buffer in &buffers {
3430        let collaborators_to_notify = rejoined_buffer
3431            .buffer
3432            .collaborators
3433            .iter()
3434            .filter_map(|c| Some(c.peer_id?.into()));
3435        channel_buffer_updated(
3436            session.connection_id,
3437            collaborators_to_notify,
3438            &proto::UpdateChannelBufferCollaborators {
3439                channel_id: rejoined_buffer.buffer.channel_id,
3440                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3441            },
3442            &session.peer,
3443        );
3444    }
3445
3446    response.send(proto::RejoinChannelBuffersResponse {
3447        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3448    })?;
3449
3450    Ok(())
3451}
3452
3453/// Stop editing the channel notes
3454async fn leave_channel_buffer(
3455    request: proto::LeaveChannelBuffer,
3456    response: Response<proto::LeaveChannelBuffer>,
3457    session: MessageContext,
3458) -> Result<()> {
3459    let db = session.db().await;
3460    let channel_id = ChannelId::from_proto(request.channel_id);
3461
3462    let left_buffer = db
3463        .leave_channel_buffer(channel_id, session.connection_id)
3464        .await?;
3465
3466    response.send(Ack {})?;
3467
3468    channel_buffer_updated(
3469        session.connection_id,
3470        left_buffer.connections,
3471        &proto::UpdateChannelBufferCollaborators {
3472            channel_id: channel_id.to_proto(),
3473            collaborators: left_buffer.collaborators,
3474        },
3475        &session.peer,
3476    );
3477
3478    Ok(())
3479}
3480
3481fn channel_buffer_updated<T: EnvelopedMessage>(
3482    sender_id: ConnectionId,
3483    collaborators: impl IntoIterator<Item = ConnectionId>,
3484    message: &T,
3485    peer: &Peer,
3486) {
3487    broadcast(Some(sender_id), collaborators, |peer_id| {
3488        peer.send(peer_id, message.clone())
3489    });
3490}
3491
3492fn send_notifications(
3493    connection_pool: &ConnectionPool,
3494    peer: &Peer,
3495    notifications: db::NotificationBatch,
3496) {
3497    for (user_id, notification) in notifications {
3498        for connection_id in connection_pool.user_connection_ids(user_id) {
3499            if let Err(error) = peer.send(
3500                connection_id,
3501                proto::AddNotification {
3502                    notification: Some(notification.clone()),
3503                },
3504            ) {
3505                tracing::error!(
3506                    "failed to send notification to {:?} {}",
3507                    connection_id,
3508                    error
3509                );
3510            }
3511        }
3512    }
3513}
3514
3515/// Send a message to the channel
3516async fn send_channel_message(
3517    _request: proto::SendChannelMessage,
3518    _response: Response<proto::SendChannelMessage>,
3519    _session: MessageContext,
3520) -> Result<()> {
3521    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3522}
3523
3524/// Delete a channel message
3525async fn remove_channel_message(
3526    _request: proto::RemoveChannelMessage,
3527    _response: Response<proto::RemoveChannelMessage>,
3528    _session: MessageContext,
3529) -> Result<()> {
3530    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3531}
3532
3533async fn update_channel_message(
3534    _request: proto::UpdateChannelMessage,
3535    _response: Response<proto::UpdateChannelMessage>,
3536    _session: MessageContext,
3537) -> Result<()> {
3538    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3539}
3540
3541/// Mark a channel message as read
3542async fn acknowledge_channel_message(
3543    _request: proto::AckChannelMessage,
3544    _session: MessageContext,
3545) -> Result<()> {
3546    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3547}
3548
3549/// Mark a buffer version as synced
3550async fn acknowledge_buffer_version(
3551    request: proto::AckBufferOperation,
3552    session: MessageContext,
3553) -> Result<()> {
3554    let buffer_id = BufferId::from_proto(request.buffer_id);
3555    session
3556        .db()
3557        .await
3558        .observe_buffer_version(
3559            buffer_id,
3560            session.user_id(),
3561            request.epoch as i32,
3562            &request.version,
3563        )
3564        .await?;
3565    Ok(())
3566}
3567
3568/// Start receiving chat updates for a channel
3569async fn join_channel_chat(
3570    _request: proto::JoinChannelChat,
3571    _response: Response<proto::JoinChannelChat>,
3572    _session: MessageContext,
3573) -> Result<()> {
3574    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3575}
3576
3577/// Stop receiving chat updates for a channel
3578async fn leave_channel_chat(
3579    _request: proto::LeaveChannelChat,
3580    _session: MessageContext,
3581) -> Result<()> {
3582    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3583}
3584
3585/// Retrieve the chat history for a channel
3586async fn get_channel_messages(
3587    _request: proto::GetChannelMessages,
3588    _response: Response<proto::GetChannelMessages>,
3589    _session: MessageContext,
3590) -> Result<()> {
3591    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3592}
3593
3594/// Retrieve specific chat messages
3595async fn get_channel_messages_by_id(
3596    _request: proto::GetChannelMessagesById,
3597    _response: Response<proto::GetChannelMessagesById>,
3598    _session: MessageContext,
3599) -> Result<()> {
3600    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3601}
3602
3603/// Retrieve the current users notifications
3604async fn get_notifications(
3605    request: proto::GetNotifications,
3606    response: Response<proto::GetNotifications>,
3607    session: MessageContext,
3608) -> Result<()> {
3609    let notifications = session
3610        .db()
3611        .await
3612        .get_notifications(
3613            session.user_id(),
3614            NOTIFICATION_COUNT_PER_PAGE,
3615            request.before_id.map(db::NotificationId::from_proto),
3616        )
3617        .await?;
3618    response.send(proto::GetNotificationsResponse {
3619        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3620        notifications,
3621    })?;
3622    Ok(())
3623}
3624
3625/// Mark notifications as read
3626async fn mark_notification_as_read(
3627    request: proto::MarkNotificationRead,
3628    response: Response<proto::MarkNotificationRead>,
3629    session: MessageContext,
3630) -> Result<()> {
3631    let database = &session.db().await;
3632    let notifications = database
3633        .mark_notification_as_read_by_id(
3634            session.user_id(),
3635            NotificationId::from_proto(request.notification_id),
3636        )
3637        .await?;
3638    send_notifications(
3639        &*session.connection_pool().await,
3640        &session.peer,
3641        notifications,
3642    );
3643    response.send(proto::Ack {})?;
3644    Ok(())
3645}
3646
3647fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3648    let message = match message {
3649        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3650        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3651        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3652        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3653        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3654            code: frame.code.into(),
3655            reason: frame.reason.as_str().to_owned().into(),
3656        })),
3657        // We should never receive a frame while reading the message, according
3658        // to the `tungstenite` maintainers:
3659        //
3660        // > It cannot occur when you read messages from the WebSocket, but it
3661        // > can be used when you want to send the raw frames (e.g. you want to
3662        // > send the frames to the WebSocket without composing the full message first).
3663        // >
3664        // > — https://github.com/snapview/tungstenite-rs/issues/268
3665        TungsteniteMessage::Frame(_) => {
3666            bail!("received an unexpected frame while reading the message")
3667        }
3668    };
3669
3670    Ok(message)
3671}
3672
3673fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3674    match message {
3675        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3676        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3677        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3678        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3679        AxumMessage::Close(frame) => {
3680            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3681                code: frame.code.into(),
3682                reason: frame.reason.as_ref().into(),
3683            }))
3684        }
3685    }
3686}
3687
3688fn notify_membership_updated(
3689    connection_pool: &mut ConnectionPool,
3690    result: MembershipUpdated,
3691    user_id: UserId,
3692    peer: &Peer,
3693) {
3694    for membership in &result.new_channels.channel_memberships {
3695        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3696    }
3697    for channel_id in &result.removed_channels {
3698        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3699    }
3700
3701    let user_channels_update = proto::UpdateUserChannels {
3702        channel_memberships: result
3703            .new_channels
3704            .channel_memberships
3705            .iter()
3706            .map(|cm| proto::ChannelMembership {
3707                channel_id: cm.channel_id.to_proto(),
3708                role: cm.role.into(),
3709            })
3710            .collect(),
3711        ..Default::default()
3712    };
3713
3714    let mut update = build_channels_update(result.new_channels);
3715    update.delete_channels = result
3716        .removed_channels
3717        .into_iter()
3718        .map(|id| id.to_proto())
3719        .collect();
3720    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3721
3722    for connection_id in connection_pool.user_connection_ids(user_id) {
3723        peer.send(connection_id, user_channels_update.clone())
3724            .trace_err();
3725        peer.send(connection_id, update.clone()).trace_err();
3726    }
3727}
3728
3729fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3730    proto::UpdateUserChannels {
3731        channel_memberships: channels
3732            .channel_memberships
3733            .iter()
3734            .map(|m| proto::ChannelMembership {
3735                channel_id: m.channel_id.to_proto(),
3736                role: m.role.into(),
3737            })
3738            .collect(),
3739        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3740    }
3741}
3742
3743fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3744    let mut update = proto::UpdateChannels::default();
3745
3746    for channel in channels.channels {
3747        update.channels.push(channel.to_proto());
3748    }
3749
3750    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3751
3752    for (channel_id, participants) in channels.channel_participants {
3753        update
3754            .channel_participants
3755            .push(proto::ChannelParticipants {
3756                channel_id: channel_id.to_proto(),
3757                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3758            });
3759    }
3760
3761    for channel in channels.invited_channels {
3762        update.channel_invitations.push(channel.to_proto());
3763    }
3764
3765    update
3766}
3767
3768fn build_initial_contacts_update(
3769    contacts: Vec<db::Contact>,
3770    pool: &ConnectionPool,
3771) -> proto::UpdateContacts {
3772    let mut update = proto::UpdateContacts::default();
3773
3774    for contact in contacts {
3775        match contact {
3776            db::Contact::Accepted { user_id, busy } => {
3777                update.contacts.push(contact_for_user(user_id, busy, pool));
3778            }
3779            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3780            db::Contact::Incoming { user_id } => {
3781                update
3782                    .incoming_requests
3783                    .push(proto::IncomingContactRequest {
3784                        requester_id: user_id.to_proto(),
3785                    })
3786            }
3787        }
3788    }
3789
3790    update
3791}
3792
3793fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3794    proto::Contact {
3795        user_id: user_id.to_proto(),
3796        online: pool.is_user_online(user_id),
3797        busy,
3798    }
3799}
3800
3801fn room_updated(room: &proto::Room, peer: &Peer) {
3802    broadcast(
3803        None,
3804        room.participants
3805            .iter()
3806            .filter_map(|participant| Some(participant.peer_id?.into())),
3807        |peer_id| {
3808            peer.send(
3809                peer_id,
3810                proto::RoomUpdated {
3811                    room: Some(room.clone()),
3812                },
3813            )
3814        },
3815    );
3816}
3817
3818fn channel_updated(
3819    channel: &db::channel::Model,
3820    room: &proto::Room,
3821    peer: &Peer,
3822    pool: &ConnectionPool,
3823) {
3824    let participants = room
3825        .participants
3826        .iter()
3827        .map(|p| p.user_id)
3828        .collect::<Vec<_>>();
3829
3830    broadcast(
3831        None,
3832        pool.channel_connection_ids(channel.root_id())
3833            .filter_map(|(channel_id, role)| {
3834                role.can_see_channel(channel.visibility)
3835                    .then_some(channel_id)
3836            }),
3837        |peer_id| {
3838            peer.send(
3839                peer_id,
3840                proto::UpdateChannels {
3841                    channel_participants: vec![proto::ChannelParticipants {
3842                        channel_id: channel.id.to_proto(),
3843                        participant_user_ids: participants.clone(),
3844                    }],
3845                    ..Default::default()
3846                },
3847            )
3848        },
3849    );
3850}
3851
3852async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3853    let db = session.db().await;
3854
3855    let contacts = db.get_contacts(user_id).await?;
3856    let busy = db.is_user_busy(user_id).await?;
3857
3858    let pool = session.connection_pool().await;
3859    let updated_contact = contact_for_user(user_id, busy, &pool);
3860    for contact in contacts {
3861        if let db::Contact::Accepted {
3862            user_id: contact_user_id,
3863            ..
3864        } = contact
3865        {
3866            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3867                session
3868                    .peer
3869                    .send(
3870                        contact_conn_id,
3871                        proto::UpdateContacts {
3872                            contacts: vec![updated_contact.clone()],
3873                            remove_contacts: Default::default(),
3874                            incoming_requests: Default::default(),
3875                            remove_incoming_requests: Default::default(),
3876                            outgoing_requests: Default::default(),
3877                            remove_outgoing_requests: Default::default(),
3878                        },
3879                    )
3880                    .trace_err();
3881            }
3882        }
3883    }
3884    Ok(())
3885}
3886
3887async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3888    let mut contacts_to_update = HashSet::default();
3889
3890    let room_id;
3891    let canceled_calls_to_user_ids;
3892    let livekit_room;
3893    let delete_livekit_room;
3894    let room;
3895    let channel;
3896
3897    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3898        contacts_to_update.insert(session.user_id());
3899
3900        for project in left_room.left_projects.values() {
3901            project_left(project, session);
3902        }
3903
3904        room_id = RoomId::from_proto(left_room.room.id);
3905        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3906        livekit_room = mem::take(&mut left_room.room.livekit_room);
3907        delete_livekit_room = left_room.deleted;
3908        room = mem::take(&mut left_room.room);
3909        channel = mem::take(&mut left_room.channel);
3910
3911        room_updated(&room, &session.peer);
3912    } else {
3913        return Ok(());
3914    }
3915
3916    if let Some(channel) = channel {
3917        channel_updated(
3918            &channel,
3919            &room,
3920            &session.peer,
3921            &*session.connection_pool().await,
3922        );
3923    }
3924
3925    {
3926        let pool = session.connection_pool().await;
3927        for canceled_user_id in canceled_calls_to_user_ids {
3928            for connection_id in pool.user_connection_ids(canceled_user_id) {
3929                session
3930                    .peer
3931                    .send(
3932                        connection_id,
3933                        proto::CallCanceled {
3934                            room_id: room_id.to_proto(),
3935                        },
3936                    )
3937                    .trace_err();
3938            }
3939            contacts_to_update.insert(canceled_user_id);
3940        }
3941    }
3942
3943    for contact_user_id in contacts_to_update {
3944        update_user_contacts(contact_user_id, session).await?;
3945    }
3946
3947    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3948        live_kit
3949            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3950            .await
3951            .trace_err();
3952
3953        if delete_livekit_room {
3954            live_kit.delete_room(livekit_room).await.trace_err();
3955        }
3956    }
3957
3958    Ok(())
3959}
3960
3961async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3962    let left_channel_buffers = session
3963        .db()
3964        .await
3965        .leave_channel_buffers(session.connection_id)
3966        .await?;
3967
3968    for left_buffer in left_channel_buffers {
3969        channel_buffer_updated(
3970            session.connection_id,
3971            left_buffer.connections,
3972            &proto::UpdateChannelBufferCollaborators {
3973                channel_id: left_buffer.channel_id.to_proto(),
3974                collaborators: left_buffer.collaborators,
3975            },
3976            &session.peer,
3977        );
3978    }
3979
3980    Ok(())
3981}
3982
3983fn project_left(project: &db::LeftProject, session: &Session) {
3984    for connection_id in &project.connection_ids {
3985        if project.should_unshare {
3986            session
3987                .peer
3988                .send(
3989                    *connection_id,
3990                    proto::UnshareProject {
3991                        project_id: project.id.to_proto(),
3992                    },
3993                )
3994                .trace_err();
3995        } else {
3996            session
3997                .peer
3998                .send(
3999                    *connection_id,
4000                    proto::RemoveProjectCollaborator {
4001                        project_id: project.id.to_proto(),
4002                        peer_id: Some(session.connection_id.into()),
4003                    },
4004                )
4005                .trace_err();
4006        }
4007    }
4008}
4009
4010async fn share_agent_thread(
4011    request: proto::ShareAgentThread,
4012    response: Response<proto::ShareAgentThread>,
4013    session: MessageContext,
4014) -> Result<()> {
4015    let user_id = session.user_id();
4016
4017    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4018        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4019
4020    session
4021        .db()
4022        .await
4023        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4024        .await?;
4025
4026    response.send(proto::Ack {})?;
4027
4028    Ok(())
4029}
4030
4031async fn get_shared_agent_thread(
4032    request: proto::GetSharedAgentThread,
4033    response: Response<proto::GetSharedAgentThread>,
4034    session: MessageContext,
4035) -> Result<()> {
4036    let share_id = SharedThreadId::from_proto(request.session_id)
4037        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4038
4039    let result = session.db().await.get_shared_thread(share_id).await?;
4040
4041    match result {
4042        Some((thread, username)) => {
4043            response.send(proto::GetSharedAgentThreadResponse {
4044                title: thread.title,
4045                thread_data: thread.data,
4046                sharer_username: username,
4047                created_at: thread.created_at.and_utc().to_rfc3339(),
4048            })?;
4049        }
4050        None => {
4051            return Err(anyhow!("Shared thread not found").into());
4052        }
4053    }
4054
4055    Ok(())
4056}
4057
4058pub trait ResultExt {
4059    type Ok;
4060
4061    fn trace_err(self) -> Option<Self::Ok>;
4062}
4063
4064impl<T, E> ResultExt for Result<T, E>
4065where
4066    E: std::fmt::Debug,
4067{
4068    type Ok = T;
4069
4070    #[track_caller]
4071    fn trace_err(self) -> Option<T> {
4072        match self {
4073            Ok(value) => Some(value),
4074            Err(error) => {
4075                tracing::error!("{:?}", error);
4076                None
4077            }
4078        }
4079    }
4080}