rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::TrashProjectEntry>)
 351            .add_request_handler(forward_mutating_project_request::<proto::RestoreProjectEntry>)
 352            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 353            .add_request_handler(
 354                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 355            )
 356            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 357            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 358            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 359            .add_request_handler(lsp_query)
 360            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 361            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 362            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 364            .add_message_handler(create_buffer_for_peer)
 365            .add_message_handler(create_image_for_peer)
 366            .add_request_handler(update_buffer)
 367            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 368            .add_message_handler(
 369                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 370            )
 371            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 375            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 376            .add_message_handler(
 377                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 378            )
 379            .add_request_handler(get_users)
 380            .add_request_handler(fuzzy_search_users)
 381            .add_request_handler(request_contact)
 382            .add_request_handler(remove_contact)
 383            .add_request_handler(respond_to_contact_request)
 384            .add_message_handler(subscribe_to_channels)
 385            .add_request_handler(create_channel)
 386            .add_request_handler(delete_channel)
 387            .add_request_handler(invite_channel_member)
 388            .add_request_handler(remove_channel_member)
 389            .add_request_handler(set_channel_member_role)
 390            .add_request_handler(set_channel_visibility)
 391            .add_request_handler(rename_channel)
 392            .add_request_handler(join_channel_buffer)
 393            .add_request_handler(leave_channel_buffer)
 394            .add_message_handler(update_channel_buffer)
 395            .add_request_handler(rejoin_channel_buffers)
 396            .add_request_handler(get_channel_members)
 397            .add_request_handler(respond_to_channel_invite)
 398            .add_request_handler(join_channel)
 399            .add_request_handler(join_channel_chat)
 400            .add_message_handler(leave_channel_chat)
 401            .add_request_handler(send_channel_message)
 402            .add_request_handler(remove_channel_message)
 403            .add_request_handler(update_channel_message)
 404            .add_request_handler(get_channel_messages)
 405            .add_request_handler(get_channel_messages_by_id)
 406            .add_request_handler(get_notifications)
 407            .add_request_handler(mark_notification_as_read)
 408            .add_request_handler(move_channel)
 409            .add_request_handler(reorder_channel)
 410            .add_request_handler(follow)
 411            .add_message_handler(unfollow)
 412            .add_message_handler(update_followers)
 413            .add_message_handler(acknowledge_channel_message)
 414            .add_message_handler(acknowledge_buffer_version)
 415            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 416            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 417            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 418            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 419            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 420            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 421            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 422            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 423            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 425            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 426            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 427            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 428            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 429            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 430            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 431            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 437            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 438            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 439            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 440            .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
 441            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 442            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 443            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 444            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 445            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 446            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 447            .add_request_handler(share_agent_thread)
 448            .add_request_handler(get_shared_agent_thread)
 449            .add_request_handler(forward_project_search_chunk);
 450
 451        Arc::new(server)
 452    }
 453
 454    pub async fn start(&self) -> Result<()> {
 455        let server_id = *self.id.lock();
 456        let app_state = self.app_state.clone();
 457        let peer = self.peer.clone();
 458        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 459        let pool = self.connection_pool.clone();
 460        let livekit_client = self.app_state.livekit_client.clone();
 461
 462        let span = info_span!("start server");
 463        self.app_state.executor.spawn_detached(
 464            async move {
 465                tracing::info!("waiting for cleanup timeout");
 466                timeout.await;
 467                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 468
 469                app_state
 470                    .db
 471                    .delete_stale_channel_chat_participants(
 472                        &app_state.config.zed_environment,
 473                        server_id,
 474                    )
 475                    .await
 476                    .trace_err();
 477
 478                if let Some((room_ids, channel_ids)) = app_state
 479                    .db
 480                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 481                    .await
 482                    .trace_err()
 483                {
 484                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 485                    tracing::info!(
 486                        stale_channel_buffer_count = channel_ids.len(),
 487                        "retrieved stale channel buffers"
 488                    );
 489
 490                    for channel_id in channel_ids {
 491                        if let Some(refreshed_channel_buffer) = app_state
 492                            .db
 493                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 494                            .await
 495                            .trace_err()
 496                        {
 497                            for connection_id in refreshed_channel_buffer.connection_ids {
 498                                peer.send(
 499                                    connection_id,
 500                                    proto::UpdateChannelBufferCollaborators {
 501                                        channel_id: channel_id.to_proto(),
 502                                        collaborators: refreshed_channel_buffer
 503                                            .collaborators
 504                                            .clone(),
 505                                    },
 506                                )
 507                                .trace_err();
 508                            }
 509                        }
 510                    }
 511
 512                    for room_id in room_ids {
 513                        let mut contacts_to_update = HashSet::default();
 514                        let mut canceled_calls_to_user_ids = Vec::new();
 515                        let mut livekit_room = String::new();
 516                        let mut delete_livekit_room = false;
 517
 518                        if let Some(mut refreshed_room) = app_state
 519                            .db
 520                            .clear_stale_room_participants(room_id, server_id)
 521                            .await
 522                            .trace_err()
 523                        {
 524                            tracing::info!(
 525                                room_id = room_id.0,
 526                                new_participant_count = refreshed_room.room.participants.len(),
 527                                "refreshed room"
 528                            );
 529                            room_updated(&refreshed_room.room, &peer);
 530                            if let Some(channel) = refreshed_room.channel.as_ref() {
 531                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 532                            }
 533                            contacts_to_update
 534                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 535                            contacts_to_update
 536                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 537                            canceled_calls_to_user_ids =
 538                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 539                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 540                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 541                        }
 542
 543                        {
 544                            let pool = pool.lock();
 545                            for canceled_user_id in canceled_calls_to_user_ids {
 546                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 547                                    peer.send(
 548                                        connection_id,
 549                                        proto::CallCanceled {
 550                                            room_id: room_id.to_proto(),
 551                                        },
 552                                    )
 553                                    .trace_err();
 554                                }
 555                            }
 556                        }
 557
 558                        for user_id in contacts_to_update {
 559                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 560                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 561                            if let Some((busy, contacts)) = busy.zip(contacts) {
 562                                let pool = pool.lock();
 563                                let updated_contact = contact_for_user(user_id, busy, &pool);
 564                                for contact in contacts {
 565                                    if let db::Contact::Accepted {
 566                                        user_id: contact_user_id,
 567                                        ..
 568                                    } = contact
 569                                    {
 570                                        for contact_conn_id in
 571                                            pool.user_connection_ids(contact_user_id)
 572                                        {
 573                                            peer.send(
 574                                                contact_conn_id,
 575                                                proto::UpdateContacts {
 576                                                    contacts: vec![updated_contact.clone()],
 577                                                    remove_contacts: Default::default(),
 578                                                    incoming_requests: Default::default(),
 579                                                    remove_incoming_requests: Default::default(),
 580                                                    outgoing_requests: Default::default(),
 581                                                    remove_outgoing_requests: Default::default(),
 582                                                },
 583                                            )
 584                                            .trace_err();
 585                                        }
 586                                    }
 587                                }
 588                            }
 589                        }
 590
 591                        if let Some(live_kit) = livekit_client.as_ref()
 592                            && delete_livekit_room
 593                        {
 594                            live_kit.delete_room(livekit_room).await.trace_err();
 595                        }
 596                    }
 597                }
 598
 599                app_state
 600                    .db
 601                    .delete_stale_channel_chat_participants(
 602                        &app_state.config.zed_environment,
 603                        server_id,
 604                    )
 605                    .await
 606                    .trace_err();
 607
 608                app_state
 609                    .db
 610                    .clear_old_worktree_entries(server_id)
 611                    .await
 612                    .trace_err();
 613
 614                app_state
 615                    .db
 616                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 617                    .await
 618                    .trace_err();
 619            }
 620            .instrument(span),
 621        );
 622        Ok(())
 623    }
 624
 625    pub fn teardown(&self) {
 626        self.peer.teardown();
 627        self.connection_pool.lock().reset();
 628        let _ = self.teardown.send(true);
 629    }
 630
 631    #[cfg(feature = "test-support")]
 632    pub fn reset(&self, id: ServerId) {
 633        self.teardown();
 634        *self.id.lock() = id;
 635        self.peer.reset(id.0 as u32);
 636        let _ = self.teardown.send(false);
 637    }
 638
 639    #[cfg(feature = "test-support")]
 640    pub fn id(&self) -> ServerId {
 641        *self.id.lock()
 642    }
 643
 644    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 645    where
 646        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 647        Fut: 'static + Send + Future<Output = Result<()>>,
 648        M: EnvelopedMessage,
 649    {
 650        let prev_handler = self.handlers.insert(
 651            TypeId::of::<M>(),
 652            Box::new(move |envelope, session, span| {
 653                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 654                let received_at = envelope.received_at;
 655                tracing::info!("message received");
 656                let start_time = Instant::now();
 657                let future = (handler)(
 658                    *envelope,
 659                    MessageContext {
 660                        session,
 661                        span: span.clone(),
 662                    },
 663                );
 664                async move {
 665                    let result = future.await;
 666                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 667                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 668                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 669                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 670                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 671                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 672                    match result {
 673                        Err(error) => {
 674                            tracing::error!(?error, "error handling message")
 675                        }
 676                        Ok(()) => tracing::info!("finished handling message"),
 677                    }
 678                }
 679                .boxed()
 680            }),
 681        );
 682        if prev_handler.is_some() {
 683            panic!("registered a handler for the same message twice");
 684        }
 685        self
 686    }
 687
 688    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 689    where
 690        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 691        Fut: 'static + Send + Future<Output = Result<()>>,
 692        M: EnvelopedMessage,
 693    {
 694        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 695        self
 696    }
 697
 698    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 699    where
 700        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 701        Fut: Send + Future<Output = Result<()>>,
 702        M: RequestMessage,
 703    {
 704        let handler = Arc::new(handler);
 705        self.add_handler(move |envelope, session| {
 706            let receipt = envelope.receipt();
 707            let handler = handler.clone();
 708            async move {
 709                let peer = session.peer.clone();
 710                let responded = Arc::new(AtomicBool::default());
 711                let response = Response {
 712                    peer: peer.clone(),
 713                    responded: responded.clone(),
 714                    receipt,
 715                };
 716                match (handler)(envelope.payload, response, session).await {
 717                    Ok(()) => {
 718                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 719                            Ok(())
 720                        } else {
 721                            Err(anyhow!("handler did not send a response"))?
 722                        }
 723                    }
 724                    Err(error) => {
 725                        let proto_err = match &error {
 726                            Error::Internal(err) => err.to_proto(),
 727                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 728                        };
 729                        peer.respond_with_error(receipt, proto_err)?;
 730                        Err(error)
 731                    }
 732                }
 733            }
 734        })
 735    }
 736
 737    pub fn handle_connection(
 738        self: &Arc<Self>,
 739        connection: Connection,
 740        address: String,
 741        principal: Principal,
 742        zed_version: ZedVersion,
 743        release_channel: Option<String>,
 744        user_agent: Option<String>,
 745        geoip_country_code: Option<String>,
 746        system_id: Option<String>,
 747        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 748        executor: Executor,
 749        connection_guard: Option<ConnectionGuard>,
 750    ) -> impl Future<Output = ()> + use<> {
 751        let this = self.clone();
 752        let span = info_span!("handle connection", %address,
 753            connection_id=field::Empty,
 754            user_id=field::Empty,
 755            login=field::Empty,
 756            user_agent=field::Empty,
 757            geoip_country_code=field::Empty,
 758            release_channel=field::Empty,
 759        );
 760        principal.update_span(&span);
 761        if let Some(user_agent) = user_agent {
 762            span.record("user_agent", user_agent);
 763        }
 764        if let Some(release_channel) = release_channel {
 765            span.record("release_channel", release_channel);
 766        }
 767
 768        if let Some(country_code) = geoip_country_code.as_ref() {
 769            span.record("geoip_country_code", country_code);
 770        }
 771
 772        let mut teardown = self.teardown.subscribe();
 773        async move {
 774            if *teardown.borrow() {
 775                tracing::error!("server is tearing down");
 776                return;
 777            }
 778
 779            let (connection_id, handle_io, mut incoming_rx) =
 780                this.peer.add_connection(connection, {
 781                    let executor = executor.clone();
 782                    move |duration| executor.sleep(duration)
 783                });
 784            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 785
 786            tracing::info!("connection opened");
 787
 788            let session = Session {
 789                principal: principal.clone(),
 790                connection_id,
 791                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 792                peer: this.peer.clone(),
 793                connection_pool: this.connection_pool.clone(),
 794                app_state: this.app_state.clone(),
 795                geoip_country_code,
 796                system_id,
 797                _executor: executor.clone(),
 798            };
 799
 800            if let Err(error) = this
 801                .send_initial_client_update(
 802                    connection_id,
 803                    zed_version,
 804                    send_connection_id,
 805                    &session,
 806                )
 807                .await
 808            {
 809                tracing::error!(?error, "failed to send initial client update");
 810                return;
 811            }
 812            drop(connection_guard);
 813
 814            let handle_io = handle_io.fuse();
 815            futures::pin_mut!(handle_io);
 816
 817            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 818            // This prevents deadlocks when e.g., client A performs a request to client B and
 819            // client B performs a request to client A. If both clients stop processing further
 820            // messages until their respective request completes, they won't have a chance to
 821            // respond to the other client's request and cause a deadlock.
 822            //
 823            // This arrangement ensures we will attempt to process earlier messages first, but fall
 824            // back to processing messages arrived later in the spirit of making progress.
 825            const MAX_CONCURRENT_HANDLERS: usize = 256;
 826            let mut foreground_message_handlers = FuturesUnordered::new();
 827            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 828            let get_concurrent_handlers = {
 829                let concurrent_handlers = concurrent_handlers.clone();
 830                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 831            };
 832            loop {
 833                let next_message = async {
 834                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 835                    let message = incoming_rx.next().await;
 836                    // Cache the concurrent_handlers here, so that we know what the
 837                    // queue looks like as each handler starts
 838                    (permit, message, get_concurrent_handlers())
 839                }
 840                .fuse();
 841                futures::pin_mut!(next_message);
 842                futures::select_biased! {
 843                    _ = teardown.changed().fuse() => return,
 844                    result = handle_io => {
 845                        if let Err(error) = result {
 846                            tracing::error!(?error, "error handling I/O");
 847                        }
 848                        break;
 849                    }
 850                    _ = foreground_message_handlers.next() => {}
 851                    next_message = next_message => {
 852                        let (permit, message, concurrent_handlers) = next_message;
 853                        if let Some(message) = message {
 854                            let type_name = message.payload_type_name();
 855                            // note: we copy all the fields from the parent span so we can query them in the logs.
 856                            // (https://github.com/tokio-rs/tracing/issues/2670).
 857                            let span = tracing::info_span!("receive message",
 858                                %connection_id,
 859                                %address,
 860                                type_name,
 861                                concurrent_handlers,
 862                                user_id=field::Empty,
 863                                login=field::Empty,
 864                                lsp_query_request=field::Empty,
 865                                release_channel=field::Empty,
 866                                { TOTAL_DURATION_MS }=field::Empty,
 867                                { PROCESSING_DURATION_MS }=field::Empty,
 868                                { QUEUE_DURATION_MS }=field::Empty,
 869                                { HOST_WAITING_MS }=field::Empty
 870                            );
 871                            principal.update_span(&span);
 872                            let span_enter = span.enter();
 873                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 874                                let is_background = message.is_background();
 875                                let handle_message = (handler)(message, session.clone(), span.clone());
 876                                drop(span_enter);
 877
 878                                let handle_message = async move {
 879                                    handle_message.await;
 880                                    drop(permit);
 881                                }.instrument(span);
 882                                if is_background {
 883                                    executor.spawn_detached(handle_message);
 884                                } else {
 885                                    foreground_message_handlers.push(handle_message);
 886                                }
 887                            } else {
 888                                tracing::error!("no message handler");
 889                            }
 890                        } else {
 891                            tracing::info!("connection closed");
 892                            break;
 893                        }
 894                    }
 895                }
 896            }
 897
 898            drop(foreground_message_handlers);
 899            let concurrent_handlers = get_concurrent_handlers();
 900            tracing::info!(concurrent_handlers, "signing out");
 901            if let Err(error) = connection_lost(session, teardown, executor).await {
 902                tracing::error!(?error, "error signing out");
 903            }
 904        }
 905        .instrument(span)
 906    }
 907
 908    async fn send_initial_client_update(
 909        &self,
 910        connection_id: ConnectionId,
 911        zed_version: ZedVersion,
 912        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 913        session: &Session,
 914    ) -> Result<()> {
 915        self.peer.send(
 916            connection_id,
 917            proto::Hello {
 918                peer_id: Some(connection_id.into()),
 919            },
 920        )?;
 921        tracing::info!("sent hello message");
 922        if let Some(send_connection_id) = send_connection_id.take() {
 923            let _ = send_connection_id.send(connection_id);
 924        }
 925
 926        match &session.principal {
 927            Principal::User(user) => {
 928                if !user.connected_once {
 929                    self.peer.send(connection_id, proto::ShowContacts {})?;
 930                    self.app_state
 931                        .db
 932                        .set_user_connected_once(user.id, true)
 933                        .await?;
 934                }
 935
 936                let contacts = self.app_state.db.get_contacts(user.id).await?;
 937
 938                {
 939                    let mut pool = self.connection_pool.lock();
 940                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 941                    self.peer.send(
 942                        connection_id,
 943                        build_initial_contacts_update(contacts, &pool),
 944                    )?;
 945                }
 946
 947                if should_auto_subscribe_to_channels(&zed_version) {
 948                    subscribe_user_to_channels(user.id, session).await?;
 949                }
 950
 951                if let Some(incoming_call) =
 952                    self.app_state.db.incoming_call_for_user(user.id).await?
 953                {
 954                    self.peer.send(connection_id, incoming_call)?;
 955                }
 956
 957                update_user_contacts(user.id, session).await?;
 958            }
 959        }
 960
 961        Ok(())
 962    }
 963}
 964
 965impl Deref for ConnectionPoolGuard<'_> {
 966    type Target = ConnectionPool;
 967
 968    fn deref(&self) -> &Self::Target {
 969        &self.guard
 970    }
 971}
 972
 973impl DerefMut for ConnectionPoolGuard<'_> {
 974    fn deref_mut(&mut self) -> &mut Self::Target {
 975        &mut self.guard
 976    }
 977}
 978
 979impl Drop for ConnectionPoolGuard<'_> {
 980    fn drop(&mut self) {
 981        #[cfg(feature = "test-support")]
 982        self.check_invariants();
 983    }
 984}
 985
 986fn broadcast<F>(
 987    sender_id: Option<ConnectionId>,
 988    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 989    mut f: F,
 990) where
 991    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 992{
 993    for receiver_id in receiver_ids {
 994        if Some(receiver_id) != sender_id
 995            && let Err(error) = f(receiver_id)
 996        {
 997            tracing::error!("failed to send to {:?} {}", receiver_id, error);
 998        }
 999    }
1000}
1001
1002pub struct ProtocolVersion(u32);
1003
1004impl Header for ProtocolVersion {
1005    fn name() -> &'static HeaderName {
1006        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1007        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1008    }
1009
1010    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1011    where
1012        Self: Sized,
1013        I: Iterator<Item = &'i axum::http::HeaderValue>,
1014    {
1015        let version = values
1016            .next()
1017            .ok_or_else(axum::headers::Error::invalid)?
1018            .to_str()
1019            .map_err(|_| axum::headers::Error::invalid())?
1020            .parse()
1021            .map_err(|_| axum::headers::Error::invalid())?;
1022        Ok(Self(version))
1023    }
1024
1025    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1026        values.extend([self.0.to_string().parse().unwrap()]);
1027    }
1028}
1029
1030pub struct AppVersionHeader(Version);
1031impl Header for AppVersionHeader {
1032    fn name() -> &'static HeaderName {
1033        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1034        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1035    }
1036
1037    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1038    where
1039        Self: Sized,
1040        I: Iterator<Item = &'i axum::http::HeaderValue>,
1041    {
1042        let version = values
1043            .next()
1044            .ok_or_else(axum::headers::Error::invalid)?
1045            .to_str()
1046            .map_err(|_| axum::headers::Error::invalid())?
1047            .parse()
1048            .map_err(|_| axum::headers::Error::invalid())?;
1049        Ok(Self(version))
1050    }
1051
1052    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1053        values.extend([self.0.to_string().parse().unwrap()]);
1054    }
1055}
1056
1057#[derive(Debug)]
1058pub struct ReleaseChannelHeader(String);
1059
1060impl Header for ReleaseChannelHeader {
1061    fn name() -> &'static HeaderName {
1062        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1063        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1064    }
1065
1066    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1067    where
1068        Self: Sized,
1069        I: Iterator<Item = &'i axum::http::HeaderValue>,
1070    {
1071        Ok(Self(
1072            values
1073                .next()
1074                .ok_or_else(axum::headers::Error::invalid)?
1075                .to_str()
1076                .map_err(|_| axum::headers::Error::invalid())?
1077                .to_owned(),
1078        ))
1079    }
1080
1081    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1082        values.extend([self.0.parse().unwrap()]);
1083    }
1084}
1085
1086pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1087    Router::new()
1088        .route("/rpc", get(handle_websocket_request))
1089        .layer(
1090            ServiceBuilder::new()
1091                .layer(Extension(server.app_state.clone()))
1092                .layer(middleware::from_fn(auth::validate_header)),
1093        )
1094        .route("/metrics", get(handle_metrics))
1095        .layer(Extension(server))
1096}
1097
1098pub async fn handle_websocket_request(
1099    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1100    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1101    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1102    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1103    Extension(server): Extension<Arc<Server>>,
1104    Extension(principal): Extension<Principal>,
1105    user_agent: Option<TypedHeader<UserAgent>>,
1106    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1107    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1108    ws: WebSocketUpgrade,
1109) -> axum::response::Response {
1110    if protocol_version != rpc::PROTOCOL_VERSION {
1111        return (
1112            StatusCode::UPGRADE_REQUIRED,
1113            "client must be upgraded".to_string(),
1114        )
1115            .into_response();
1116    }
1117
1118    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1119        return (
1120            StatusCode::UPGRADE_REQUIRED,
1121            "no version header found".to_string(),
1122        )
1123            .into_response();
1124    };
1125
1126    let release_channel = release_channel_header.map(|header| header.0.0);
1127
1128    if !version.can_collaborate() {
1129        return (
1130            StatusCode::UPGRADE_REQUIRED,
1131            "client must be upgraded".to_string(),
1132        )
1133            .into_response();
1134    }
1135
1136    let socket_address = socket_address.to_string();
1137
1138    // Acquire connection guard before WebSocket upgrade
1139    let connection_guard = match ConnectionGuard::try_acquire() {
1140        Ok(guard) => guard,
1141        Err(()) => {
1142            return (
1143                StatusCode::SERVICE_UNAVAILABLE,
1144                "Too many concurrent connections",
1145            )
1146                .into_response();
1147        }
1148    };
1149
1150    ws.on_upgrade(move |socket| {
1151        let socket = socket
1152            .map_ok(to_tungstenite_message)
1153            .err_into()
1154            .with(|message| async move { to_axum_message(message) });
1155        let connection = Connection::new(Box::pin(socket));
1156        async move {
1157            server
1158                .handle_connection(
1159                    connection,
1160                    socket_address,
1161                    principal,
1162                    version,
1163                    release_channel,
1164                    user_agent.map(|header| header.to_string()),
1165                    country_code_header.map(|header| header.to_string()),
1166                    system_id_header.map(|header| header.to_string()),
1167                    None,
1168                    Executor::Production,
1169                    Some(connection_guard),
1170                )
1171                .await;
1172        }
1173    })
1174}
1175
1176pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1177    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1178    let connections_metric = CONNECTIONS_METRIC
1179        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1180
1181    let connections = server
1182        .connection_pool
1183        .lock()
1184        .connections()
1185        .filter(|connection| !connection.admin)
1186        .count();
1187    connections_metric.set(connections as _);
1188
1189    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1190    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1191        register_int_gauge!(
1192            "shared_projects",
1193            "number of open projects with one or more guests"
1194        )
1195        .unwrap()
1196    });
1197
1198    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1199    shared_projects_metric.set(shared_projects as _);
1200
1201    let encoder = prometheus::TextEncoder::new();
1202    let metric_families = prometheus::gather();
1203    let encoded_metrics = encoder
1204        .encode_to_string(&metric_families)
1205        .map_err(|err| anyhow!("{err}"))?;
1206    Ok(encoded_metrics)
1207}
1208
1209#[instrument(err, skip(executor))]
1210async fn connection_lost(
1211    session: Session,
1212    mut teardown: watch::Receiver<bool>,
1213    executor: Executor,
1214) -> Result<()> {
1215    session.peer.disconnect(session.connection_id);
1216    session
1217        .connection_pool()
1218        .await
1219        .remove_connection(session.connection_id)?;
1220
1221    session
1222        .db()
1223        .await
1224        .connection_lost(session.connection_id)
1225        .await
1226        .trace_err();
1227
1228    futures::select_biased! {
1229        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1230
1231            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1232            leave_room_for_session(&session, session.connection_id).await.trace_err();
1233            leave_channel_buffers_for_session(&session)
1234                .await
1235                .trace_err();
1236
1237            if !session
1238                .connection_pool()
1239                .await
1240                .is_user_online(session.user_id())
1241            {
1242                let db = session.db().await;
1243                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1244                    room_updated(&room, &session.peer);
1245                }
1246            }
1247
1248            update_user_contacts(session.user_id(), &session).await?;
1249        },
1250        _ = teardown.changed().fuse() => {}
1251    }
1252
1253    Ok(())
1254}
1255
1256/// Acknowledges a ping from a client, used to keep the connection alive.
1257async fn ping(
1258    _: proto::Ping,
1259    response: Response<proto::Ping>,
1260    _session: MessageContext,
1261) -> Result<()> {
1262    response.send(proto::Ack {})?;
1263    Ok(())
1264}
1265
1266/// Creates a new room for calling (outside of channels)
1267async fn create_room(
1268    _request: proto::CreateRoom,
1269    response: Response<proto::CreateRoom>,
1270    session: MessageContext,
1271) -> Result<()> {
1272    let livekit_room = nanoid::nanoid!(30);
1273
1274    let live_kit_connection_info = util::maybe!(async {
1275        let live_kit = session.app_state.livekit_client.as_ref();
1276        let live_kit = live_kit?;
1277        let user_id = session.user_id().to_string();
1278
1279        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1280
1281        Some(proto::LiveKitConnectionInfo {
1282            server_url: live_kit.url().into(),
1283            token,
1284            can_publish: true,
1285        })
1286    })
1287    .await;
1288
1289    let room = session
1290        .db()
1291        .await
1292        .create_room(session.user_id(), session.connection_id, &livekit_room)
1293        .await?;
1294
1295    response.send(proto::CreateRoomResponse {
1296        room: Some(room.clone()),
1297        live_kit_connection_info,
1298    })?;
1299
1300    update_user_contacts(session.user_id(), &session).await?;
1301    Ok(())
1302}
1303
1304/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1305async fn join_room(
1306    request: proto::JoinRoom,
1307    response: Response<proto::JoinRoom>,
1308    session: MessageContext,
1309) -> Result<()> {
1310    let room_id = RoomId::from_proto(request.id);
1311
1312    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1313
1314    if let Some(channel_id) = channel_id {
1315        return join_channel_internal(channel_id, Box::new(response), session).await;
1316    }
1317
1318    let joined_room = {
1319        let room = session
1320            .db()
1321            .await
1322            .join_room(room_id, session.user_id(), session.connection_id)
1323            .await?;
1324        room_updated(&room.room, &session.peer);
1325        room.into_inner()
1326    };
1327
1328    for connection_id in session
1329        .connection_pool()
1330        .await
1331        .user_connection_ids(session.user_id())
1332    {
1333        session
1334            .peer
1335            .send(
1336                connection_id,
1337                proto::CallCanceled {
1338                    room_id: room_id.to_proto(),
1339                },
1340            )
1341            .trace_err();
1342    }
1343
1344    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1345    {
1346        live_kit
1347            .room_token(
1348                &joined_room.room.livekit_room,
1349                &session.user_id().to_string(),
1350            )
1351            .trace_err()
1352            .map(|token| proto::LiveKitConnectionInfo {
1353                server_url: live_kit.url().into(),
1354                token,
1355                can_publish: true,
1356            })
1357    } else {
1358        None
1359    };
1360
1361    response.send(proto::JoinRoomResponse {
1362        room: Some(joined_room.room),
1363        channel_id: None,
1364        live_kit_connection_info,
1365    })?;
1366
1367    update_user_contacts(session.user_id(), &session).await?;
1368    Ok(())
1369}
1370
1371/// Rejoin room is used to reconnect to a room after connection errors.
1372async fn rejoin_room(
1373    request: proto::RejoinRoom,
1374    response: Response<proto::RejoinRoom>,
1375    session: MessageContext,
1376) -> Result<()> {
1377    let room;
1378    let channel;
1379    {
1380        let mut rejoined_room = session
1381            .db()
1382            .await
1383            .rejoin_room(request, session.user_id(), session.connection_id)
1384            .await?;
1385
1386        response.send(proto::RejoinRoomResponse {
1387            room: Some(rejoined_room.room.clone()),
1388            reshared_projects: rejoined_room
1389                .reshared_projects
1390                .iter()
1391                .map(|project| proto::ResharedProject {
1392                    id: project.id.to_proto(),
1393                    collaborators: project
1394                        .collaborators
1395                        .iter()
1396                        .map(|collaborator| collaborator.to_proto())
1397                        .collect(),
1398                })
1399                .collect(),
1400            rejoined_projects: rejoined_room
1401                .rejoined_projects
1402                .iter()
1403                .map(|rejoined_project| rejoined_project.to_proto())
1404                .collect(),
1405        })?;
1406        room_updated(&rejoined_room.room, &session.peer);
1407
1408        for project in &rejoined_room.reshared_projects {
1409            for collaborator in &project.collaborators {
1410                session
1411                    .peer
1412                    .send(
1413                        collaborator.connection_id,
1414                        proto::UpdateProjectCollaborator {
1415                            project_id: project.id.to_proto(),
1416                            old_peer_id: Some(project.old_connection_id.into()),
1417                            new_peer_id: Some(session.connection_id.into()),
1418                        },
1419                    )
1420                    .trace_err();
1421            }
1422
1423            broadcast(
1424                Some(session.connection_id),
1425                project
1426                    .collaborators
1427                    .iter()
1428                    .map(|collaborator| collaborator.connection_id),
1429                |connection_id| {
1430                    session.peer.forward_send(
1431                        session.connection_id,
1432                        connection_id,
1433                        proto::UpdateProject {
1434                            project_id: project.id.to_proto(),
1435                            worktrees: project.worktrees.clone(),
1436                        },
1437                    )
1438                },
1439            );
1440        }
1441
1442        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1443
1444        let rejoined_room = rejoined_room.into_inner();
1445
1446        room = rejoined_room.room;
1447        channel = rejoined_room.channel;
1448    }
1449
1450    if let Some(channel) = channel {
1451        channel_updated(
1452            &channel,
1453            &room,
1454            &session.peer,
1455            &*session.connection_pool().await,
1456        );
1457    }
1458
1459    update_user_contacts(session.user_id(), &session).await?;
1460    Ok(())
1461}
1462
1463fn notify_rejoined_projects(
1464    rejoined_projects: &mut Vec<RejoinedProject>,
1465    session: &Session,
1466) -> Result<()> {
1467    for project in rejoined_projects.iter() {
1468        for collaborator in &project.collaborators {
1469            session
1470                .peer
1471                .send(
1472                    collaborator.connection_id,
1473                    proto::UpdateProjectCollaborator {
1474                        project_id: project.id.to_proto(),
1475                        old_peer_id: Some(project.old_connection_id.into()),
1476                        new_peer_id: Some(session.connection_id.into()),
1477                    },
1478                )
1479                .trace_err();
1480        }
1481    }
1482
1483    for project in rejoined_projects {
1484        for worktree in mem::take(&mut project.worktrees) {
1485            // Stream this worktree's entries.
1486            let message = proto::UpdateWorktree {
1487                project_id: project.id.to_proto(),
1488                worktree_id: worktree.id,
1489                abs_path: worktree.abs_path.clone(),
1490                root_name: worktree.root_name,
1491                root_repo_common_dir: worktree.root_repo_common_dir,
1492                updated_entries: worktree.updated_entries,
1493                removed_entries: worktree.removed_entries,
1494                scan_id: worktree.scan_id,
1495                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1496                updated_repositories: worktree.updated_repositories,
1497                removed_repositories: worktree.removed_repositories,
1498            };
1499            for update in proto::split_worktree_update(message) {
1500                session.peer.send(session.connection_id, update)?;
1501            }
1502
1503            // Stream this worktree's diagnostics.
1504            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1505            if let Some(summary) = worktree_diagnostics.next() {
1506                let message = proto::UpdateDiagnosticSummary {
1507                    project_id: project.id.to_proto(),
1508                    worktree_id: worktree.id,
1509                    summary: Some(summary),
1510                    more_summaries: worktree_diagnostics.collect(),
1511                };
1512                session.peer.send(session.connection_id, message)?;
1513            }
1514
1515            for settings_file in worktree.settings_files {
1516                session.peer.send(
1517                    session.connection_id,
1518                    proto::UpdateWorktreeSettings {
1519                        project_id: project.id.to_proto(),
1520                        worktree_id: worktree.id,
1521                        path: settings_file.path,
1522                        content: Some(settings_file.content),
1523                        kind: Some(settings_file.kind.to_proto().into()),
1524                        outside_worktree: Some(settings_file.outside_worktree),
1525                    },
1526                )?;
1527            }
1528        }
1529
1530        for repository in mem::take(&mut project.updated_repositories) {
1531            for update in split_repository_update(repository) {
1532                session.peer.send(session.connection_id, update)?;
1533            }
1534        }
1535
1536        for id in mem::take(&mut project.removed_repositories) {
1537            session.peer.send(
1538                session.connection_id,
1539                proto::RemoveRepository {
1540                    project_id: project.id.to_proto(),
1541                    id,
1542                },
1543            )?;
1544        }
1545    }
1546
1547    Ok(())
1548}
1549
1550/// leave room disconnects from the room.
1551async fn leave_room(
1552    _: proto::LeaveRoom,
1553    response: Response<proto::LeaveRoom>,
1554    session: MessageContext,
1555) -> Result<()> {
1556    leave_room_for_session(&session, session.connection_id).await?;
1557    response.send(proto::Ack {})?;
1558    Ok(())
1559}
1560
1561/// Updates the permissions of someone else in the room.
1562async fn set_room_participant_role(
1563    request: proto::SetRoomParticipantRole,
1564    response: Response<proto::SetRoomParticipantRole>,
1565    session: MessageContext,
1566) -> Result<()> {
1567    let user_id = UserId::from_proto(request.user_id);
1568    let role = ChannelRole::from(request.role());
1569
1570    let (livekit_room, can_publish) = {
1571        let room = session
1572            .db()
1573            .await
1574            .set_room_participant_role(
1575                session.user_id(),
1576                RoomId::from_proto(request.room_id),
1577                user_id,
1578                role,
1579            )
1580            .await?;
1581
1582        let livekit_room = room.livekit_room.clone();
1583        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1584        room_updated(&room, &session.peer);
1585        (livekit_room, can_publish)
1586    };
1587
1588    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1589        live_kit
1590            .update_participant(
1591                livekit_room.clone(),
1592                request.user_id.to_string(),
1593                livekit_api::proto::ParticipantPermission {
1594                    can_subscribe: true,
1595                    can_publish,
1596                    can_publish_data: can_publish,
1597                    hidden: false,
1598                    recorder: false,
1599                },
1600            )
1601            .await
1602            .trace_err();
1603    }
1604
1605    response.send(proto::Ack {})?;
1606    Ok(())
1607}
1608
1609/// Call someone else into the current room
1610async fn call(
1611    request: proto::Call,
1612    response: Response<proto::Call>,
1613    session: MessageContext,
1614) -> Result<()> {
1615    let room_id = RoomId::from_proto(request.room_id);
1616    let calling_user_id = session.user_id();
1617    let calling_connection_id = session.connection_id;
1618    let called_user_id = UserId::from_proto(request.called_user_id);
1619    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1620    if !session
1621        .db()
1622        .await
1623        .has_contact(calling_user_id, called_user_id)
1624        .await?
1625    {
1626        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1627    }
1628
1629    let incoming_call = {
1630        let (room, incoming_call) = &mut *session
1631            .db()
1632            .await
1633            .call(
1634                room_id,
1635                calling_user_id,
1636                calling_connection_id,
1637                called_user_id,
1638                initial_project_id,
1639            )
1640            .await?;
1641        room_updated(room, &session.peer);
1642        mem::take(incoming_call)
1643    };
1644    update_user_contacts(called_user_id, &session).await?;
1645
1646    let mut calls = session
1647        .connection_pool()
1648        .await
1649        .user_connection_ids(called_user_id)
1650        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1651        .collect::<FuturesUnordered<_>>();
1652
1653    while let Some(call_response) = calls.next().await {
1654        match call_response.as_ref() {
1655            Ok(_) => {
1656                response.send(proto::Ack {})?;
1657                return Ok(());
1658            }
1659            Err(_) => {
1660                call_response.trace_err();
1661            }
1662        }
1663    }
1664
1665    {
1666        let room = session
1667            .db()
1668            .await
1669            .call_failed(room_id, called_user_id)
1670            .await?;
1671        room_updated(&room, &session.peer);
1672    }
1673    update_user_contacts(called_user_id, &session).await?;
1674
1675    Err(anyhow!("failed to ring user"))?
1676}
1677
1678/// Cancel an outgoing call.
1679async fn cancel_call(
1680    request: proto::CancelCall,
1681    response: Response<proto::CancelCall>,
1682    session: MessageContext,
1683) -> Result<()> {
1684    let called_user_id = UserId::from_proto(request.called_user_id);
1685    let room_id = RoomId::from_proto(request.room_id);
1686    {
1687        let room = session
1688            .db()
1689            .await
1690            .cancel_call(room_id, session.connection_id, called_user_id)
1691            .await?;
1692        room_updated(&room, &session.peer);
1693    }
1694
1695    for connection_id in session
1696        .connection_pool()
1697        .await
1698        .user_connection_ids(called_user_id)
1699    {
1700        session
1701            .peer
1702            .send(
1703                connection_id,
1704                proto::CallCanceled {
1705                    room_id: room_id.to_proto(),
1706                },
1707            )
1708            .trace_err();
1709    }
1710    response.send(proto::Ack {})?;
1711
1712    update_user_contacts(called_user_id, &session).await?;
1713    Ok(())
1714}
1715
1716/// Decline an incoming call.
1717async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1718    let room_id = RoomId::from_proto(message.room_id);
1719    {
1720        let room = session
1721            .db()
1722            .await
1723            .decline_call(Some(room_id), session.user_id())
1724            .await?
1725            .context("declining call")?;
1726        room_updated(&room, &session.peer);
1727    }
1728
1729    for connection_id in session
1730        .connection_pool()
1731        .await
1732        .user_connection_ids(session.user_id())
1733    {
1734        session
1735            .peer
1736            .send(
1737                connection_id,
1738                proto::CallCanceled {
1739                    room_id: room_id.to_proto(),
1740                },
1741            )
1742            .trace_err();
1743    }
1744    update_user_contacts(session.user_id(), &session).await?;
1745    Ok(())
1746}
1747
1748/// Updates other participants in the room with your current location.
1749async fn update_participant_location(
1750    request: proto::UpdateParticipantLocation,
1751    response: Response<proto::UpdateParticipantLocation>,
1752    session: MessageContext,
1753) -> Result<()> {
1754    let room_id = RoomId::from_proto(request.room_id);
1755    let location = request.location.context("invalid location")?;
1756
1757    let db = session.db().await;
1758    let room = db
1759        .update_room_participant_location(room_id, session.connection_id, location)
1760        .await?;
1761
1762    room_updated(&room, &session.peer);
1763    response.send(proto::Ack {})?;
1764    Ok(())
1765}
1766
1767/// Share a project into the room.
1768async fn share_project(
1769    request: proto::ShareProject,
1770    response: Response<proto::ShareProject>,
1771    session: MessageContext,
1772) -> Result<()> {
1773    let (project_id, room) = &*session
1774        .db()
1775        .await
1776        .share_project(
1777            RoomId::from_proto(request.room_id),
1778            session.connection_id,
1779            &request.worktrees,
1780            request.is_ssh_project,
1781            request.windows_paths.unwrap_or(false),
1782            &request.features,
1783        )
1784        .await?;
1785    response.send(proto::ShareProjectResponse {
1786        project_id: project_id.to_proto(),
1787    })?;
1788    room_updated(room, &session.peer);
1789
1790    Ok(())
1791}
1792
1793/// Unshare a project from the room.
1794async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1795    let project_id = ProjectId::from_proto(message.project_id);
1796    unshare_project_internal(project_id, session.connection_id, &session).await
1797}
1798
1799async fn unshare_project_internal(
1800    project_id: ProjectId,
1801    connection_id: ConnectionId,
1802    session: &Session,
1803) -> Result<()> {
1804    let delete = {
1805        let room_guard = session
1806            .db()
1807            .await
1808            .unshare_project(project_id, connection_id)
1809            .await?;
1810
1811        let (delete, room, guest_connection_ids) = &*room_guard;
1812
1813        let message = proto::UnshareProject {
1814            project_id: project_id.to_proto(),
1815        };
1816
1817        broadcast(
1818            Some(connection_id),
1819            guest_connection_ids.iter().copied(),
1820            |conn_id| session.peer.send(conn_id, message.clone()),
1821        );
1822        if let Some(room) = room {
1823            room_updated(room, &session.peer);
1824        }
1825
1826        *delete
1827    };
1828
1829    if delete {
1830        let db = session.db().await;
1831        db.delete_project(project_id).await?;
1832    }
1833
1834    Ok(())
1835}
1836
1837/// Join someone elses shared project.
1838async fn join_project(
1839    request: proto::JoinProject,
1840    response: Response<proto::JoinProject>,
1841    session: MessageContext,
1842) -> Result<()> {
1843    let project_id = ProjectId::from_proto(request.project_id);
1844
1845    tracing::info!(%project_id, "join project");
1846
1847    let db = session.db().await;
1848    let project_model = db.get_project(project_id).await?;
1849    let host_features: Vec<String> =
1850        serde_json::from_str(&project_model.features).unwrap_or_default();
1851    let guest_features: HashSet<_> = request.features.iter().collect();
1852    let host_features_set: HashSet<_> = host_features.iter().collect();
1853    if guest_features != host_features_set {
1854        let host_connection_id = project_model.host_connection()?;
1855        let mut pool = session.connection_pool().await;
1856        let host_version = pool
1857            .connection(host_connection_id)
1858            .map(|c| c.zed_version.to_string());
1859        let guest_version = pool
1860            .connection(session.connection_id)
1861            .map(|c| c.zed_version.to_string());
1862        drop(pool);
1863        Err(anyhow!(
1864            "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1865            host_version.as_deref().unwrap_or("unknown"),
1866            guest_version.as_deref().unwrap_or("unknown"),
1867        ))?;
1868    }
1869
1870    let (project, replica_id) = &mut *db
1871        .join_project(
1872            project_id,
1873            session.connection_id,
1874            session.user_id(),
1875            request.committer_name.clone(),
1876            request.committer_email.clone(),
1877        )
1878        .await?;
1879    drop(db);
1880
1881    tracing::info!(%project_id, "join remote project");
1882    let collaborators = project
1883        .collaborators
1884        .iter()
1885        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1886        .map(|collaborator| collaborator.to_proto())
1887        .collect::<Vec<_>>();
1888    let project_id = project.id;
1889    let guest_user_id = session.user_id();
1890
1891    let worktrees = project
1892        .worktrees
1893        .iter()
1894        .map(|(id, worktree)| proto::WorktreeMetadata {
1895            id: *id,
1896            root_name: worktree.root_name.clone(),
1897            visible: worktree.visible,
1898            abs_path: worktree.abs_path.clone(),
1899            root_repo_common_dir: None,
1900        })
1901        .collect::<Vec<_>>();
1902
1903    let add_project_collaborator = proto::AddProjectCollaborator {
1904        project_id: project_id.to_proto(),
1905        collaborator: Some(proto::Collaborator {
1906            peer_id: Some(session.connection_id.into()),
1907            replica_id: replica_id.0 as u32,
1908            user_id: guest_user_id.to_proto(),
1909            is_host: false,
1910            committer_name: request.committer_name.clone(),
1911            committer_email: request.committer_email.clone(),
1912        }),
1913    };
1914
1915    for collaborator in &collaborators {
1916        session
1917            .peer
1918            .send(
1919                collaborator.peer_id.unwrap().into(),
1920                add_project_collaborator.clone(),
1921            )
1922            .trace_err();
1923    }
1924
1925    // First, we send the metadata associated with each worktree.
1926    let (language_servers, language_server_capabilities) = project
1927        .language_servers
1928        .clone()
1929        .into_iter()
1930        .map(|server| (server.server, server.capabilities))
1931        .unzip();
1932    response.send(proto::JoinProjectResponse {
1933        project_id: project.id.0 as u64,
1934        worktrees,
1935        replica_id: replica_id.0 as u32,
1936        collaborators,
1937        language_servers,
1938        language_server_capabilities,
1939        role: project.role.into(),
1940        windows_paths: project.path_style == PathStyle::Windows,
1941        features: project.features.clone(),
1942    })?;
1943
1944    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1945        // Stream this worktree's entries.
1946        let message = proto::UpdateWorktree {
1947            project_id: project_id.to_proto(),
1948            worktree_id,
1949            abs_path: worktree.abs_path.clone(),
1950            root_name: worktree.root_name,
1951            root_repo_common_dir: worktree.root_repo_common_dir,
1952            updated_entries: worktree.entries,
1953            removed_entries: Default::default(),
1954            scan_id: worktree.scan_id,
1955            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1956            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1957            removed_repositories: Default::default(),
1958        };
1959        for update in proto::split_worktree_update(message) {
1960            session.peer.send(session.connection_id, update.clone())?;
1961        }
1962
1963        // Stream this worktree's diagnostics.
1964        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1965        if let Some(summary) = worktree_diagnostics.next() {
1966            let message = proto::UpdateDiagnosticSummary {
1967                project_id: project.id.to_proto(),
1968                worktree_id: worktree.id,
1969                summary: Some(summary),
1970                more_summaries: worktree_diagnostics.collect(),
1971            };
1972            session.peer.send(session.connection_id, message)?;
1973        }
1974
1975        for settings_file in worktree.settings_files {
1976            session.peer.send(
1977                session.connection_id,
1978                proto::UpdateWorktreeSettings {
1979                    project_id: project_id.to_proto(),
1980                    worktree_id: worktree.id,
1981                    path: settings_file.path,
1982                    content: Some(settings_file.content),
1983                    kind: Some(settings_file.kind.to_proto() as i32),
1984                    outside_worktree: Some(settings_file.outside_worktree),
1985                },
1986            )?;
1987        }
1988    }
1989
1990    for repository in mem::take(&mut project.repositories) {
1991        for update in split_repository_update(repository) {
1992            session.peer.send(session.connection_id, update)?;
1993        }
1994    }
1995
1996    for language_server in &project.language_servers {
1997        session.peer.send(
1998            session.connection_id,
1999            proto::UpdateLanguageServer {
2000                project_id: project_id.to_proto(),
2001                server_name: Some(language_server.server.name.clone()),
2002                language_server_id: language_server.server.id,
2003                variant: Some(
2004                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2005                        proto::LspDiskBasedDiagnosticsUpdated {},
2006                    ),
2007                ),
2008            },
2009        )?;
2010    }
2011
2012    Ok(())
2013}
2014
2015/// Leave someone elses shared project.
2016async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2017    let sender_id = session.connection_id;
2018    let project_id = ProjectId::from_proto(request.project_id);
2019    let db = session.db().await;
2020
2021    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2022    tracing::info!(
2023        %project_id,
2024        "leave project"
2025    );
2026
2027    project_left(project, &session);
2028    if let Some(room) = room {
2029        room_updated(room, &session.peer);
2030    }
2031
2032    Ok(())
2033}
2034
2035/// Updates other participants with changes to the project
2036async fn update_project(
2037    request: proto::UpdateProject,
2038    response: Response<proto::UpdateProject>,
2039    session: MessageContext,
2040) -> Result<()> {
2041    let project_id = ProjectId::from_proto(request.project_id);
2042    let (room, guest_connection_ids) = &*session
2043        .db()
2044        .await
2045        .update_project(project_id, session.connection_id, &request.worktrees)
2046        .await?;
2047    broadcast(
2048        Some(session.connection_id),
2049        guest_connection_ids.iter().copied(),
2050        |connection_id| {
2051            session
2052                .peer
2053                .forward_send(session.connection_id, connection_id, request.clone())
2054        },
2055    );
2056    if let Some(room) = room {
2057        room_updated(room, &session.peer);
2058    }
2059    response.send(proto::Ack {})?;
2060
2061    Ok(())
2062}
2063
2064/// Updates other participants with changes to the worktree
2065async fn update_worktree(
2066    request: proto::UpdateWorktree,
2067    response: Response<proto::UpdateWorktree>,
2068    session: MessageContext,
2069) -> Result<()> {
2070    let guest_connection_ids = session
2071        .db()
2072        .await
2073        .update_worktree(&request, session.connection_id)
2074        .await?;
2075
2076    broadcast(
2077        Some(session.connection_id),
2078        guest_connection_ids.iter().copied(),
2079        |connection_id| {
2080            session
2081                .peer
2082                .forward_send(session.connection_id, connection_id, request.clone())
2083        },
2084    );
2085    response.send(proto::Ack {})?;
2086    Ok(())
2087}
2088
2089async fn update_repository(
2090    request: proto::UpdateRepository,
2091    response: Response<proto::UpdateRepository>,
2092    session: MessageContext,
2093) -> Result<()> {
2094    let guest_connection_ids = session
2095        .db()
2096        .await
2097        .update_repository(&request, session.connection_id)
2098        .await?;
2099
2100    broadcast(
2101        Some(session.connection_id),
2102        guest_connection_ids.iter().copied(),
2103        |connection_id| {
2104            session
2105                .peer
2106                .forward_send(session.connection_id, connection_id, request.clone())
2107        },
2108    );
2109    response.send(proto::Ack {})?;
2110    Ok(())
2111}
2112
2113async fn remove_repository(
2114    request: proto::RemoveRepository,
2115    response: Response<proto::RemoveRepository>,
2116    session: MessageContext,
2117) -> Result<()> {
2118    let guest_connection_ids = session
2119        .db()
2120        .await
2121        .remove_repository(&request, session.connection_id)
2122        .await?;
2123
2124    broadcast(
2125        Some(session.connection_id),
2126        guest_connection_ids.iter().copied(),
2127        |connection_id| {
2128            session
2129                .peer
2130                .forward_send(session.connection_id, connection_id, request.clone())
2131        },
2132    );
2133    response.send(proto::Ack {})?;
2134    Ok(())
2135}
2136
2137/// Updates other participants with changes to the diagnostics
2138async fn update_diagnostic_summary(
2139    message: proto::UpdateDiagnosticSummary,
2140    session: MessageContext,
2141) -> Result<()> {
2142    let guest_connection_ids = session
2143        .db()
2144        .await
2145        .update_diagnostic_summary(&message, session.connection_id)
2146        .await?;
2147
2148    broadcast(
2149        Some(session.connection_id),
2150        guest_connection_ids.iter().copied(),
2151        |connection_id| {
2152            session
2153                .peer
2154                .forward_send(session.connection_id, connection_id, message.clone())
2155        },
2156    );
2157
2158    Ok(())
2159}
2160
2161/// Updates other participants with changes to the worktree settings
2162async fn update_worktree_settings(
2163    message: proto::UpdateWorktreeSettings,
2164    session: MessageContext,
2165) -> Result<()> {
2166    let guest_connection_ids = session
2167        .db()
2168        .await
2169        .update_worktree_settings(&message, session.connection_id)
2170        .await?;
2171
2172    broadcast(
2173        Some(session.connection_id),
2174        guest_connection_ids.iter().copied(),
2175        |connection_id| {
2176            session
2177                .peer
2178                .forward_send(session.connection_id, connection_id, message.clone())
2179        },
2180    );
2181
2182    Ok(())
2183}
2184
2185/// Notify other participants that a language server has started.
2186async fn start_language_server(
2187    request: proto::StartLanguageServer,
2188    session: MessageContext,
2189) -> Result<()> {
2190    let guest_connection_ids = session
2191        .db()
2192        .await
2193        .start_language_server(&request, session.connection_id)
2194        .await?;
2195
2196    broadcast(
2197        Some(session.connection_id),
2198        guest_connection_ids.iter().copied(),
2199        |connection_id| {
2200            session
2201                .peer
2202                .forward_send(session.connection_id, connection_id, request.clone())
2203        },
2204    );
2205    Ok(())
2206}
2207
2208/// Notify other participants that a language server has changed.
2209async fn update_language_server(
2210    request: proto::UpdateLanguageServer,
2211    session: MessageContext,
2212) -> Result<()> {
2213    let project_id = ProjectId::from_proto(request.project_id);
2214    let db = session.db().await;
2215
2216    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2217        && let Some(capabilities) = update.capabilities.clone()
2218    {
2219        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2220            .await?;
2221    }
2222
2223    let project_connection_ids = db
2224        .project_connection_ids(project_id, session.connection_id, true)
2225        .await?;
2226    broadcast(
2227        Some(session.connection_id),
2228        project_connection_ids.iter().copied(),
2229        |connection_id| {
2230            session
2231                .peer
2232                .forward_send(session.connection_id, connection_id, request.clone())
2233        },
2234    );
2235    Ok(())
2236}
2237
2238/// forward a project request to the host. These requests should be read only
2239/// as guests are allowed to send them.
2240async fn forward_read_only_project_request<T>(
2241    request: T,
2242    response: Response<T>,
2243    session: MessageContext,
2244) -> Result<()>
2245where
2246    T: EntityMessage + RequestMessage,
2247{
2248    let project_id = ProjectId::from_proto(request.remote_entity_id());
2249    let host_connection_id = session
2250        .db()
2251        .await
2252        .host_for_read_only_project_request(project_id, session.connection_id)
2253        .await?;
2254    let payload = session.forward_request(host_connection_id, request).await?;
2255    response.send(payload)?;
2256    Ok(())
2257}
2258
2259/// forward a project request to the host. These requests are disallowed
2260/// for guests.
2261async fn forward_mutating_project_request<T>(
2262    request: T,
2263    response: Response<T>,
2264    session: MessageContext,
2265) -> Result<()>
2266where
2267    T: EntityMessage + RequestMessage,
2268{
2269    let project_id = ProjectId::from_proto(request.remote_entity_id());
2270
2271    let host_connection_id = session
2272        .db()
2273        .await
2274        .host_for_mutating_project_request(project_id, session.connection_id)
2275        .await?;
2276    let payload = session.forward_request(host_connection_id, request).await?;
2277    response.send(payload)?;
2278    Ok(())
2279}
2280
2281async fn disallow_guest_request<T>(
2282    _request: T,
2283    response: Response<T>,
2284    _session: MessageContext,
2285) -> Result<()>
2286where
2287    T: RequestMessage,
2288{
2289    response.peer.respond_with_error(
2290        response.receipt,
2291        ErrorCode::Forbidden
2292            .message("request is not allowed for guests".to_string())
2293            .to_proto(),
2294    )?;
2295    response.responded.store(true, SeqCst);
2296    Ok(())
2297}
2298
2299async fn lsp_query(
2300    request: proto::LspQuery,
2301    response: Response<proto::LspQuery>,
2302    session: MessageContext,
2303) -> Result<()> {
2304    let (name, should_write) = request.query_name_and_write_permissions();
2305    tracing::Span::current().record("lsp_query_request", name);
2306    tracing::info!("lsp_query message received");
2307    if should_write {
2308        forward_mutating_project_request(request, response, session).await
2309    } else {
2310        forward_read_only_project_request(request, response, session).await
2311    }
2312}
2313
2314/// Notify other participants that a new buffer has been created
2315async fn create_buffer_for_peer(
2316    request: proto::CreateBufferForPeer,
2317    session: MessageContext,
2318) -> Result<()> {
2319    session
2320        .db()
2321        .await
2322        .check_user_is_project_host(
2323            ProjectId::from_proto(request.project_id),
2324            session.connection_id,
2325        )
2326        .await?;
2327    let peer_id = request.peer_id.context("invalid peer id")?;
2328    session
2329        .peer
2330        .forward_send(session.connection_id, peer_id.into(), request)?;
2331    Ok(())
2332}
2333
2334/// Notify other participants that a new image has been created
2335async fn create_image_for_peer(
2336    request: proto::CreateImageForPeer,
2337    session: MessageContext,
2338) -> Result<()> {
2339    session
2340        .db()
2341        .await
2342        .check_user_is_project_host(
2343            ProjectId::from_proto(request.project_id),
2344            session.connection_id,
2345        )
2346        .await?;
2347    let peer_id = request.peer_id.context("invalid peer id")?;
2348    session
2349        .peer
2350        .forward_send(session.connection_id, peer_id.into(), request)?;
2351    Ok(())
2352}
2353
2354/// Notify other participants that a buffer has been updated. This is
2355/// allowed for guests as long as the update is limited to selections.
2356async fn update_buffer(
2357    request: proto::UpdateBuffer,
2358    response: Response<proto::UpdateBuffer>,
2359    session: MessageContext,
2360) -> Result<()> {
2361    let project_id = ProjectId::from_proto(request.project_id);
2362    let mut capability = Capability::ReadOnly;
2363
2364    for op in request.operations.iter() {
2365        match op.variant {
2366            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2367            Some(_) => capability = Capability::ReadWrite,
2368        }
2369    }
2370
2371    let host = {
2372        let guard = session
2373            .db()
2374            .await
2375            .connections_for_buffer_update(project_id, session.connection_id, capability)
2376            .await?;
2377
2378        let (host, guests) = &*guard;
2379
2380        broadcast(
2381            Some(session.connection_id),
2382            guests.clone(),
2383            |connection_id| {
2384                session
2385                    .peer
2386                    .forward_send(session.connection_id, connection_id, request.clone())
2387            },
2388        );
2389
2390        *host
2391    };
2392
2393    if host != session.connection_id {
2394        session.forward_request(host, request.clone()).await?;
2395    }
2396
2397    response.send(proto::Ack {})?;
2398    Ok(())
2399}
2400
2401async fn forward_project_search_chunk(
2402    message: proto::FindSearchCandidatesChunk,
2403    response: Response<proto::FindSearchCandidatesChunk>,
2404    session: MessageContext,
2405) -> Result<()> {
2406    let peer_id = message.peer_id.context("missing peer_id")?;
2407    let payload = session
2408        .peer
2409        .forward_request(session.connection_id, peer_id.into(), message)
2410        .await?;
2411    response.send(payload)?;
2412    Ok(())
2413}
2414
2415/// Notify other participants that a project has been updated.
2416async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2417    request: T,
2418    session: MessageContext,
2419) -> Result<()> {
2420    let project_id = ProjectId::from_proto(request.remote_entity_id());
2421    let project_connection_ids = session
2422        .db()
2423        .await
2424        .project_connection_ids(project_id, session.connection_id, false)
2425        .await?;
2426
2427    broadcast(
2428        Some(session.connection_id),
2429        project_connection_ids.iter().copied(),
2430        |connection_id| {
2431            session
2432                .peer
2433                .forward_send(session.connection_id, connection_id, request.clone())
2434        },
2435    );
2436    Ok(())
2437}
2438
2439/// Start following another user in a call.
2440async fn follow(
2441    request: proto::Follow,
2442    response: Response<proto::Follow>,
2443    session: MessageContext,
2444) -> Result<()> {
2445    let room_id = RoomId::from_proto(request.room_id);
2446    let project_id = request.project_id.map(ProjectId::from_proto);
2447    let leader_id = request.leader_id.context("invalid leader id")?.into();
2448    let follower_id = session.connection_id;
2449
2450    session
2451        .db()
2452        .await
2453        .check_room_participants(room_id, leader_id, session.connection_id)
2454        .await?;
2455
2456    let response_payload = session.forward_request(leader_id, request).await?;
2457    response.send(response_payload)?;
2458
2459    if let Some(project_id) = project_id {
2460        let room = session
2461            .db()
2462            .await
2463            .follow(room_id, project_id, leader_id, follower_id)
2464            .await?;
2465        room_updated(&room, &session.peer);
2466    }
2467
2468    Ok(())
2469}
2470
2471/// Stop following another user in a call.
2472async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2473    let room_id = RoomId::from_proto(request.room_id);
2474    let project_id = request.project_id.map(ProjectId::from_proto);
2475    let leader_id = request.leader_id.context("invalid leader id")?.into();
2476    let follower_id = session.connection_id;
2477
2478    session
2479        .db()
2480        .await
2481        .check_room_participants(room_id, leader_id, session.connection_id)
2482        .await?;
2483
2484    session
2485        .peer
2486        .forward_send(session.connection_id, leader_id, request)?;
2487
2488    if let Some(project_id) = project_id {
2489        let room = session
2490            .db()
2491            .await
2492            .unfollow(room_id, project_id, leader_id, follower_id)
2493            .await?;
2494        room_updated(&room, &session.peer);
2495    }
2496
2497    Ok(())
2498}
2499
2500/// Notify everyone following you of your current location.
2501async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2502    let room_id = RoomId::from_proto(request.room_id);
2503    let database = session.db.lock().await;
2504
2505    let connection_ids = if let Some(project_id) = request.project_id {
2506        let project_id = ProjectId::from_proto(project_id);
2507        database
2508            .project_connection_ids(project_id, session.connection_id, true)
2509            .await?
2510    } else {
2511        database
2512            .room_connection_ids(room_id, session.connection_id)
2513            .await?
2514    };
2515
2516    // For now, don't send view update messages back to that view's current leader.
2517    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2518        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2519        _ => None,
2520    });
2521
2522    for connection_id in connection_ids.iter().cloned() {
2523        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2524            session
2525                .peer
2526                .forward_send(session.connection_id, connection_id, request.clone())?;
2527        }
2528    }
2529    Ok(())
2530}
2531
2532/// Get public data about users.
2533async fn get_users(
2534    request: proto::GetUsers,
2535    response: Response<proto::GetUsers>,
2536    session: MessageContext,
2537) -> Result<()> {
2538    let user_ids = request
2539        .user_ids
2540        .into_iter()
2541        .map(UserId::from_proto)
2542        .collect();
2543    let users = session
2544        .db()
2545        .await
2546        .get_users_by_ids(user_ids)
2547        .await?
2548        .into_iter()
2549        .map(|user| proto::User {
2550            id: user.id.to_proto(),
2551            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2552            github_login: user.github_login,
2553            name: user.name,
2554        })
2555        .collect();
2556    response.send(proto::UsersResponse { users })?;
2557    Ok(())
2558}
2559
2560/// Search for users (to invite) buy Github login
2561async fn fuzzy_search_users(
2562    request: proto::FuzzySearchUsers,
2563    response: Response<proto::FuzzySearchUsers>,
2564    session: MessageContext,
2565) -> Result<()> {
2566    let query = request.query;
2567    let users = match query.len() {
2568        0 => vec![],
2569        1 | 2 => session
2570            .db()
2571            .await
2572            .get_user_by_github_login(&query)
2573            .await?
2574            .into_iter()
2575            .collect(),
2576        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2577    };
2578    let users = users
2579        .into_iter()
2580        .filter(|user| user.id != session.user_id())
2581        .map(|user| proto::User {
2582            id: user.id.to_proto(),
2583            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2584            github_login: user.github_login,
2585            name: user.name,
2586        })
2587        .collect();
2588    response.send(proto::UsersResponse { users })?;
2589    Ok(())
2590}
2591
2592/// Send a contact request to another user.
2593async fn request_contact(
2594    request: proto::RequestContact,
2595    response: Response<proto::RequestContact>,
2596    session: MessageContext,
2597) -> Result<()> {
2598    let requester_id = session.user_id();
2599    let responder_id = UserId::from_proto(request.responder_id);
2600    if requester_id == responder_id {
2601        return Err(anyhow!("cannot add yourself as a contact"))?;
2602    }
2603
2604    let notifications = session
2605        .db()
2606        .await
2607        .send_contact_request(requester_id, responder_id)
2608        .await?;
2609
2610    // Update outgoing contact requests of requester
2611    let mut update = proto::UpdateContacts::default();
2612    update.outgoing_requests.push(responder_id.to_proto());
2613    for connection_id in session
2614        .connection_pool()
2615        .await
2616        .user_connection_ids(requester_id)
2617    {
2618        session.peer.send(connection_id, update.clone())?;
2619    }
2620
2621    // Update incoming contact requests of responder
2622    let mut update = proto::UpdateContacts::default();
2623    update
2624        .incoming_requests
2625        .push(proto::IncomingContactRequest {
2626            requester_id: requester_id.to_proto(),
2627        });
2628    let connection_pool = session.connection_pool().await;
2629    for connection_id in connection_pool.user_connection_ids(responder_id) {
2630        session.peer.send(connection_id, update.clone())?;
2631    }
2632
2633    send_notifications(&connection_pool, &session.peer, notifications);
2634
2635    response.send(proto::Ack {})?;
2636    Ok(())
2637}
2638
2639/// Accept or decline a contact request
2640async fn respond_to_contact_request(
2641    request: proto::RespondToContactRequest,
2642    response: Response<proto::RespondToContactRequest>,
2643    session: MessageContext,
2644) -> Result<()> {
2645    let responder_id = session.user_id();
2646    let requester_id = UserId::from_proto(request.requester_id);
2647    let db = session.db().await;
2648    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2649        db.dismiss_contact_notification(responder_id, requester_id)
2650            .await?;
2651    } else {
2652        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2653
2654        let notifications = db
2655            .respond_to_contact_request(responder_id, requester_id, accept)
2656            .await?;
2657        let requester_busy = db.is_user_busy(requester_id).await?;
2658        let responder_busy = db.is_user_busy(responder_id).await?;
2659
2660        let pool = session.connection_pool().await;
2661        // Update responder with new contact
2662        let mut update = proto::UpdateContacts::default();
2663        if accept {
2664            update
2665                .contacts
2666                .push(contact_for_user(requester_id, requester_busy, &pool));
2667        }
2668        update
2669            .remove_incoming_requests
2670            .push(requester_id.to_proto());
2671        for connection_id in pool.user_connection_ids(responder_id) {
2672            session.peer.send(connection_id, update.clone())?;
2673        }
2674
2675        // Update requester with new contact
2676        let mut update = proto::UpdateContacts::default();
2677        if accept {
2678            update
2679                .contacts
2680                .push(contact_for_user(responder_id, responder_busy, &pool));
2681        }
2682        update
2683            .remove_outgoing_requests
2684            .push(responder_id.to_proto());
2685
2686        for connection_id in pool.user_connection_ids(requester_id) {
2687            session.peer.send(connection_id, update.clone())?;
2688        }
2689
2690        send_notifications(&pool, &session.peer, notifications);
2691    }
2692
2693    response.send(proto::Ack {})?;
2694    Ok(())
2695}
2696
2697/// Remove a contact.
2698async fn remove_contact(
2699    request: proto::RemoveContact,
2700    response: Response<proto::RemoveContact>,
2701    session: MessageContext,
2702) -> Result<()> {
2703    let requester_id = session.user_id();
2704    let responder_id = UserId::from_proto(request.user_id);
2705    let db = session.db().await;
2706    let (contact_accepted, deleted_notification_id) =
2707        db.remove_contact(requester_id, responder_id).await?;
2708
2709    let pool = session.connection_pool().await;
2710    // Update outgoing contact requests of requester
2711    let mut update = proto::UpdateContacts::default();
2712    if contact_accepted {
2713        update.remove_contacts.push(responder_id.to_proto());
2714    } else {
2715        update
2716            .remove_outgoing_requests
2717            .push(responder_id.to_proto());
2718    }
2719    for connection_id in pool.user_connection_ids(requester_id) {
2720        session.peer.send(connection_id, update.clone())?;
2721    }
2722
2723    // Update incoming contact requests of responder
2724    let mut update = proto::UpdateContacts::default();
2725    if contact_accepted {
2726        update.remove_contacts.push(requester_id.to_proto());
2727    } else {
2728        update
2729            .remove_incoming_requests
2730            .push(requester_id.to_proto());
2731    }
2732    for connection_id in pool.user_connection_ids(responder_id) {
2733        session.peer.send(connection_id, update.clone())?;
2734        if let Some(notification_id) = deleted_notification_id {
2735            session.peer.send(
2736                connection_id,
2737                proto::DeleteNotification {
2738                    notification_id: notification_id.to_proto(),
2739                },
2740            )?;
2741        }
2742    }
2743
2744    response.send(proto::Ack {})?;
2745    Ok(())
2746}
2747
2748fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2749    version.0.minor < 139
2750}
2751
2752async fn subscribe_to_channels(
2753    _: proto::SubscribeToChannels,
2754    session: MessageContext,
2755) -> Result<()> {
2756    subscribe_user_to_channels(session.user_id(), &session).await?;
2757    Ok(())
2758}
2759
2760async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2761    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2762    let mut pool = session.connection_pool().await;
2763    for membership in &channels_for_user.channel_memberships {
2764        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2765    }
2766    session.peer.send(
2767        session.connection_id,
2768        build_update_user_channels(&channels_for_user),
2769    )?;
2770    session.peer.send(
2771        session.connection_id,
2772        build_channels_update(channels_for_user),
2773    )?;
2774    Ok(())
2775}
2776
2777/// Creates a new channel.
2778async fn create_channel(
2779    request: proto::CreateChannel,
2780    response: Response<proto::CreateChannel>,
2781    session: MessageContext,
2782) -> Result<()> {
2783    let db = session.db().await;
2784
2785    let parent_id = request.parent_id.map(ChannelId::from_proto);
2786    let (channel, membership) = db
2787        .create_channel(&request.name, parent_id, session.user_id())
2788        .await?;
2789
2790    let root_id = channel.root_id();
2791    let channel = Channel::from_model(channel);
2792
2793    response.send(proto::CreateChannelResponse {
2794        channel: Some(channel.to_proto()),
2795        parent_id: request.parent_id,
2796    })?;
2797
2798    let mut connection_pool = session.connection_pool().await;
2799    if let Some(membership) = membership {
2800        connection_pool.subscribe_to_channel(
2801            membership.user_id,
2802            membership.channel_id,
2803            membership.role,
2804        );
2805        let update = proto::UpdateUserChannels {
2806            channel_memberships: vec![proto::ChannelMembership {
2807                channel_id: membership.channel_id.to_proto(),
2808                role: membership.role.into(),
2809            }],
2810            ..Default::default()
2811        };
2812        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2813            session.peer.send(connection_id, update.clone())?;
2814        }
2815    }
2816
2817    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2818        if !role.can_see_channel(channel.visibility) {
2819            continue;
2820        }
2821
2822        let update = proto::UpdateChannels {
2823            channels: vec![channel.to_proto()],
2824            ..Default::default()
2825        };
2826        session.peer.send(connection_id, update.clone())?;
2827    }
2828
2829    Ok(())
2830}
2831
2832/// Delete a channel
2833async fn delete_channel(
2834    request: proto::DeleteChannel,
2835    response: Response<proto::DeleteChannel>,
2836    session: MessageContext,
2837) -> Result<()> {
2838    let db = session.db().await;
2839
2840    let channel_id = request.channel_id;
2841    let (root_channel, removed_channels) = db
2842        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2843        .await?;
2844    response.send(proto::Ack {})?;
2845
2846    // Notify members of removed channels
2847    let mut update = proto::UpdateChannels::default();
2848    update
2849        .delete_channels
2850        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2851
2852    let connection_pool = session.connection_pool().await;
2853    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2854        session.peer.send(connection_id, update.clone())?;
2855    }
2856
2857    Ok(())
2858}
2859
2860/// Invite someone to join a channel.
2861async fn invite_channel_member(
2862    request: proto::InviteChannelMember,
2863    response: Response<proto::InviteChannelMember>,
2864    session: MessageContext,
2865) -> Result<()> {
2866    let db = session.db().await;
2867    let channel_id = ChannelId::from_proto(request.channel_id);
2868    let invitee_id = UserId::from_proto(request.user_id);
2869    let InviteMemberResult {
2870        channel,
2871        notifications,
2872    } = db
2873        .invite_channel_member(
2874            channel_id,
2875            invitee_id,
2876            session.user_id(),
2877            request.role().into(),
2878        )
2879        .await?;
2880
2881    let update = proto::UpdateChannels {
2882        channel_invitations: vec![channel.to_proto()],
2883        ..Default::default()
2884    };
2885
2886    let connection_pool = session.connection_pool().await;
2887    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2888        session.peer.send(connection_id, update.clone())?;
2889    }
2890
2891    send_notifications(&connection_pool, &session.peer, notifications);
2892
2893    response.send(proto::Ack {})?;
2894    Ok(())
2895}
2896
2897/// remove someone from a channel
2898async fn remove_channel_member(
2899    request: proto::RemoveChannelMember,
2900    response: Response<proto::RemoveChannelMember>,
2901    session: MessageContext,
2902) -> Result<()> {
2903    let db = session.db().await;
2904    let channel_id = ChannelId::from_proto(request.channel_id);
2905    let member_id = UserId::from_proto(request.user_id);
2906
2907    let RemoveChannelMemberResult {
2908        membership_update,
2909        notification_id,
2910    } = db
2911        .remove_channel_member(channel_id, member_id, session.user_id())
2912        .await?;
2913
2914    let mut connection_pool = session.connection_pool().await;
2915    notify_membership_updated(
2916        &mut connection_pool,
2917        membership_update,
2918        member_id,
2919        &session.peer,
2920    );
2921    for connection_id in connection_pool.user_connection_ids(member_id) {
2922        if let Some(notification_id) = notification_id {
2923            session
2924                .peer
2925                .send(
2926                    connection_id,
2927                    proto::DeleteNotification {
2928                        notification_id: notification_id.to_proto(),
2929                    },
2930                )
2931                .trace_err();
2932        }
2933    }
2934
2935    response.send(proto::Ack {})?;
2936    Ok(())
2937}
2938
2939/// Toggle the channel between public and private.
2940/// Care is taken to maintain the invariant that public channels only descend from public channels,
2941/// (though members-only channels can appear at any point in the hierarchy).
2942async fn set_channel_visibility(
2943    request: proto::SetChannelVisibility,
2944    response: Response<proto::SetChannelVisibility>,
2945    session: MessageContext,
2946) -> Result<()> {
2947    let db = session.db().await;
2948    let channel_id = ChannelId::from_proto(request.channel_id);
2949    let visibility = request.visibility().into();
2950
2951    let channel_model = db
2952        .set_channel_visibility(channel_id, visibility, session.user_id())
2953        .await?;
2954    let root_id = channel_model.root_id();
2955    let channel = Channel::from_model(channel_model);
2956
2957    let mut connection_pool = session.connection_pool().await;
2958    for (user_id, role) in connection_pool
2959        .channel_user_ids(root_id)
2960        .collect::<Vec<_>>()
2961        .into_iter()
2962    {
2963        let update = if role.can_see_channel(channel.visibility) {
2964            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2965            proto::UpdateChannels {
2966                channels: vec![channel.to_proto()],
2967                ..Default::default()
2968            }
2969        } else {
2970            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2971            proto::UpdateChannels {
2972                delete_channels: vec![channel.id.to_proto()],
2973                ..Default::default()
2974            }
2975        };
2976
2977        for connection_id in connection_pool.user_connection_ids(user_id) {
2978            session.peer.send(connection_id, update.clone())?;
2979        }
2980    }
2981
2982    response.send(proto::Ack {})?;
2983    Ok(())
2984}
2985
2986/// Alter the role for a user in the channel.
2987async fn set_channel_member_role(
2988    request: proto::SetChannelMemberRole,
2989    response: Response<proto::SetChannelMemberRole>,
2990    session: MessageContext,
2991) -> Result<()> {
2992    let db = session.db().await;
2993    let channel_id = ChannelId::from_proto(request.channel_id);
2994    let member_id = UserId::from_proto(request.user_id);
2995    let result = db
2996        .set_channel_member_role(
2997            channel_id,
2998            session.user_id(),
2999            member_id,
3000            request.role().into(),
3001        )
3002        .await?;
3003
3004    match result {
3005        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3006            let mut connection_pool = session.connection_pool().await;
3007            notify_membership_updated(
3008                &mut connection_pool,
3009                membership_update,
3010                member_id,
3011                &session.peer,
3012            )
3013        }
3014        db::SetMemberRoleResult::InviteUpdated(channel) => {
3015            let update = proto::UpdateChannels {
3016                channel_invitations: vec![channel.to_proto()],
3017                ..Default::default()
3018            };
3019
3020            for connection_id in session
3021                .connection_pool()
3022                .await
3023                .user_connection_ids(member_id)
3024            {
3025                session.peer.send(connection_id, update.clone())?;
3026            }
3027        }
3028    }
3029
3030    response.send(proto::Ack {})?;
3031    Ok(())
3032}
3033
3034/// Change the name of a channel
3035async fn rename_channel(
3036    request: proto::RenameChannel,
3037    response: Response<proto::RenameChannel>,
3038    session: MessageContext,
3039) -> Result<()> {
3040    let db = session.db().await;
3041    let channel_id = ChannelId::from_proto(request.channel_id);
3042    let channel_model = db
3043        .rename_channel(channel_id, session.user_id(), &request.name)
3044        .await?;
3045    let root_id = channel_model.root_id();
3046    let channel = Channel::from_model(channel_model);
3047
3048    response.send(proto::RenameChannelResponse {
3049        channel: Some(channel.to_proto()),
3050    })?;
3051
3052    let connection_pool = session.connection_pool().await;
3053    let update = proto::UpdateChannels {
3054        channels: vec![channel.to_proto()],
3055        ..Default::default()
3056    };
3057    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3058        if role.can_see_channel(channel.visibility) {
3059            session.peer.send(connection_id, update.clone())?;
3060        }
3061    }
3062
3063    Ok(())
3064}
3065
3066/// Move a channel to a new parent.
3067async fn move_channel(
3068    request: proto::MoveChannel,
3069    response: Response<proto::MoveChannel>,
3070    session: MessageContext,
3071) -> Result<()> {
3072    let channel_id = ChannelId::from_proto(request.channel_id);
3073    let to = ChannelId::from_proto(request.to);
3074
3075    let (root_id, channels) = session
3076        .db()
3077        .await
3078        .move_channel(channel_id, to, session.user_id())
3079        .await?;
3080
3081    let connection_pool = session.connection_pool().await;
3082    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3083        let channels = channels
3084            .iter()
3085            .filter_map(|channel| {
3086                if role.can_see_channel(channel.visibility) {
3087                    Some(channel.to_proto())
3088                } else {
3089                    None
3090                }
3091            })
3092            .collect::<Vec<_>>();
3093        if channels.is_empty() {
3094            continue;
3095        }
3096
3097        let update = proto::UpdateChannels {
3098            channels,
3099            ..Default::default()
3100        };
3101
3102        session.peer.send(connection_id, update.clone())?;
3103    }
3104
3105    response.send(Ack {})?;
3106    Ok(())
3107}
3108
3109async fn reorder_channel(
3110    request: proto::ReorderChannel,
3111    response: Response<proto::ReorderChannel>,
3112    session: MessageContext,
3113) -> Result<()> {
3114    let channel_id = ChannelId::from_proto(request.channel_id);
3115    let direction = request.direction();
3116
3117    let updated_channels = session
3118        .db()
3119        .await
3120        .reorder_channel(channel_id, direction, session.user_id())
3121        .await?;
3122
3123    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3124        let connection_pool = session.connection_pool().await;
3125        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3126            let channels = updated_channels
3127                .iter()
3128                .filter_map(|channel| {
3129                    if role.can_see_channel(channel.visibility) {
3130                        Some(channel.to_proto())
3131                    } else {
3132                        None
3133                    }
3134                })
3135                .collect::<Vec<_>>();
3136
3137            if channels.is_empty() {
3138                continue;
3139            }
3140
3141            let update = proto::UpdateChannels {
3142                channels,
3143                ..Default::default()
3144            };
3145
3146            session.peer.send(connection_id, update.clone())?;
3147        }
3148    }
3149
3150    response.send(Ack {})?;
3151    Ok(())
3152}
3153
3154/// Get the list of channel members
3155async fn get_channel_members(
3156    request: proto::GetChannelMembers,
3157    response: Response<proto::GetChannelMembers>,
3158    session: MessageContext,
3159) -> Result<()> {
3160    let db = session.db().await;
3161    let channel_id = ChannelId::from_proto(request.channel_id);
3162    let limit = if request.limit == 0 {
3163        u16::MAX as u64
3164    } else {
3165        request.limit
3166    };
3167    let (members, users) = db
3168        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3169        .await?;
3170    response.send(proto::GetChannelMembersResponse { members, users })?;
3171    Ok(())
3172}
3173
3174/// Accept or decline a channel invitation.
3175async fn respond_to_channel_invite(
3176    request: proto::RespondToChannelInvite,
3177    response: Response<proto::RespondToChannelInvite>,
3178    session: MessageContext,
3179) -> Result<()> {
3180    let db = session.db().await;
3181    let channel_id = ChannelId::from_proto(request.channel_id);
3182    let RespondToChannelInvite {
3183        membership_update,
3184        notifications,
3185    } = db
3186        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3187        .await?;
3188
3189    let mut connection_pool = session.connection_pool().await;
3190    if let Some(membership_update) = membership_update {
3191        notify_membership_updated(
3192            &mut connection_pool,
3193            membership_update,
3194            session.user_id(),
3195            &session.peer,
3196        );
3197    } else {
3198        let update = proto::UpdateChannels {
3199            remove_channel_invitations: vec![channel_id.to_proto()],
3200            ..Default::default()
3201        };
3202
3203        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3204            session.peer.send(connection_id, update.clone())?;
3205        }
3206    };
3207
3208    send_notifications(&connection_pool, &session.peer, notifications);
3209
3210    response.send(proto::Ack {})?;
3211
3212    Ok(())
3213}
3214
3215/// Join the channels' room
3216async fn join_channel(
3217    request: proto::JoinChannel,
3218    response: Response<proto::JoinChannel>,
3219    session: MessageContext,
3220) -> Result<()> {
3221    let channel_id = ChannelId::from_proto(request.channel_id);
3222    join_channel_internal(channel_id, Box::new(response), session).await
3223}
3224
3225trait JoinChannelInternalResponse {
3226    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3227}
3228impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3229    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3230        Response::<proto::JoinChannel>::send(self, result)
3231    }
3232}
3233impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3234    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3235        Response::<proto::JoinRoom>::send(self, result)
3236    }
3237}
3238
3239async fn join_channel_internal(
3240    channel_id: ChannelId,
3241    response: Box<impl JoinChannelInternalResponse>,
3242    session: MessageContext,
3243) -> Result<()> {
3244    let joined_room = {
3245        let mut db = session.db().await;
3246        // If zed quits without leaving the room, and the user re-opens zed before the
3247        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3248        // room they were in.
3249        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3250            tracing::info!(
3251                stale_connection_id = %connection,
3252                "cleaning up stale connection",
3253            );
3254            drop(db);
3255            leave_room_for_session(&session, connection).await?;
3256            db = session.db().await;
3257        }
3258
3259        let (joined_room, membership_updated, role) = db
3260            .join_channel(channel_id, session.user_id(), session.connection_id)
3261            .await?;
3262
3263        let live_kit_connection_info =
3264            session
3265                .app_state
3266                .livekit_client
3267                .as_ref()
3268                .and_then(|live_kit| {
3269                    let (can_publish, token) = if role == ChannelRole::Guest {
3270                        (
3271                            false,
3272                            live_kit
3273                                .guest_token(
3274                                    &joined_room.room.livekit_room,
3275                                    &session.user_id().to_string(),
3276                                )
3277                                .trace_err()?,
3278                        )
3279                    } else {
3280                        (
3281                            true,
3282                            live_kit
3283                                .room_token(
3284                                    &joined_room.room.livekit_room,
3285                                    &session.user_id().to_string(),
3286                                )
3287                                .trace_err()?,
3288                        )
3289                    };
3290
3291                    Some(LiveKitConnectionInfo {
3292                        server_url: live_kit.url().into(),
3293                        token,
3294                        can_publish,
3295                    })
3296                });
3297
3298        response.send(proto::JoinRoomResponse {
3299            room: Some(joined_room.room.clone()),
3300            channel_id: joined_room
3301                .channel
3302                .as_ref()
3303                .map(|channel| channel.id.to_proto()),
3304            live_kit_connection_info,
3305        })?;
3306
3307        let mut connection_pool = session.connection_pool().await;
3308        if let Some(membership_updated) = membership_updated {
3309            notify_membership_updated(
3310                &mut connection_pool,
3311                membership_updated,
3312                session.user_id(),
3313                &session.peer,
3314            );
3315        }
3316
3317        room_updated(&joined_room.room, &session.peer);
3318
3319        joined_room
3320    };
3321
3322    channel_updated(
3323        &joined_room.channel.context("channel not returned")?,
3324        &joined_room.room,
3325        &session.peer,
3326        &*session.connection_pool().await,
3327    );
3328
3329    update_user_contacts(session.user_id(), &session).await?;
3330    Ok(())
3331}
3332
3333/// Start editing the channel notes
3334async fn join_channel_buffer(
3335    request: proto::JoinChannelBuffer,
3336    response: Response<proto::JoinChannelBuffer>,
3337    session: MessageContext,
3338) -> Result<()> {
3339    let db = session.db().await;
3340    let channel_id = ChannelId::from_proto(request.channel_id);
3341
3342    let open_response = db
3343        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3344        .await?;
3345
3346    let collaborators = open_response.collaborators.clone();
3347    response.send(open_response)?;
3348
3349    let update = UpdateChannelBufferCollaborators {
3350        channel_id: channel_id.to_proto(),
3351        collaborators: collaborators.clone(),
3352    };
3353    channel_buffer_updated(
3354        session.connection_id,
3355        collaborators
3356            .iter()
3357            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3358        &update,
3359        &session.peer,
3360    );
3361
3362    Ok(())
3363}
3364
3365/// Edit the channel notes
3366async fn update_channel_buffer(
3367    request: proto::UpdateChannelBuffer,
3368    session: MessageContext,
3369) -> Result<()> {
3370    let db = session.db().await;
3371    let channel_id = ChannelId::from_proto(request.channel_id);
3372
3373    let (collaborators, epoch, version) = db
3374        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3375        .await?;
3376
3377    channel_buffer_updated(
3378        session.connection_id,
3379        collaborators.clone(),
3380        &proto::UpdateChannelBuffer {
3381            channel_id: channel_id.to_proto(),
3382            operations: request.operations,
3383        },
3384        &session.peer,
3385    );
3386
3387    let pool = &*session.connection_pool().await;
3388
3389    let non_collaborators =
3390        pool.channel_connection_ids(channel_id)
3391            .filter_map(|(connection_id, _)| {
3392                if collaborators.contains(&connection_id) {
3393                    None
3394                } else {
3395                    Some(connection_id)
3396                }
3397            });
3398
3399    broadcast(None, non_collaborators, |peer_id| {
3400        session.peer.send(
3401            peer_id,
3402            proto::UpdateChannels {
3403                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3404                    channel_id: channel_id.to_proto(),
3405                    epoch: epoch as u64,
3406                    version: version.clone(),
3407                }],
3408                ..Default::default()
3409            },
3410        )
3411    });
3412
3413    Ok(())
3414}
3415
3416/// Rejoin the channel notes after a connection blip
3417async fn rejoin_channel_buffers(
3418    request: proto::RejoinChannelBuffers,
3419    response: Response<proto::RejoinChannelBuffers>,
3420    session: MessageContext,
3421) -> Result<()> {
3422    let db = session.db().await;
3423    let buffers = db
3424        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3425        .await?;
3426
3427    for rejoined_buffer in &buffers {
3428        let collaborators_to_notify = rejoined_buffer
3429            .buffer
3430            .collaborators
3431            .iter()
3432            .filter_map(|c| Some(c.peer_id?.into()));
3433        channel_buffer_updated(
3434            session.connection_id,
3435            collaborators_to_notify,
3436            &proto::UpdateChannelBufferCollaborators {
3437                channel_id: rejoined_buffer.buffer.channel_id,
3438                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3439            },
3440            &session.peer,
3441        );
3442    }
3443
3444    response.send(proto::RejoinChannelBuffersResponse {
3445        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3446    })?;
3447
3448    Ok(())
3449}
3450
3451/// Stop editing the channel notes
3452async fn leave_channel_buffer(
3453    request: proto::LeaveChannelBuffer,
3454    response: Response<proto::LeaveChannelBuffer>,
3455    session: MessageContext,
3456) -> Result<()> {
3457    let db = session.db().await;
3458    let channel_id = ChannelId::from_proto(request.channel_id);
3459
3460    let left_buffer = db
3461        .leave_channel_buffer(channel_id, session.connection_id)
3462        .await?;
3463
3464    response.send(Ack {})?;
3465
3466    channel_buffer_updated(
3467        session.connection_id,
3468        left_buffer.connections,
3469        &proto::UpdateChannelBufferCollaborators {
3470            channel_id: channel_id.to_proto(),
3471            collaborators: left_buffer.collaborators,
3472        },
3473        &session.peer,
3474    );
3475
3476    Ok(())
3477}
3478
3479fn channel_buffer_updated<T: EnvelopedMessage>(
3480    sender_id: ConnectionId,
3481    collaborators: impl IntoIterator<Item = ConnectionId>,
3482    message: &T,
3483    peer: &Peer,
3484) {
3485    broadcast(Some(sender_id), collaborators, |peer_id| {
3486        peer.send(peer_id, message.clone())
3487    });
3488}
3489
3490fn send_notifications(
3491    connection_pool: &ConnectionPool,
3492    peer: &Peer,
3493    notifications: db::NotificationBatch,
3494) {
3495    for (user_id, notification) in notifications {
3496        for connection_id in connection_pool.user_connection_ids(user_id) {
3497            if let Err(error) = peer.send(
3498                connection_id,
3499                proto::AddNotification {
3500                    notification: Some(notification.clone()),
3501                },
3502            ) {
3503                tracing::error!(
3504                    "failed to send notification to {:?} {}",
3505                    connection_id,
3506                    error
3507                );
3508            }
3509        }
3510    }
3511}
3512
3513/// Send a message to the channel
3514async fn send_channel_message(
3515    _request: proto::SendChannelMessage,
3516    _response: Response<proto::SendChannelMessage>,
3517    _session: MessageContext,
3518) -> Result<()> {
3519    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3520}
3521
3522/// Delete a channel message
3523async fn remove_channel_message(
3524    _request: proto::RemoveChannelMessage,
3525    _response: Response<proto::RemoveChannelMessage>,
3526    _session: MessageContext,
3527) -> Result<()> {
3528    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3529}
3530
3531async fn update_channel_message(
3532    _request: proto::UpdateChannelMessage,
3533    _response: Response<proto::UpdateChannelMessage>,
3534    _session: MessageContext,
3535) -> Result<()> {
3536    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3537}
3538
3539/// Mark a channel message as read
3540async fn acknowledge_channel_message(
3541    _request: proto::AckChannelMessage,
3542    _session: MessageContext,
3543) -> Result<()> {
3544    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3545}
3546
3547/// Mark a buffer version as synced
3548async fn acknowledge_buffer_version(
3549    request: proto::AckBufferOperation,
3550    session: MessageContext,
3551) -> Result<()> {
3552    let buffer_id = BufferId::from_proto(request.buffer_id);
3553    session
3554        .db()
3555        .await
3556        .observe_buffer_version(
3557            buffer_id,
3558            session.user_id(),
3559            request.epoch as i32,
3560            &request.version,
3561        )
3562        .await?;
3563    Ok(())
3564}
3565
3566/// Start receiving chat updates for a channel
3567async fn join_channel_chat(
3568    _request: proto::JoinChannelChat,
3569    _response: Response<proto::JoinChannelChat>,
3570    _session: MessageContext,
3571) -> Result<()> {
3572    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3573}
3574
3575/// Stop receiving chat updates for a channel
3576async fn leave_channel_chat(
3577    _request: proto::LeaveChannelChat,
3578    _session: MessageContext,
3579) -> Result<()> {
3580    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3581}
3582
3583/// Retrieve the chat history for a channel
3584async fn get_channel_messages(
3585    _request: proto::GetChannelMessages,
3586    _response: Response<proto::GetChannelMessages>,
3587    _session: MessageContext,
3588) -> Result<()> {
3589    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3590}
3591
3592/// Retrieve specific chat messages
3593async fn get_channel_messages_by_id(
3594    _request: proto::GetChannelMessagesById,
3595    _response: Response<proto::GetChannelMessagesById>,
3596    _session: MessageContext,
3597) -> Result<()> {
3598    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3599}
3600
3601/// Retrieve the current users notifications
3602async fn get_notifications(
3603    request: proto::GetNotifications,
3604    response: Response<proto::GetNotifications>,
3605    session: MessageContext,
3606) -> Result<()> {
3607    let notifications = session
3608        .db()
3609        .await
3610        .get_notifications(
3611            session.user_id(),
3612            NOTIFICATION_COUNT_PER_PAGE,
3613            request.before_id.map(db::NotificationId::from_proto),
3614        )
3615        .await?;
3616    response.send(proto::GetNotificationsResponse {
3617        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3618        notifications,
3619    })?;
3620    Ok(())
3621}
3622
3623/// Mark notifications as read
3624async fn mark_notification_as_read(
3625    request: proto::MarkNotificationRead,
3626    response: Response<proto::MarkNotificationRead>,
3627    session: MessageContext,
3628) -> Result<()> {
3629    let database = &session.db().await;
3630    let notifications = database
3631        .mark_notification_as_read_by_id(
3632            session.user_id(),
3633            NotificationId::from_proto(request.notification_id),
3634        )
3635        .await?;
3636    send_notifications(
3637        &*session.connection_pool().await,
3638        &session.peer,
3639        notifications,
3640    );
3641    response.send(proto::Ack {})?;
3642    Ok(())
3643}
3644
3645fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3646    let message = match message {
3647        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3648        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3649        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3650        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3651        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3652            code: frame.code.into(),
3653            reason: frame.reason.as_str().to_owned().into(),
3654        })),
3655        // We should never receive a frame while reading the message, according
3656        // to the `tungstenite` maintainers:
3657        //
3658        // > It cannot occur when you read messages from the WebSocket, but it
3659        // > can be used when you want to send the raw frames (e.g. you want to
3660        // > send the frames to the WebSocket without composing the full message first).
3661        // >
3662        // > — https://github.com/snapview/tungstenite-rs/issues/268
3663        TungsteniteMessage::Frame(_) => {
3664            bail!("received an unexpected frame while reading the message")
3665        }
3666    };
3667
3668    Ok(message)
3669}
3670
3671fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3672    match message {
3673        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3674        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3675        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3676        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3677        AxumMessage::Close(frame) => {
3678            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3679                code: frame.code.into(),
3680                reason: frame.reason.as_ref().into(),
3681            }))
3682        }
3683    }
3684}
3685
3686fn notify_membership_updated(
3687    connection_pool: &mut ConnectionPool,
3688    result: MembershipUpdated,
3689    user_id: UserId,
3690    peer: &Peer,
3691) {
3692    for membership in &result.new_channels.channel_memberships {
3693        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3694    }
3695    for channel_id in &result.removed_channels {
3696        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3697    }
3698
3699    let user_channels_update = proto::UpdateUserChannels {
3700        channel_memberships: result
3701            .new_channels
3702            .channel_memberships
3703            .iter()
3704            .map(|cm| proto::ChannelMembership {
3705                channel_id: cm.channel_id.to_proto(),
3706                role: cm.role.into(),
3707            })
3708            .collect(),
3709        ..Default::default()
3710    };
3711
3712    let mut update = build_channels_update(result.new_channels);
3713    update.delete_channels = result
3714        .removed_channels
3715        .into_iter()
3716        .map(|id| id.to_proto())
3717        .collect();
3718    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3719
3720    for connection_id in connection_pool.user_connection_ids(user_id) {
3721        peer.send(connection_id, user_channels_update.clone())
3722            .trace_err();
3723        peer.send(connection_id, update.clone()).trace_err();
3724    }
3725}
3726
3727fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3728    proto::UpdateUserChannels {
3729        channel_memberships: channels
3730            .channel_memberships
3731            .iter()
3732            .map(|m| proto::ChannelMembership {
3733                channel_id: m.channel_id.to_proto(),
3734                role: m.role.into(),
3735            })
3736            .collect(),
3737        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3738    }
3739}
3740
3741fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3742    let mut update = proto::UpdateChannels::default();
3743
3744    for channel in channels.channels {
3745        update.channels.push(channel.to_proto());
3746    }
3747
3748    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3749
3750    for (channel_id, participants) in channels.channel_participants {
3751        update
3752            .channel_participants
3753            .push(proto::ChannelParticipants {
3754                channel_id: channel_id.to_proto(),
3755                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3756            });
3757    }
3758
3759    for channel in channels.invited_channels {
3760        update.channel_invitations.push(channel.to_proto());
3761    }
3762
3763    update
3764}
3765
3766fn build_initial_contacts_update(
3767    contacts: Vec<db::Contact>,
3768    pool: &ConnectionPool,
3769) -> proto::UpdateContacts {
3770    let mut update = proto::UpdateContacts::default();
3771
3772    for contact in contacts {
3773        match contact {
3774            db::Contact::Accepted { user_id, busy } => {
3775                update.contacts.push(contact_for_user(user_id, busy, pool));
3776            }
3777            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3778            db::Contact::Incoming { user_id } => {
3779                update
3780                    .incoming_requests
3781                    .push(proto::IncomingContactRequest {
3782                        requester_id: user_id.to_proto(),
3783                    })
3784            }
3785        }
3786    }
3787
3788    update
3789}
3790
3791fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3792    proto::Contact {
3793        user_id: user_id.to_proto(),
3794        online: pool.is_user_online(user_id),
3795        busy,
3796    }
3797}
3798
3799fn room_updated(room: &proto::Room, peer: &Peer) {
3800    broadcast(
3801        None,
3802        room.participants
3803            .iter()
3804            .filter_map(|participant| Some(participant.peer_id?.into())),
3805        |peer_id| {
3806            peer.send(
3807                peer_id,
3808                proto::RoomUpdated {
3809                    room: Some(room.clone()),
3810                },
3811            )
3812        },
3813    );
3814}
3815
3816fn channel_updated(
3817    channel: &db::channel::Model,
3818    room: &proto::Room,
3819    peer: &Peer,
3820    pool: &ConnectionPool,
3821) {
3822    let participants = room
3823        .participants
3824        .iter()
3825        .map(|p| p.user_id)
3826        .collect::<Vec<_>>();
3827
3828    broadcast(
3829        None,
3830        pool.channel_connection_ids(channel.root_id())
3831            .filter_map(|(channel_id, role)| {
3832                role.can_see_channel(channel.visibility)
3833                    .then_some(channel_id)
3834            }),
3835        |peer_id| {
3836            peer.send(
3837                peer_id,
3838                proto::UpdateChannels {
3839                    channel_participants: vec![proto::ChannelParticipants {
3840                        channel_id: channel.id.to_proto(),
3841                        participant_user_ids: participants.clone(),
3842                    }],
3843                    ..Default::default()
3844                },
3845            )
3846        },
3847    );
3848}
3849
3850async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3851    let db = session.db().await;
3852
3853    let contacts = db.get_contacts(user_id).await?;
3854    let busy = db.is_user_busy(user_id).await?;
3855
3856    let pool = session.connection_pool().await;
3857    let updated_contact = contact_for_user(user_id, busy, &pool);
3858    for contact in contacts {
3859        if let db::Contact::Accepted {
3860            user_id: contact_user_id,
3861            ..
3862        } = contact
3863        {
3864            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3865                session
3866                    .peer
3867                    .send(
3868                        contact_conn_id,
3869                        proto::UpdateContacts {
3870                            contacts: vec![updated_contact.clone()],
3871                            remove_contacts: Default::default(),
3872                            incoming_requests: Default::default(),
3873                            remove_incoming_requests: Default::default(),
3874                            outgoing_requests: Default::default(),
3875                            remove_outgoing_requests: Default::default(),
3876                        },
3877                    )
3878                    .trace_err();
3879            }
3880        }
3881    }
3882    Ok(())
3883}
3884
3885async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3886    let mut contacts_to_update = HashSet::default();
3887
3888    let room_id;
3889    let canceled_calls_to_user_ids;
3890    let livekit_room;
3891    let delete_livekit_room;
3892    let room;
3893    let channel;
3894
3895    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3896        contacts_to_update.insert(session.user_id());
3897
3898        for project in left_room.left_projects.values() {
3899            project_left(project, session);
3900        }
3901
3902        room_id = RoomId::from_proto(left_room.room.id);
3903        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3904        livekit_room = mem::take(&mut left_room.room.livekit_room);
3905        delete_livekit_room = left_room.deleted;
3906        room = mem::take(&mut left_room.room);
3907        channel = mem::take(&mut left_room.channel);
3908
3909        room_updated(&room, &session.peer);
3910    } else {
3911        return Ok(());
3912    }
3913
3914    if let Some(channel) = channel {
3915        channel_updated(
3916            &channel,
3917            &room,
3918            &session.peer,
3919            &*session.connection_pool().await,
3920        );
3921    }
3922
3923    {
3924        let pool = session.connection_pool().await;
3925        for canceled_user_id in canceled_calls_to_user_ids {
3926            for connection_id in pool.user_connection_ids(canceled_user_id) {
3927                session
3928                    .peer
3929                    .send(
3930                        connection_id,
3931                        proto::CallCanceled {
3932                            room_id: room_id.to_proto(),
3933                        },
3934                    )
3935                    .trace_err();
3936            }
3937            contacts_to_update.insert(canceled_user_id);
3938        }
3939    }
3940
3941    for contact_user_id in contacts_to_update {
3942        update_user_contacts(contact_user_id, session).await?;
3943    }
3944
3945    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3946        live_kit
3947            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3948            .await
3949            .trace_err();
3950
3951        if delete_livekit_room {
3952            live_kit.delete_room(livekit_room).await.trace_err();
3953        }
3954    }
3955
3956    Ok(())
3957}
3958
3959async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3960    let left_channel_buffers = session
3961        .db()
3962        .await
3963        .leave_channel_buffers(session.connection_id)
3964        .await?;
3965
3966    for left_buffer in left_channel_buffers {
3967        channel_buffer_updated(
3968            session.connection_id,
3969            left_buffer.connections,
3970            &proto::UpdateChannelBufferCollaborators {
3971                channel_id: left_buffer.channel_id.to_proto(),
3972                collaborators: left_buffer.collaborators,
3973            },
3974            &session.peer,
3975        );
3976    }
3977
3978    Ok(())
3979}
3980
3981fn project_left(project: &db::LeftProject, session: &Session) {
3982    for connection_id in &project.connection_ids {
3983        if project.should_unshare {
3984            session
3985                .peer
3986                .send(
3987                    *connection_id,
3988                    proto::UnshareProject {
3989                        project_id: project.id.to_proto(),
3990                    },
3991                )
3992                .trace_err();
3993        } else {
3994            session
3995                .peer
3996                .send(
3997                    *connection_id,
3998                    proto::RemoveProjectCollaborator {
3999                        project_id: project.id.to_proto(),
4000                        peer_id: Some(session.connection_id.into()),
4001                    },
4002                )
4003                .trace_err();
4004        }
4005    }
4006}
4007
4008async fn share_agent_thread(
4009    request: proto::ShareAgentThread,
4010    response: Response<proto::ShareAgentThread>,
4011    session: MessageContext,
4012) -> Result<()> {
4013    let user_id = session.user_id();
4014
4015    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4016        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4017
4018    session
4019        .db()
4020        .await
4021        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4022        .await?;
4023
4024    response.send(proto::Ack {})?;
4025
4026    Ok(())
4027}
4028
4029async fn get_shared_agent_thread(
4030    request: proto::GetSharedAgentThread,
4031    response: Response<proto::GetSharedAgentThread>,
4032    session: MessageContext,
4033) -> Result<()> {
4034    let share_id = SharedThreadId::from_proto(request.session_id)
4035        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4036
4037    let result = session.db().await.get_shared_thread(share_id).await?;
4038
4039    match result {
4040        Some((thread, username)) => {
4041            response.send(proto::GetSharedAgentThreadResponse {
4042                title: thread.title,
4043                thread_data: thread.data,
4044                sharer_username: username,
4045                created_at: thread.created_at.and_utc().to_rfc3339(),
4046            })?;
4047        }
4048        None => {
4049            return Err(anyhow!("Shared thread not found").into());
4050        }
4051    }
4052
4053    Ok(())
4054}
4055
4056pub trait ResultExt {
4057    type Ok;
4058
4059    fn trace_err(self) -> Option<Self::Ok>;
4060}
4061
4062impl<T, E> ResultExt for Result<T, E>
4063where
4064    E: std::fmt::Debug,
4065{
4066    type Ok = T;
4067
4068    #[track_caller]
4069    fn trace_err(self) -> Option<T> {
4070        match self {
4071            Ok(value) => Some(value),
4072            Err(error) => {
4073                tracing::error!("{:?}", error);
4074                None
4075            }
4076        }
4077    }
4078}