rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 415            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 416            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 417            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 419            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 421            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 422            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 423            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 426            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 427            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 430            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 431            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 437            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
 439            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 440            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 441            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 442            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 443            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 444            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 445            .add_request_handler(share_agent_thread)
 446            .add_request_handler(get_shared_agent_thread)
 447            .add_request_handler(forward_project_search_chunk);
 448
 449        Arc::new(server)
 450    }
 451
 452    pub async fn start(&self) -> Result<()> {
 453        let server_id = *self.id.lock();
 454        let app_state = self.app_state.clone();
 455        let peer = self.peer.clone();
 456        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 457        let pool = self.connection_pool.clone();
 458        let livekit_client = self.app_state.livekit_client.clone();
 459
 460        let span = info_span!("start server");
 461        self.app_state.executor.spawn_detached(
 462            async move {
 463                tracing::info!("waiting for cleanup timeout");
 464                timeout.await;
 465                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 466
 467                app_state
 468                    .db
 469                    .delete_stale_channel_chat_participants(
 470                        &app_state.config.zed_environment,
 471                        server_id,
 472                    )
 473                    .await
 474                    .trace_err();
 475
 476                if let Some((room_ids, channel_ids)) = app_state
 477                    .db
 478                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 479                    .await
 480                    .trace_err()
 481                {
 482                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 483                    tracing::info!(
 484                        stale_channel_buffer_count = channel_ids.len(),
 485                        "retrieved stale channel buffers"
 486                    );
 487
 488                    for channel_id in channel_ids {
 489                        if let Some(refreshed_channel_buffer) = app_state
 490                            .db
 491                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 492                            .await
 493                            .trace_err()
 494                        {
 495                            for connection_id in refreshed_channel_buffer.connection_ids {
 496                                peer.send(
 497                                    connection_id,
 498                                    proto::UpdateChannelBufferCollaborators {
 499                                        channel_id: channel_id.to_proto(),
 500                                        collaborators: refreshed_channel_buffer
 501                                            .collaborators
 502                                            .clone(),
 503                                    },
 504                                )
 505                                .trace_err();
 506                            }
 507                        }
 508                    }
 509
 510                    for room_id in room_ids {
 511                        let mut contacts_to_update = HashSet::default();
 512                        let mut canceled_calls_to_user_ids = Vec::new();
 513                        let mut livekit_room = String::new();
 514                        let mut delete_livekit_room = false;
 515
 516                        if let Some(mut refreshed_room) = app_state
 517                            .db
 518                            .clear_stale_room_participants(room_id, server_id)
 519                            .await
 520                            .trace_err()
 521                        {
 522                            tracing::info!(
 523                                room_id = room_id.0,
 524                                new_participant_count = refreshed_room.room.participants.len(),
 525                                "refreshed room"
 526                            );
 527                            room_updated(&refreshed_room.room, &peer);
 528                            if let Some(channel) = refreshed_room.channel.as_ref() {
 529                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 530                            }
 531                            contacts_to_update
 532                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 533                            contacts_to_update
 534                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 535                            canceled_calls_to_user_ids =
 536                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 537                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 538                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 539                        }
 540
 541                        {
 542                            let pool = pool.lock();
 543                            for canceled_user_id in canceled_calls_to_user_ids {
 544                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 545                                    peer.send(
 546                                        connection_id,
 547                                        proto::CallCanceled {
 548                                            room_id: room_id.to_proto(),
 549                                        },
 550                                    )
 551                                    .trace_err();
 552                                }
 553                            }
 554                        }
 555
 556                        for user_id in contacts_to_update {
 557                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 558                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 559                            if let Some((busy, contacts)) = busy.zip(contacts) {
 560                                let pool = pool.lock();
 561                                let updated_contact = contact_for_user(user_id, busy, &pool);
 562                                for contact in contacts {
 563                                    if let db::Contact::Accepted {
 564                                        user_id: contact_user_id,
 565                                        ..
 566                                    } = contact
 567                                    {
 568                                        for contact_conn_id in
 569                                            pool.user_connection_ids(contact_user_id)
 570                                        {
 571                                            peer.send(
 572                                                contact_conn_id,
 573                                                proto::UpdateContacts {
 574                                                    contacts: vec![updated_contact.clone()],
 575                                                    remove_contacts: Default::default(),
 576                                                    incoming_requests: Default::default(),
 577                                                    remove_incoming_requests: Default::default(),
 578                                                    outgoing_requests: Default::default(),
 579                                                    remove_outgoing_requests: Default::default(),
 580                                                },
 581                                            )
 582                                            .trace_err();
 583                                        }
 584                                    }
 585                                }
 586                            }
 587                        }
 588
 589                        if let Some(live_kit) = livekit_client.as_ref()
 590                            && delete_livekit_room
 591                        {
 592                            live_kit.delete_room(livekit_room).await.trace_err();
 593                        }
 594                    }
 595                }
 596
 597                app_state
 598                    .db
 599                    .delete_stale_channel_chat_participants(
 600                        &app_state.config.zed_environment,
 601                        server_id,
 602                    )
 603                    .await
 604                    .trace_err();
 605
 606                app_state
 607                    .db
 608                    .clear_old_worktree_entries(server_id)
 609                    .await
 610                    .trace_err();
 611
 612                app_state
 613                    .db
 614                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 615                    .await
 616                    .trace_err();
 617            }
 618            .instrument(span),
 619        );
 620        Ok(())
 621    }
 622
 623    pub fn teardown(&self) {
 624        self.peer.teardown();
 625        self.connection_pool.lock().reset();
 626        let _ = self.teardown.send(true);
 627    }
 628
 629    #[cfg(feature = "test-support")]
 630    pub fn reset(&self, id: ServerId) {
 631        self.teardown();
 632        *self.id.lock() = id;
 633        self.peer.reset(id.0 as u32);
 634        let _ = self.teardown.send(false);
 635    }
 636
 637    #[cfg(feature = "test-support")]
 638    pub fn id(&self) -> ServerId {
 639        *self.id.lock()
 640    }
 641
 642    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 643    where
 644        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 645        Fut: 'static + Send + Future<Output = Result<()>>,
 646        M: EnvelopedMessage,
 647    {
 648        let prev_handler = self.handlers.insert(
 649            TypeId::of::<M>(),
 650            Box::new(move |envelope, session, span| {
 651                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 652                let received_at = envelope.received_at;
 653                tracing::info!("message received");
 654                let start_time = Instant::now();
 655                let future = (handler)(
 656                    *envelope,
 657                    MessageContext {
 658                        session,
 659                        span: span.clone(),
 660                    },
 661                );
 662                async move {
 663                    let result = future.await;
 664                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 665                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 666                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 667                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 668                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 669                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 670                    match result {
 671                        Err(error) => {
 672                            tracing::error!(?error, "error handling message")
 673                        }
 674                        Ok(()) => tracing::info!("finished handling message"),
 675                    }
 676                }
 677                .boxed()
 678            }),
 679        );
 680        if prev_handler.is_some() {
 681            panic!("registered a handler for the same message twice");
 682        }
 683        self
 684    }
 685
 686    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 687    where
 688        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 689        Fut: 'static + Send + Future<Output = Result<()>>,
 690        M: EnvelopedMessage,
 691    {
 692        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 693        self
 694    }
 695
 696    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 697    where
 698        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 699        Fut: Send + Future<Output = Result<()>>,
 700        M: RequestMessage,
 701    {
 702        let handler = Arc::new(handler);
 703        self.add_handler(move |envelope, session| {
 704            let receipt = envelope.receipt();
 705            let handler = handler.clone();
 706            async move {
 707                let peer = session.peer.clone();
 708                let responded = Arc::new(AtomicBool::default());
 709                let response = Response {
 710                    peer: peer.clone(),
 711                    responded: responded.clone(),
 712                    receipt,
 713                };
 714                match (handler)(envelope.payload, response, session).await {
 715                    Ok(()) => {
 716                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 717                            Ok(())
 718                        } else {
 719                            Err(anyhow!("handler did not send a response"))?
 720                        }
 721                    }
 722                    Err(error) => {
 723                        let proto_err = match &error {
 724                            Error::Internal(err) => err.to_proto(),
 725                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 726                        };
 727                        peer.respond_with_error(receipt, proto_err)?;
 728                        Err(error)
 729                    }
 730                }
 731            }
 732        })
 733    }
 734
 735    pub fn handle_connection(
 736        self: &Arc<Self>,
 737        connection: Connection,
 738        address: String,
 739        principal: Principal,
 740        zed_version: ZedVersion,
 741        release_channel: Option<String>,
 742        user_agent: Option<String>,
 743        geoip_country_code: Option<String>,
 744        system_id: Option<String>,
 745        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 746        executor: Executor,
 747        connection_guard: Option<ConnectionGuard>,
 748    ) -> impl Future<Output = ()> + use<> {
 749        let this = self.clone();
 750        let span = info_span!("handle connection", %address,
 751            connection_id=field::Empty,
 752            user_id=field::Empty,
 753            login=field::Empty,
 754            user_agent=field::Empty,
 755            geoip_country_code=field::Empty,
 756            release_channel=field::Empty,
 757        );
 758        principal.update_span(&span);
 759        if let Some(user_agent) = user_agent {
 760            span.record("user_agent", user_agent);
 761        }
 762        if let Some(release_channel) = release_channel {
 763            span.record("release_channel", release_channel);
 764        }
 765
 766        if let Some(country_code) = geoip_country_code.as_ref() {
 767            span.record("geoip_country_code", country_code);
 768        }
 769
 770        let mut teardown = self.teardown.subscribe();
 771        async move {
 772            if *teardown.borrow() {
 773                tracing::error!("server is tearing down");
 774                return;
 775            }
 776
 777            let (connection_id, handle_io, mut incoming_rx) =
 778                this.peer.add_connection(connection, {
 779                    let executor = executor.clone();
 780                    move |duration| executor.sleep(duration)
 781                });
 782            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 783
 784            tracing::info!("connection opened");
 785
 786            let session = Session {
 787                principal: principal.clone(),
 788                connection_id,
 789                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 790                peer: this.peer.clone(),
 791                connection_pool: this.connection_pool.clone(),
 792                app_state: this.app_state.clone(),
 793                geoip_country_code,
 794                system_id,
 795                _executor: executor.clone(),
 796            };
 797
 798            if let Err(error) = this
 799                .send_initial_client_update(
 800                    connection_id,
 801                    zed_version,
 802                    send_connection_id,
 803                    &session,
 804                )
 805                .await
 806            {
 807                tracing::error!(?error, "failed to send initial client update");
 808                return;
 809            }
 810            drop(connection_guard);
 811
 812            let handle_io = handle_io.fuse();
 813            futures::pin_mut!(handle_io);
 814
 815            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 816            // This prevents deadlocks when e.g., client A performs a request to client B and
 817            // client B performs a request to client A. If both clients stop processing further
 818            // messages until their respective request completes, they won't have a chance to
 819            // respond to the other client's request and cause a deadlock.
 820            //
 821            // This arrangement ensures we will attempt to process earlier messages first, but fall
 822            // back to processing messages arrived later in the spirit of making progress.
 823            const MAX_CONCURRENT_HANDLERS: usize = 256;
 824            let mut foreground_message_handlers = FuturesUnordered::new();
 825            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 826            let get_concurrent_handlers = {
 827                let concurrent_handlers = concurrent_handlers.clone();
 828                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 829            };
 830            loop {
 831                let next_message = async {
 832                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 833                    let message = incoming_rx.next().await;
 834                    // Cache the concurrent_handlers here, so that we know what the
 835                    // queue looks like as each handler starts
 836                    (permit, message, get_concurrent_handlers())
 837                }
 838                .fuse();
 839                futures::pin_mut!(next_message);
 840                futures::select_biased! {
 841                    _ = teardown.changed().fuse() => return,
 842                    result = handle_io => {
 843                        if let Err(error) = result {
 844                            tracing::error!(?error, "error handling I/O");
 845                        }
 846                        break;
 847                    }
 848                    _ = foreground_message_handlers.next() => {}
 849                    next_message = next_message => {
 850                        let (permit, message, concurrent_handlers) = next_message;
 851                        if let Some(message) = message {
 852                            let type_name = message.payload_type_name();
 853                            // note: we copy all the fields from the parent span so we can query them in the logs.
 854                            // (https://github.com/tokio-rs/tracing/issues/2670).
 855                            let span = tracing::info_span!("receive message",
 856                                %connection_id,
 857                                %address,
 858                                type_name,
 859                                concurrent_handlers,
 860                                user_id=field::Empty,
 861                                login=field::Empty,
 862                                lsp_query_request=field::Empty,
 863                                release_channel=field::Empty,
 864                                { TOTAL_DURATION_MS }=field::Empty,
 865                                { PROCESSING_DURATION_MS }=field::Empty,
 866                                { QUEUE_DURATION_MS }=field::Empty,
 867                                { HOST_WAITING_MS }=field::Empty
 868                            );
 869                            principal.update_span(&span);
 870                            let span_enter = span.enter();
 871                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 872                                let is_background = message.is_background();
 873                                let handle_message = (handler)(message, session.clone(), span.clone());
 874                                drop(span_enter);
 875
 876                                let handle_message = async move {
 877                                    handle_message.await;
 878                                    drop(permit);
 879                                }.instrument(span);
 880                                if is_background {
 881                                    executor.spawn_detached(handle_message);
 882                                } else {
 883                                    foreground_message_handlers.push(handle_message);
 884                                }
 885                            } else {
 886                                tracing::error!("no message handler");
 887                            }
 888                        } else {
 889                            tracing::info!("connection closed");
 890                            break;
 891                        }
 892                    }
 893                }
 894            }
 895
 896            drop(foreground_message_handlers);
 897            let concurrent_handlers = get_concurrent_handlers();
 898            tracing::info!(concurrent_handlers, "signing out");
 899            if let Err(error) = connection_lost(session, teardown, executor).await {
 900                tracing::error!(?error, "error signing out");
 901            }
 902        }
 903        .instrument(span)
 904    }
 905
 906    async fn send_initial_client_update(
 907        &self,
 908        connection_id: ConnectionId,
 909        zed_version: ZedVersion,
 910        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 911        session: &Session,
 912    ) -> Result<()> {
 913        self.peer.send(
 914            connection_id,
 915            proto::Hello {
 916                peer_id: Some(connection_id.into()),
 917            },
 918        )?;
 919        tracing::info!("sent hello message");
 920        if let Some(send_connection_id) = send_connection_id.take() {
 921            let _ = send_connection_id.send(connection_id);
 922        }
 923
 924        match &session.principal {
 925            Principal::User(user) => {
 926                if !user.connected_once {
 927                    self.peer.send(connection_id, proto::ShowContacts {})?;
 928                    self.app_state
 929                        .db
 930                        .set_user_connected_once(user.id, true)
 931                        .await?;
 932                }
 933
 934                let contacts = self.app_state.db.get_contacts(user.id).await?;
 935
 936                {
 937                    let mut pool = self.connection_pool.lock();
 938                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 939                    self.peer.send(
 940                        connection_id,
 941                        build_initial_contacts_update(contacts, &pool),
 942                    )?;
 943                }
 944
 945                if should_auto_subscribe_to_channels(&zed_version) {
 946                    subscribe_user_to_channels(user.id, session).await?;
 947                }
 948
 949                if let Some(incoming_call) =
 950                    self.app_state.db.incoming_call_for_user(user.id).await?
 951                {
 952                    self.peer.send(connection_id, incoming_call)?;
 953                }
 954
 955                update_user_contacts(user.id, session).await?;
 956            }
 957        }
 958
 959        Ok(())
 960    }
 961}
 962
 963impl Deref for ConnectionPoolGuard<'_> {
 964    type Target = ConnectionPool;
 965
 966    fn deref(&self) -> &Self::Target {
 967        &self.guard
 968    }
 969}
 970
 971impl DerefMut for ConnectionPoolGuard<'_> {
 972    fn deref_mut(&mut self) -> &mut Self::Target {
 973        &mut self.guard
 974    }
 975}
 976
 977impl Drop for ConnectionPoolGuard<'_> {
 978    fn drop(&mut self) {
 979        #[cfg(feature = "test-support")]
 980        self.check_invariants();
 981    }
 982}
 983
 984fn broadcast<F>(
 985    sender_id: Option<ConnectionId>,
 986    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 987    mut f: F,
 988) where
 989    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 990{
 991    for receiver_id in receiver_ids {
 992        if Some(receiver_id) != sender_id
 993            && let Err(error) = f(receiver_id)
 994        {
 995            tracing::error!("failed to send to {:?} {}", receiver_id, error);
 996        }
 997    }
 998}
 999
1000pub struct ProtocolVersion(u32);
1001
1002impl Header for ProtocolVersion {
1003    fn name() -> &'static HeaderName {
1004        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1005        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1006    }
1007
1008    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1009    where
1010        Self: Sized,
1011        I: Iterator<Item = &'i axum::http::HeaderValue>,
1012    {
1013        let version = values
1014            .next()
1015            .ok_or_else(axum::headers::Error::invalid)?
1016            .to_str()
1017            .map_err(|_| axum::headers::Error::invalid())?
1018            .parse()
1019            .map_err(|_| axum::headers::Error::invalid())?;
1020        Ok(Self(version))
1021    }
1022
1023    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1024        values.extend([self.0.to_string().parse().unwrap()]);
1025    }
1026}
1027
1028pub struct AppVersionHeader(Version);
1029impl Header for AppVersionHeader {
1030    fn name() -> &'static HeaderName {
1031        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1032        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1033    }
1034
1035    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1036    where
1037        Self: Sized,
1038        I: Iterator<Item = &'i axum::http::HeaderValue>,
1039    {
1040        let version = values
1041            .next()
1042            .ok_or_else(axum::headers::Error::invalid)?
1043            .to_str()
1044            .map_err(|_| axum::headers::Error::invalid())?
1045            .parse()
1046            .map_err(|_| axum::headers::Error::invalid())?;
1047        Ok(Self(version))
1048    }
1049
1050    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1051        values.extend([self.0.to_string().parse().unwrap()]);
1052    }
1053}
1054
1055#[derive(Debug)]
1056pub struct ReleaseChannelHeader(String);
1057
1058impl Header for ReleaseChannelHeader {
1059    fn name() -> &'static HeaderName {
1060        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1061        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1062    }
1063
1064    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1065    where
1066        Self: Sized,
1067        I: Iterator<Item = &'i axum::http::HeaderValue>,
1068    {
1069        Ok(Self(
1070            values
1071                .next()
1072                .ok_or_else(axum::headers::Error::invalid)?
1073                .to_str()
1074                .map_err(|_| axum::headers::Error::invalid())?
1075                .to_owned(),
1076        ))
1077    }
1078
1079    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1080        values.extend([self.0.parse().unwrap()]);
1081    }
1082}
1083
1084pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1085    Router::new()
1086        .route("/rpc", get(handle_websocket_request))
1087        .layer(
1088            ServiceBuilder::new()
1089                .layer(Extension(server.app_state.clone()))
1090                .layer(middleware::from_fn(auth::validate_header)),
1091        )
1092        .route("/metrics", get(handle_metrics))
1093        .layer(Extension(server))
1094}
1095
1096pub async fn handle_websocket_request(
1097    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1098    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1099    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1100    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1101    Extension(server): Extension<Arc<Server>>,
1102    Extension(principal): Extension<Principal>,
1103    user_agent: Option<TypedHeader<UserAgent>>,
1104    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1105    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1106    ws: WebSocketUpgrade,
1107) -> axum::response::Response {
1108    if protocol_version != rpc::PROTOCOL_VERSION {
1109        return (
1110            StatusCode::UPGRADE_REQUIRED,
1111            "client must be upgraded".to_string(),
1112        )
1113            .into_response();
1114    }
1115
1116    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1117        return (
1118            StatusCode::UPGRADE_REQUIRED,
1119            "no version header found".to_string(),
1120        )
1121            .into_response();
1122    };
1123
1124    let release_channel = release_channel_header.map(|header| header.0.0);
1125
1126    if !version.can_collaborate() {
1127        return (
1128            StatusCode::UPGRADE_REQUIRED,
1129            "client must be upgraded".to_string(),
1130        )
1131            .into_response();
1132    }
1133
1134    let socket_address = socket_address.to_string();
1135
1136    // Acquire connection guard before WebSocket upgrade
1137    let connection_guard = match ConnectionGuard::try_acquire() {
1138        Ok(guard) => guard,
1139        Err(()) => {
1140            return (
1141                StatusCode::SERVICE_UNAVAILABLE,
1142                "Too many concurrent connections",
1143            )
1144                .into_response();
1145        }
1146    };
1147
1148    ws.on_upgrade(move |socket| {
1149        let socket = socket
1150            .map_ok(to_tungstenite_message)
1151            .err_into()
1152            .with(|message| async move { to_axum_message(message) });
1153        let connection = Connection::new(Box::pin(socket));
1154        async move {
1155            server
1156                .handle_connection(
1157                    connection,
1158                    socket_address,
1159                    principal,
1160                    version,
1161                    release_channel,
1162                    user_agent.map(|header| header.to_string()),
1163                    country_code_header.map(|header| header.to_string()),
1164                    system_id_header.map(|header| header.to_string()),
1165                    None,
1166                    Executor::Production,
1167                    Some(connection_guard),
1168                )
1169                .await;
1170        }
1171    })
1172}
1173
1174pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1175    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1176    let connections_metric = CONNECTIONS_METRIC
1177        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1178
1179    let connections = server
1180        .connection_pool
1181        .lock()
1182        .connections()
1183        .filter(|connection| !connection.admin)
1184        .count();
1185    connections_metric.set(connections as _);
1186
1187    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1188    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1189        register_int_gauge!(
1190            "shared_projects",
1191            "number of open projects with one or more guests"
1192        )
1193        .unwrap()
1194    });
1195
1196    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1197    shared_projects_metric.set(shared_projects as _);
1198
1199    let encoder = prometheus::TextEncoder::new();
1200    let metric_families = prometheus::gather();
1201    let encoded_metrics = encoder
1202        .encode_to_string(&metric_families)
1203        .map_err(|err| anyhow!("{err}"))?;
1204    Ok(encoded_metrics)
1205}
1206
1207#[instrument(err, skip(executor))]
1208async fn connection_lost(
1209    session: Session,
1210    mut teardown: watch::Receiver<bool>,
1211    executor: Executor,
1212) -> Result<()> {
1213    session.peer.disconnect(session.connection_id);
1214    session
1215        .connection_pool()
1216        .await
1217        .remove_connection(session.connection_id)?;
1218
1219    session
1220        .db()
1221        .await
1222        .connection_lost(session.connection_id)
1223        .await
1224        .trace_err();
1225
1226    futures::select_biased! {
1227        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1228
1229            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1230            leave_room_for_session(&session, session.connection_id).await.trace_err();
1231            leave_channel_buffers_for_session(&session)
1232                .await
1233                .trace_err();
1234
1235            if !session
1236                .connection_pool()
1237                .await
1238                .is_user_online(session.user_id())
1239            {
1240                let db = session.db().await;
1241                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1242                    room_updated(&room, &session.peer);
1243                }
1244            }
1245
1246            update_user_contacts(session.user_id(), &session).await?;
1247        },
1248        _ = teardown.changed().fuse() => {}
1249    }
1250
1251    Ok(())
1252}
1253
1254/// Acknowledges a ping from a client, used to keep the connection alive.
1255async fn ping(
1256    _: proto::Ping,
1257    response: Response<proto::Ping>,
1258    _session: MessageContext,
1259) -> Result<()> {
1260    response.send(proto::Ack {})?;
1261    Ok(())
1262}
1263
1264/// Creates a new room for calling (outside of channels)
1265async fn create_room(
1266    _request: proto::CreateRoom,
1267    response: Response<proto::CreateRoom>,
1268    session: MessageContext,
1269) -> Result<()> {
1270    let livekit_room = nanoid::nanoid!(30);
1271
1272    let live_kit_connection_info = util::maybe!(async {
1273        let live_kit = session.app_state.livekit_client.as_ref();
1274        let live_kit = live_kit?;
1275        let user_id = session.user_id().to_string();
1276
1277        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1278
1279        Some(proto::LiveKitConnectionInfo {
1280            server_url: live_kit.url().into(),
1281            token,
1282            can_publish: true,
1283        })
1284    })
1285    .await;
1286
1287    let room = session
1288        .db()
1289        .await
1290        .create_room(session.user_id(), session.connection_id, &livekit_room)
1291        .await?;
1292
1293    response.send(proto::CreateRoomResponse {
1294        room: Some(room.clone()),
1295        live_kit_connection_info,
1296    })?;
1297
1298    update_user_contacts(session.user_id(), &session).await?;
1299    Ok(())
1300}
1301
1302/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1303async fn join_room(
1304    request: proto::JoinRoom,
1305    response: Response<proto::JoinRoom>,
1306    session: MessageContext,
1307) -> Result<()> {
1308    let room_id = RoomId::from_proto(request.id);
1309
1310    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1311
1312    if let Some(channel_id) = channel_id {
1313        return join_channel_internal(channel_id, Box::new(response), session).await;
1314    }
1315
1316    let joined_room = {
1317        let room = session
1318            .db()
1319            .await
1320            .join_room(room_id, session.user_id(), session.connection_id)
1321            .await?;
1322        room_updated(&room.room, &session.peer);
1323        room.into_inner()
1324    };
1325
1326    for connection_id in session
1327        .connection_pool()
1328        .await
1329        .user_connection_ids(session.user_id())
1330    {
1331        session
1332            .peer
1333            .send(
1334                connection_id,
1335                proto::CallCanceled {
1336                    room_id: room_id.to_proto(),
1337                },
1338            )
1339            .trace_err();
1340    }
1341
1342    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1343    {
1344        live_kit
1345            .room_token(
1346                &joined_room.room.livekit_room,
1347                &session.user_id().to_string(),
1348            )
1349            .trace_err()
1350            .map(|token| proto::LiveKitConnectionInfo {
1351                server_url: live_kit.url().into(),
1352                token,
1353                can_publish: true,
1354            })
1355    } else {
1356        None
1357    };
1358
1359    response.send(proto::JoinRoomResponse {
1360        room: Some(joined_room.room),
1361        channel_id: None,
1362        live_kit_connection_info,
1363    })?;
1364
1365    update_user_contacts(session.user_id(), &session).await?;
1366    Ok(())
1367}
1368
1369/// Rejoin room is used to reconnect to a room after connection errors.
1370async fn rejoin_room(
1371    request: proto::RejoinRoom,
1372    response: Response<proto::RejoinRoom>,
1373    session: MessageContext,
1374) -> Result<()> {
1375    let room;
1376    let channel;
1377    {
1378        let mut rejoined_room = session
1379            .db()
1380            .await
1381            .rejoin_room(request, session.user_id(), session.connection_id)
1382            .await?;
1383
1384        response.send(proto::RejoinRoomResponse {
1385            room: Some(rejoined_room.room.clone()),
1386            reshared_projects: rejoined_room
1387                .reshared_projects
1388                .iter()
1389                .map(|project| proto::ResharedProject {
1390                    id: project.id.to_proto(),
1391                    collaborators: project
1392                        .collaborators
1393                        .iter()
1394                        .map(|collaborator| collaborator.to_proto())
1395                        .collect(),
1396                })
1397                .collect(),
1398            rejoined_projects: rejoined_room
1399                .rejoined_projects
1400                .iter()
1401                .map(|rejoined_project| rejoined_project.to_proto())
1402                .collect(),
1403        })?;
1404        room_updated(&rejoined_room.room, &session.peer);
1405
1406        for project in &rejoined_room.reshared_projects {
1407            for collaborator in &project.collaborators {
1408                session
1409                    .peer
1410                    .send(
1411                        collaborator.connection_id,
1412                        proto::UpdateProjectCollaborator {
1413                            project_id: project.id.to_proto(),
1414                            old_peer_id: Some(project.old_connection_id.into()),
1415                            new_peer_id: Some(session.connection_id.into()),
1416                        },
1417                    )
1418                    .trace_err();
1419            }
1420
1421            broadcast(
1422                Some(session.connection_id),
1423                project
1424                    .collaborators
1425                    .iter()
1426                    .map(|collaborator| collaborator.connection_id),
1427                |connection_id| {
1428                    session.peer.forward_send(
1429                        session.connection_id,
1430                        connection_id,
1431                        proto::UpdateProject {
1432                            project_id: project.id.to_proto(),
1433                            worktrees: project.worktrees.clone(),
1434                        },
1435                    )
1436                },
1437            );
1438        }
1439
1440        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1441
1442        let rejoined_room = rejoined_room.into_inner();
1443
1444        room = rejoined_room.room;
1445        channel = rejoined_room.channel;
1446    }
1447
1448    if let Some(channel) = channel {
1449        channel_updated(
1450            &channel,
1451            &room,
1452            &session.peer,
1453            &*session.connection_pool().await,
1454        );
1455    }
1456
1457    update_user_contacts(session.user_id(), &session).await?;
1458    Ok(())
1459}
1460
1461fn notify_rejoined_projects(
1462    rejoined_projects: &mut Vec<RejoinedProject>,
1463    session: &Session,
1464) -> Result<()> {
1465    for project in rejoined_projects.iter() {
1466        for collaborator in &project.collaborators {
1467            session
1468                .peer
1469                .send(
1470                    collaborator.connection_id,
1471                    proto::UpdateProjectCollaborator {
1472                        project_id: project.id.to_proto(),
1473                        old_peer_id: Some(project.old_connection_id.into()),
1474                        new_peer_id: Some(session.connection_id.into()),
1475                    },
1476                )
1477                .trace_err();
1478        }
1479    }
1480
1481    for project in rejoined_projects {
1482        for worktree in mem::take(&mut project.worktrees) {
1483            // Stream this worktree's entries.
1484            let message = proto::UpdateWorktree {
1485                project_id: project.id.to_proto(),
1486                worktree_id: worktree.id,
1487                abs_path: worktree.abs_path.clone(),
1488                root_name: worktree.root_name,
1489                root_repo_common_dir: worktree.root_repo_common_dir,
1490                updated_entries: worktree.updated_entries,
1491                removed_entries: worktree.removed_entries,
1492                scan_id: worktree.scan_id,
1493                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1494                updated_repositories: worktree.updated_repositories,
1495                removed_repositories: worktree.removed_repositories,
1496            };
1497            for update in proto::split_worktree_update(message) {
1498                session.peer.send(session.connection_id, update)?;
1499            }
1500
1501            // Stream this worktree's diagnostics.
1502            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1503            if let Some(summary) = worktree_diagnostics.next() {
1504                let message = proto::UpdateDiagnosticSummary {
1505                    project_id: project.id.to_proto(),
1506                    worktree_id: worktree.id,
1507                    summary: Some(summary),
1508                    more_summaries: worktree_diagnostics.collect(),
1509                };
1510                session.peer.send(session.connection_id, message)?;
1511            }
1512
1513            for settings_file in worktree.settings_files {
1514                session.peer.send(
1515                    session.connection_id,
1516                    proto::UpdateWorktreeSettings {
1517                        project_id: project.id.to_proto(),
1518                        worktree_id: worktree.id,
1519                        path: settings_file.path,
1520                        content: Some(settings_file.content),
1521                        kind: Some(settings_file.kind.to_proto().into()),
1522                        outside_worktree: Some(settings_file.outside_worktree),
1523                    },
1524                )?;
1525            }
1526        }
1527
1528        for repository in mem::take(&mut project.updated_repositories) {
1529            for update in split_repository_update(repository) {
1530                session.peer.send(session.connection_id, update)?;
1531            }
1532        }
1533
1534        for id in mem::take(&mut project.removed_repositories) {
1535            session.peer.send(
1536                session.connection_id,
1537                proto::RemoveRepository {
1538                    project_id: project.id.to_proto(),
1539                    id,
1540                },
1541            )?;
1542        }
1543    }
1544
1545    Ok(())
1546}
1547
1548/// leave room disconnects from the room.
1549async fn leave_room(
1550    _: proto::LeaveRoom,
1551    response: Response<proto::LeaveRoom>,
1552    session: MessageContext,
1553) -> Result<()> {
1554    leave_room_for_session(&session, session.connection_id).await?;
1555    response.send(proto::Ack {})?;
1556    Ok(())
1557}
1558
1559/// Updates the permissions of someone else in the room.
1560async fn set_room_participant_role(
1561    request: proto::SetRoomParticipantRole,
1562    response: Response<proto::SetRoomParticipantRole>,
1563    session: MessageContext,
1564) -> Result<()> {
1565    let user_id = UserId::from_proto(request.user_id);
1566    let role = ChannelRole::from(request.role());
1567
1568    let (livekit_room, can_publish) = {
1569        let room = session
1570            .db()
1571            .await
1572            .set_room_participant_role(
1573                session.user_id(),
1574                RoomId::from_proto(request.room_id),
1575                user_id,
1576                role,
1577            )
1578            .await?;
1579
1580        let livekit_room = room.livekit_room.clone();
1581        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1582        room_updated(&room, &session.peer);
1583        (livekit_room, can_publish)
1584    };
1585
1586    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1587        live_kit
1588            .update_participant(
1589                livekit_room.clone(),
1590                request.user_id.to_string(),
1591                livekit_api::proto::ParticipantPermission {
1592                    can_subscribe: true,
1593                    can_publish,
1594                    can_publish_data: can_publish,
1595                    hidden: false,
1596                    recorder: false,
1597                },
1598            )
1599            .await
1600            .trace_err();
1601    }
1602
1603    response.send(proto::Ack {})?;
1604    Ok(())
1605}
1606
1607/// Call someone else into the current room
1608async fn call(
1609    request: proto::Call,
1610    response: Response<proto::Call>,
1611    session: MessageContext,
1612) -> Result<()> {
1613    let room_id = RoomId::from_proto(request.room_id);
1614    let calling_user_id = session.user_id();
1615    let calling_connection_id = session.connection_id;
1616    let called_user_id = UserId::from_proto(request.called_user_id);
1617    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1618    if !session
1619        .db()
1620        .await
1621        .has_contact(calling_user_id, called_user_id)
1622        .await?
1623    {
1624        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1625    }
1626
1627    let incoming_call = {
1628        let (room, incoming_call) = &mut *session
1629            .db()
1630            .await
1631            .call(
1632                room_id,
1633                calling_user_id,
1634                calling_connection_id,
1635                called_user_id,
1636                initial_project_id,
1637            )
1638            .await?;
1639        room_updated(room, &session.peer);
1640        mem::take(incoming_call)
1641    };
1642    update_user_contacts(called_user_id, &session).await?;
1643
1644    let mut calls = session
1645        .connection_pool()
1646        .await
1647        .user_connection_ids(called_user_id)
1648        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1649        .collect::<FuturesUnordered<_>>();
1650
1651    while let Some(call_response) = calls.next().await {
1652        match call_response.as_ref() {
1653            Ok(_) => {
1654                response.send(proto::Ack {})?;
1655                return Ok(());
1656            }
1657            Err(_) => {
1658                call_response.trace_err();
1659            }
1660        }
1661    }
1662
1663    {
1664        let room = session
1665            .db()
1666            .await
1667            .call_failed(room_id, called_user_id)
1668            .await?;
1669        room_updated(&room, &session.peer);
1670    }
1671    update_user_contacts(called_user_id, &session).await?;
1672
1673    Err(anyhow!("failed to ring user"))?
1674}
1675
1676/// Cancel an outgoing call.
1677async fn cancel_call(
1678    request: proto::CancelCall,
1679    response: Response<proto::CancelCall>,
1680    session: MessageContext,
1681) -> Result<()> {
1682    let called_user_id = UserId::from_proto(request.called_user_id);
1683    let room_id = RoomId::from_proto(request.room_id);
1684    {
1685        let room = session
1686            .db()
1687            .await
1688            .cancel_call(room_id, session.connection_id, called_user_id)
1689            .await?;
1690        room_updated(&room, &session.peer);
1691    }
1692
1693    for connection_id in session
1694        .connection_pool()
1695        .await
1696        .user_connection_ids(called_user_id)
1697    {
1698        session
1699            .peer
1700            .send(
1701                connection_id,
1702                proto::CallCanceled {
1703                    room_id: room_id.to_proto(),
1704                },
1705            )
1706            .trace_err();
1707    }
1708    response.send(proto::Ack {})?;
1709
1710    update_user_contacts(called_user_id, &session).await?;
1711    Ok(())
1712}
1713
1714/// Decline an incoming call.
1715async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1716    let room_id = RoomId::from_proto(message.room_id);
1717    {
1718        let room = session
1719            .db()
1720            .await
1721            .decline_call(Some(room_id), session.user_id())
1722            .await?
1723            .context("declining call")?;
1724        room_updated(&room, &session.peer);
1725    }
1726
1727    for connection_id in session
1728        .connection_pool()
1729        .await
1730        .user_connection_ids(session.user_id())
1731    {
1732        session
1733            .peer
1734            .send(
1735                connection_id,
1736                proto::CallCanceled {
1737                    room_id: room_id.to_proto(),
1738                },
1739            )
1740            .trace_err();
1741    }
1742    update_user_contacts(session.user_id(), &session).await?;
1743    Ok(())
1744}
1745
1746/// Updates other participants in the room with your current location.
1747async fn update_participant_location(
1748    request: proto::UpdateParticipantLocation,
1749    response: Response<proto::UpdateParticipantLocation>,
1750    session: MessageContext,
1751) -> Result<()> {
1752    let room_id = RoomId::from_proto(request.room_id);
1753    let location = request.location.context("invalid location")?;
1754
1755    let db = session.db().await;
1756    let room = db
1757        .update_room_participant_location(room_id, session.connection_id, location)
1758        .await?;
1759
1760    room_updated(&room, &session.peer);
1761    response.send(proto::Ack {})?;
1762    Ok(())
1763}
1764
1765/// Share a project into the room.
1766async fn share_project(
1767    request: proto::ShareProject,
1768    response: Response<proto::ShareProject>,
1769    session: MessageContext,
1770) -> Result<()> {
1771    let (project_id, room) = &*session
1772        .db()
1773        .await
1774        .share_project(
1775            RoomId::from_proto(request.room_id),
1776            session.connection_id,
1777            &request.worktrees,
1778            request.is_ssh_project,
1779            request.windows_paths.unwrap_or(false),
1780            &request.features,
1781        )
1782        .await?;
1783    response.send(proto::ShareProjectResponse {
1784        project_id: project_id.to_proto(),
1785    })?;
1786    room_updated(room, &session.peer);
1787
1788    Ok(())
1789}
1790
1791/// Unshare a project from the room.
1792async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1793    let project_id = ProjectId::from_proto(message.project_id);
1794    unshare_project_internal(project_id, session.connection_id, &session).await
1795}
1796
1797async fn unshare_project_internal(
1798    project_id: ProjectId,
1799    connection_id: ConnectionId,
1800    session: &Session,
1801) -> Result<()> {
1802    let delete = {
1803        let room_guard = session
1804            .db()
1805            .await
1806            .unshare_project(project_id, connection_id)
1807            .await?;
1808
1809        let (delete, room, guest_connection_ids) = &*room_guard;
1810
1811        let message = proto::UnshareProject {
1812            project_id: project_id.to_proto(),
1813        };
1814
1815        broadcast(
1816            Some(connection_id),
1817            guest_connection_ids.iter().copied(),
1818            |conn_id| session.peer.send(conn_id, message.clone()),
1819        );
1820        if let Some(room) = room {
1821            room_updated(room, &session.peer);
1822        }
1823
1824        *delete
1825    };
1826
1827    if delete {
1828        let db = session.db().await;
1829        db.delete_project(project_id).await?;
1830    }
1831
1832    Ok(())
1833}
1834
1835/// Join someone elses shared project.
1836async fn join_project(
1837    request: proto::JoinProject,
1838    response: Response<proto::JoinProject>,
1839    session: MessageContext,
1840) -> Result<()> {
1841    let project_id = ProjectId::from_proto(request.project_id);
1842
1843    tracing::info!(%project_id, "join project");
1844
1845    let db = session.db().await;
1846    let project_model = db.get_project(project_id).await?;
1847    let host_features: Vec<String> =
1848        serde_json::from_str(&project_model.features).unwrap_or_default();
1849    let guest_features: HashSet<_> = request.features.iter().collect();
1850    let host_features_set: HashSet<_> = host_features.iter().collect();
1851    if guest_features != host_features_set {
1852        let host_connection_id = project_model.host_connection()?;
1853        let mut pool = session.connection_pool().await;
1854        let host_version = pool
1855            .connection(host_connection_id)
1856            .map(|c| c.zed_version.to_string());
1857        let guest_version = pool
1858            .connection(session.connection_id)
1859            .map(|c| c.zed_version.to_string());
1860        drop(pool);
1861        Err(anyhow!(
1862            "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1863            host_version.as_deref().unwrap_or("unknown"),
1864            guest_version.as_deref().unwrap_or("unknown"),
1865        ))?;
1866    }
1867
1868    let (project, replica_id) = &mut *db
1869        .join_project(
1870            project_id,
1871            session.connection_id,
1872            session.user_id(),
1873            request.committer_name.clone(),
1874            request.committer_email.clone(),
1875        )
1876        .await?;
1877    drop(db);
1878
1879    tracing::info!(%project_id, "join remote project");
1880    let collaborators = project
1881        .collaborators
1882        .iter()
1883        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1884        .map(|collaborator| collaborator.to_proto())
1885        .collect::<Vec<_>>();
1886    let project_id = project.id;
1887    let guest_user_id = session.user_id();
1888
1889    let worktrees = project
1890        .worktrees
1891        .iter()
1892        .map(|(id, worktree)| proto::WorktreeMetadata {
1893            id: *id,
1894            root_name: worktree.root_name.clone(),
1895            visible: worktree.visible,
1896            abs_path: worktree.abs_path.clone(),
1897        })
1898        .collect::<Vec<_>>();
1899
1900    let add_project_collaborator = proto::AddProjectCollaborator {
1901        project_id: project_id.to_proto(),
1902        collaborator: Some(proto::Collaborator {
1903            peer_id: Some(session.connection_id.into()),
1904            replica_id: replica_id.0 as u32,
1905            user_id: guest_user_id.to_proto(),
1906            is_host: false,
1907            committer_name: request.committer_name.clone(),
1908            committer_email: request.committer_email.clone(),
1909        }),
1910    };
1911
1912    for collaborator in &collaborators {
1913        session
1914            .peer
1915            .send(
1916                collaborator.peer_id.unwrap().into(),
1917                add_project_collaborator.clone(),
1918            )
1919            .trace_err();
1920    }
1921
1922    // First, we send the metadata associated with each worktree.
1923    let (language_servers, language_server_capabilities) = project
1924        .language_servers
1925        .clone()
1926        .into_iter()
1927        .map(|server| (server.server, server.capabilities))
1928        .unzip();
1929    response.send(proto::JoinProjectResponse {
1930        project_id: project.id.0 as u64,
1931        worktrees,
1932        replica_id: replica_id.0 as u32,
1933        collaborators,
1934        language_servers,
1935        language_server_capabilities,
1936        role: project.role.into(),
1937        windows_paths: project.path_style == PathStyle::Windows,
1938        features: project.features.clone(),
1939    })?;
1940
1941    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1942        // Stream this worktree's entries.
1943        let message = proto::UpdateWorktree {
1944            project_id: project_id.to_proto(),
1945            worktree_id,
1946            abs_path: worktree.abs_path.clone(),
1947            root_name: worktree.root_name,
1948            root_repo_common_dir: worktree.root_repo_common_dir,
1949            updated_entries: worktree.entries,
1950            removed_entries: Default::default(),
1951            scan_id: worktree.scan_id,
1952            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1953            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1954            removed_repositories: Default::default(),
1955        };
1956        for update in proto::split_worktree_update(message) {
1957            session.peer.send(session.connection_id, update.clone())?;
1958        }
1959
1960        // Stream this worktree's diagnostics.
1961        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1962        if let Some(summary) = worktree_diagnostics.next() {
1963            let message = proto::UpdateDiagnosticSummary {
1964                project_id: project.id.to_proto(),
1965                worktree_id: worktree.id,
1966                summary: Some(summary),
1967                more_summaries: worktree_diagnostics.collect(),
1968            };
1969            session.peer.send(session.connection_id, message)?;
1970        }
1971
1972        for settings_file in worktree.settings_files {
1973            session.peer.send(
1974                session.connection_id,
1975                proto::UpdateWorktreeSettings {
1976                    project_id: project_id.to_proto(),
1977                    worktree_id: worktree.id,
1978                    path: settings_file.path,
1979                    content: Some(settings_file.content),
1980                    kind: Some(settings_file.kind.to_proto() as i32),
1981                    outside_worktree: Some(settings_file.outside_worktree),
1982                },
1983            )?;
1984        }
1985    }
1986
1987    for repository in mem::take(&mut project.repositories) {
1988        for update in split_repository_update(repository) {
1989            session.peer.send(session.connection_id, update)?;
1990        }
1991    }
1992
1993    for language_server in &project.language_servers {
1994        session.peer.send(
1995            session.connection_id,
1996            proto::UpdateLanguageServer {
1997                project_id: project_id.to_proto(),
1998                server_name: Some(language_server.server.name.clone()),
1999                language_server_id: language_server.server.id,
2000                variant: Some(
2001                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2002                        proto::LspDiskBasedDiagnosticsUpdated {},
2003                    ),
2004                ),
2005            },
2006        )?;
2007    }
2008
2009    Ok(())
2010}
2011
2012/// Leave someone elses shared project.
2013async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2014    let sender_id = session.connection_id;
2015    let project_id = ProjectId::from_proto(request.project_id);
2016    let db = session.db().await;
2017
2018    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2019    tracing::info!(
2020        %project_id,
2021        "leave project"
2022    );
2023
2024    project_left(project, &session);
2025    if let Some(room) = room {
2026        room_updated(room, &session.peer);
2027    }
2028
2029    Ok(())
2030}
2031
2032/// Updates other participants with changes to the project
2033async fn update_project(
2034    request: proto::UpdateProject,
2035    response: Response<proto::UpdateProject>,
2036    session: MessageContext,
2037) -> Result<()> {
2038    let project_id = ProjectId::from_proto(request.project_id);
2039    let (room, guest_connection_ids) = &*session
2040        .db()
2041        .await
2042        .update_project(project_id, session.connection_id, &request.worktrees)
2043        .await?;
2044    broadcast(
2045        Some(session.connection_id),
2046        guest_connection_ids.iter().copied(),
2047        |connection_id| {
2048            session
2049                .peer
2050                .forward_send(session.connection_id, connection_id, request.clone())
2051        },
2052    );
2053    if let Some(room) = room {
2054        room_updated(room, &session.peer);
2055    }
2056    response.send(proto::Ack {})?;
2057
2058    Ok(())
2059}
2060
2061/// Updates other participants with changes to the worktree
2062async fn update_worktree(
2063    request: proto::UpdateWorktree,
2064    response: Response<proto::UpdateWorktree>,
2065    session: MessageContext,
2066) -> Result<()> {
2067    let guest_connection_ids = session
2068        .db()
2069        .await
2070        .update_worktree(&request, session.connection_id)
2071        .await?;
2072
2073    broadcast(
2074        Some(session.connection_id),
2075        guest_connection_ids.iter().copied(),
2076        |connection_id| {
2077            session
2078                .peer
2079                .forward_send(session.connection_id, connection_id, request.clone())
2080        },
2081    );
2082    response.send(proto::Ack {})?;
2083    Ok(())
2084}
2085
2086async fn update_repository(
2087    request: proto::UpdateRepository,
2088    response: Response<proto::UpdateRepository>,
2089    session: MessageContext,
2090) -> Result<()> {
2091    let guest_connection_ids = session
2092        .db()
2093        .await
2094        .update_repository(&request, session.connection_id)
2095        .await?;
2096
2097    broadcast(
2098        Some(session.connection_id),
2099        guest_connection_ids.iter().copied(),
2100        |connection_id| {
2101            session
2102                .peer
2103                .forward_send(session.connection_id, connection_id, request.clone())
2104        },
2105    );
2106    response.send(proto::Ack {})?;
2107    Ok(())
2108}
2109
2110async fn remove_repository(
2111    request: proto::RemoveRepository,
2112    response: Response<proto::RemoveRepository>,
2113    session: MessageContext,
2114) -> Result<()> {
2115    let guest_connection_ids = session
2116        .db()
2117        .await
2118        .remove_repository(&request, session.connection_id)
2119        .await?;
2120
2121    broadcast(
2122        Some(session.connection_id),
2123        guest_connection_ids.iter().copied(),
2124        |connection_id| {
2125            session
2126                .peer
2127                .forward_send(session.connection_id, connection_id, request.clone())
2128        },
2129    );
2130    response.send(proto::Ack {})?;
2131    Ok(())
2132}
2133
2134/// Updates other participants with changes to the diagnostics
2135async fn update_diagnostic_summary(
2136    message: proto::UpdateDiagnosticSummary,
2137    session: MessageContext,
2138) -> Result<()> {
2139    let guest_connection_ids = session
2140        .db()
2141        .await
2142        .update_diagnostic_summary(&message, session.connection_id)
2143        .await?;
2144
2145    broadcast(
2146        Some(session.connection_id),
2147        guest_connection_ids.iter().copied(),
2148        |connection_id| {
2149            session
2150                .peer
2151                .forward_send(session.connection_id, connection_id, message.clone())
2152        },
2153    );
2154
2155    Ok(())
2156}
2157
2158/// Updates other participants with changes to the worktree settings
2159async fn update_worktree_settings(
2160    message: proto::UpdateWorktreeSettings,
2161    session: MessageContext,
2162) -> Result<()> {
2163    let guest_connection_ids = session
2164        .db()
2165        .await
2166        .update_worktree_settings(&message, session.connection_id)
2167        .await?;
2168
2169    broadcast(
2170        Some(session.connection_id),
2171        guest_connection_ids.iter().copied(),
2172        |connection_id| {
2173            session
2174                .peer
2175                .forward_send(session.connection_id, connection_id, message.clone())
2176        },
2177    );
2178
2179    Ok(())
2180}
2181
2182/// Notify other participants that a language server has started.
2183async fn start_language_server(
2184    request: proto::StartLanguageServer,
2185    session: MessageContext,
2186) -> Result<()> {
2187    let guest_connection_ids = session
2188        .db()
2189        .await
2190        .start_language_server(&request, session.connection_id)
2191        .await?;
2192
2193    broadcast(
2194        Some(session.connection_id),
2195        guest_connection_ids.iter().copied(),
2196        |connection_id| {
2197            session
2198                .peer
2199                .forward_send(session.connection_id, connection_id, request.clone())
2200        },
2201    );
2202    Ok(())
2203}
2204
2205/// Notify other participants that a language server has changed.
2206async fn update_language_server(
2207    request: proto::UpdateLanguageServer,
2208    session: MessageContext,
2209) -> Result<()> {
2210    let project_id = ProjectId::from_proto(request.project_id);
2211    let db = session.db().await;
2212
2213    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2214        && let Some(capabilities) = update.capabilities.clone()
2215    {
2216        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2217            .await?;
2218    }
2219
2220    let project_connection_ids = db
2221        .project_connection_ids(project_id, session.connection_id, true)
2222        .await?;
2223    broadcast(
2224        Some(session.connection_id),
2225        project_connection_ids.iter().copied(),
2226        |connection_id| {
2227            session
2228                .peer
2229                .forward_send(session.connection_id, connection_id, request.clone())
2230        },
2231    );
2232    Ok(())
2233}
2234
2235/// forward a project request to the host. These requests should be read only
2236/// as guests are allowed to send them.
2237async fn forward_read_only_project_request<T>(
2238    request: T,
2239    response: Response<T>,
2240    session: MessageContext,
2241) -> Result<()>
2242where
2243    T: EntityMessage + RequestMessage,
2244{
2245    let project_id = ProjectId::from_proto(request.remote_entity_id());
2246    let host_connection_id = session
2247        .db()
2248        .await
2249        .host_for_read_only_project_request(project_id, session.connection_id)
2250        .await?;
2251    let payload = session.forward_request(host_connection_id, request).await?;
2252    response.send(payload)?;
2253    Ok(())
2254}
2255
2256/// forward a project request to the host. These requests are disallowed
2257/// for guests.
2258async fn forward_mutating_project_request<T>(
2259    request: T,
2260    response: Response<T>,
2261    session: MessageContext,
2262) -> Result<()>
2263where
2264    T: EntityMessage + RequestMessage,
2265{
2266    let project_id = ProjectId::from_proto(request.remote_entity_id());
2267
2268    let host_connection_id = session
2269        .db()
2270        .await
2271        .host_for_mutating_project_request(project_id, session.connection_id)
2272        .await?;
2273    let payload = session.forward_request(host_connection_id, request).await?;
2274    response.send(payload)?;
2275    Ok(())
2276}
2277
2278async fn disallow_guest_request<T>(
2279    _request: T,
2280    response: Response<T>,
2281    _session: MessageContext,
2282) -> Result<()>
2283where
2284    T: RequestMessage,
2285{
2286    response.peer.respond_with_error(
2287        response.receipt,
2288        ErrorCode::Forbidden
2289            .message("request is not allowed for guests".to_string())
2290            .to_proto(),
2291    )?;
2292    response.responded.store(true, SeqCst);
2293    Ok(())
2294}
2295
2296async fn lsp_query(
2297    request: proto::LspQuery,
2298    response: Response<proto::LspQuery>,
2299    session: MessageContext,
2300) -> Result<()> {
2301    let (name, should_write) = request.query_name_and_write_permissions();
2302    tracing::Span::current().record("lsp_query_request", name);
2303    tracing::info!("lsp_query message received");
2304    if should_write {
2305        forward_mutating_project_request(request, response, session).await
2306    } else {
2307        forward_read_only_project_request(request, response, session).await
2308    }
2309}
2310
2311/// Notify other participants that a new buffer has been created
2312async fn create_buffer_for_peer(
2313    request: proto::CreateBufferForPeer,
2314    session: MessageContext,
2315) -> Result<()> {
2316    session
2317        .db()
2318        .await
2319        .check_user_is_project_host(
2320            ProjectId::from_proto(request.project_id),
2321            session.connection_id,
2322        )
2323        .await?;
2324    let peer_id = request.peer_id.context("invalid peer id")?;
2325    session
2326        .peer
2327        .forward_send(session.connection_id, peer_id.into(), request)?;
2328    Ok(())
2329}
2330
2331/// Notify other participants that a new image has been created
2332async fn create_image_for_peer(
2333    request: proto::CreateImageForPeer,
2334    session: MessageContext,
2335) -> Result<()> {
2336    session
2337        .db()
2338        .await
2339        .check_user_is_project_host(
2340            ProjectId::from_proto(request.project_id),
2341            session.connection_id,
2342        )
2343        .await?;
2344    let peer_id = request.peer_id.context("invalid peer id")?;
2345    session
2346        .peer
2347        .forward_send(session.connection_id, peer_id.into(), request)?;
2348    Ok(())
2349}
2350
2351/// Notify other participants that a buffer has been updated. This is
2352/// allowed for guests as long as the update is limited to selections.
2353async fn update_buffer(
2354    request: proto::UpdateBuffer,
2355    response: Response<proto::UpdateBuffer>,
2356    session: MessageContext,
2357) -> Result<()> {
2358    let project_id = ProjectId::from_proto(request.project_id);
2359    let mut capability = Capability::ReadOnly;
2360
2361    for op in request.operations.iter() {
2362        match op.variant {
2363            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2364            Some(_) => capability = Capability::ReadWrite,
2365        }
2366    }
2367
2368    let host = {
2369        let guard = session
2370            .db()
2371            .await
2372            .connections_for_buffer_update(project_id, session.connection_id, capability)
2373            .await?;
2374
2375        let (host, guests) = &*guard;
2376
2377        broadcast(
2378            Some(session.connection_id),
2379            guests.clone(),
2380            |connection_id| {
2381                session
2382                    .peer
2383                    .forward_send(session.connection_id, connection_id, request.clone())
2384            },
2385        );
2386
2387        *host
2388    };
2389
2390    if host != session.connection_id {
2391        session.forward_request(host, request.clone()).await?;
2392    }
2393
2394    response.send(proto::Ack {})?;
2395    Ok(())
2396}
2397
2398async fn forward_project_search_chunk(
2399    message: proto::FindSearchCandidatesChunk,
2400    response: Response<proto::FindSearchCandidatesChunk>,
2401    session: MessageContext,
2402) -> Result<()> {
2403    let peer_id = message.peer_id.context("missing peer_id")?;
2404    let payload = session
2405        .peer
2406        .forward_request(session.connection_id, peer_id.into(), message)
2407        .await?;
2408    response.send(payload)?;
2409    Ok(())
2410}
2411
2412/// Notify other participants that a project has been updated.
2413async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2414    request: T,
2415    session: MessageContext,
2416) -> Result<()> {
2417    let project_id = ProjectId::from_proto(request.remote_entity_id());
2418    let project_connection_ids = session
2419        .db()
2420        .await
2421        .project_connection_ids(project_id, session.connection_id, false)
2422        .await?;
2423
2424    broadcast(
2425        Some(session.connection_id),
2426        project_connection_ids.iter().copied(),
2427        |connection_id| {
2428            session
2429                .peer
2430                .forward_send(session.connection_id, connection_id, request.clone())
2431        },
2432    );
2433    Ok(())
2434}
2435
2436/// Start following another user in a call.
2437async fn follow(
2438    request: proto::Follow,
2439    response: Response<proto::Follow>,
2440    session: MessageContext,
2441) -> Result<()> {
2442    let room_id = RoomId::from_proto(request.room_id);
2443    let project_id = request.project_id.map(ProjectId::from_proto);
2444    let leader_id = request.leader_id.context("invalid leader id")?.into();
2445    let follower_id = session.connection_id;
2446
2447    session
2448        .db()
2449        .await
2450        .check_room_participants(room_id, leader_id, session.connection_id)
2451        .await?;
2452
2453    let response_payload = session.forward_request(leader_id, request).await?;
2454    response.send(response_payload)?;
2455
2456    if let Some(project_id) = project_id {
2457        let room = session
2458            .db()
2459            .await
2460            .follow(room_id, project_id, leader_id, follower_id)
2461            .await?;
2462        room_updated(&room, &session.peer);
2463    }
2464
2465    Ok(())
2466}
2467
2468/// Stop following another user in a call.
2469async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2470    let room_id = RoomId::from_proto(request.room_id);
2471    let project_id = request.project_id.map(ProjectId::from_proto);
2472    let leader_id = request.leader_id.context("invalid leader id")?.into();
2473    let follower_id = session.connection_id;
2474
2475    session
2476        .db()
2477        .await
2478        .check_room_participants(room_id, leader_id, session.connection_id)
2479        .await?;
2480
2481    session
2482        .peer
2483        .forward_send(session.connection_id, leader_id, request)?;
2484
2485    if let Some(project_id) = project_id {
2486        let room = session
2487            .db()
2488            .await
2489            .unfollow(room_id, project_id, leader_id, follower_id)
2490            .await?;
2491        room_updated(&room, &session.peer);
2492    }
2493
2494    Ok(())
2495}
2496
2497/// Notify everyone following you of your current location.
2498async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2499    let room_id = RoomId::from_proto(request.room_id);
2500    let database = session.db.lock().await;
2501
2502    let connection_ids = if let Some(project_id) = request.project_id {
2503        let project_id = ProjectId::from_proto(project_id);
2504        database
2505            .project_connection_ids(project_id, session.connection_id, true)
2506            .await?
2507    } else {
2508        database
2509            .room_connection_ids(room_id, session.connection_id)
2510            .await?
2511    };
2512
2513    // For now, don't send view update messages back to that view's current leader.
2514    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2515        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2516        _ => None,
2517    });
2518
2519    for connection_id in connection_ids.iter().cloned() {
2520        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2521            session
2522                .peer
2523                .forward_send(session.connection_id, connection_id, request.clone())?;
2524        }
2525    }
2526    Ok(())
2527}
2528
2529/// Get public data about users.
2530async fn get_users(
2531    request: proto::GetUsers,
2532    response: Response<proto::GetUsers>,
2533    session: MessageContext,
2534) -> Result<()> {
2535    let user_ids = request
2536        .user_ids
2537        .into_iter()
2538        .map(UserId::from_proto)
2539        .collect();
2540    let users = session
2541        .db()
2542        .await
2543        .get_users_by_ids(user_ids)
2544        .await?
2545        .into_iter()
2546        .map(|user| proto::User {
2547            id: user.id.to_proto(),
2548            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2549            github_login: user.github_login,
2550            name: user.name,
2551        })
2552        .collect();
2553    response.send(proto::UsersResponse { users })?;
2554    Ok(())
2555}
2556
2557/// Search for users (to invite) buy Github login
2558async fn fuzzy_search_users(
2559    request: proto::FuzzySearchUsers,
2560    response: Response<proto::FuzzySearchUsers>,
2561    session: MessageContext,
2562) -> Result<()> {
2563    let query = request.query;
2564    let users = match query.len() {
2565        0 => vec![],
2566        1 | 2 => session
2567            .db()
2568            .await
2569            .get_user_by_github_login(&query)
2570            .await?
2571            .into_iter()
2572            .collect(),
2573        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2574    };
2575    let users = users
2576        .into_iter()
2577        .filter(|user| user.id != session.user_id())
2578        .map(|user| proto::User {
2579            id: user.id.to_proto(),
2580            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2581            github_login: user.github_login,
2582            name: user.name,
2583        })
2584        .collect();
2585    response.send(proto::UsersResponse { users })?;
2586    Ok(())
2587}
2588
2589/// Send a contact request to another user.
2590async fn request_contact(
2591    request: proto::RequestContact,
2592    response: Response<proto::RequestContact>,
2593    session: MessageContext,
2594) -> Result<()> {
2595    let requester_id = session.user_id();
2596    let responder_id = UserId::from_proto(request.responder_id);
2597    if requester_id == responder_id {
2598        return Err(anyhow!("cannot add yourself as a contact"))?;
2599    }
2600
2601    let notifications = session
2602        .db()
2603        .await
2604        .send_contact_request(requester_id, responder_id)
2605        .await?;
2606
2607    // Update outgoing contact requests of requester
2608    let mut update = proto::UpdateContacts::default();
2609    update.outgoing_requests.push(responder_id.to_proto());
2610    for connection_id in session
2611        .connection_pool()
2612        .await
2613        .user_connection_ids(requester_id)
2614    {
2615        session.peer.send(connection_id, update.clone())?;
2616    }
2617
2618    // Update incoming contact requests of responder
2619    let mut update = proto::UpdateContacts::default();
2620    update
2621        .incoming_requests
2622        .push(proto::IncomingContactRequest {
2623            requester_id: requester_id.to_proto(),
2624        });
2625    let connection_pool = session.connection_pool().await;
2626    for connection_id in connection_pool.user_connection_ids(responder_id) {
2627        session.peer.send(connection_id, update.clone())?;
2628    }
2629
2630    send_notifications(&connection_pool, &session.peer, notifications);
2631
2632    response.send(proto::Ack {})?;
2633    Ok(())
2634}
2635
2636/// Accept or decline a contact request
2637async fn respond_to_contact_request(
2638    request: proto::RespondToContactRequest,
2639    response: Response<proto::RespondToContactRequest>,
2640    session: MessageContext,
2641) -> Result<()> {
2642    let responder_id = session.user_id();
2643    let requester_id = UserId::from_proto(request.requester_id);
2644    let db = session.db().await;
2645    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2646        db.dismiss_contact_notification(responder_id, requester_id)
2647            .await?;
2648    } else {
2649        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2650
2651        let notifications = db
2652            .respond_to_contact_request(responder_id, requester_id, accept)
2653            .await?;
2654        let requester_busy = db.is_user_busy(requester_id).await?;
2655        let responder_busy = db.is_user_busy(responder_id).await?;
2656
2657        let pool = session.connection_pool().await;
2658        // Update responder with new contact
2659        let mut update = proto::UpdateContacts::default();
2660        if accept {
2661            update
2662                .contacts
2663                .push(contact_for_user(requester_id, requester_busy, &pool));
2664        }
2665        update
2666            .remove_incoming_requests
2667            .push(requester_id.to_proto());
2668        for connection_id in pool.user_connection_ids(responder_id) {
2669            session.peer.send(connection_id, update.clone())?;
2670        }
2671
2672        // Update requester with new contact
2673        let mut update = proto::UpdateContacts::default();
2674        if accept {
2675            update
2676                .contacts
2677                .push(contact_for_user(responder_id, responder_busy, &pool));
2678        }
2679        update
2680            .remove_outgoing_requests
2681            .push(responder_id.to_proto());
2682
2683        for connection_id in pool.user_connection_ids(requester_id) {
2684            session.peer.send(connection_id, update.clone())?;
2685        }
2686
2687        send_notifications(&pool, &session.peer, notifications);
2688    }
2689
2690    response.send(proto::Ack {})?;
2691    Ok(())
2692}
2693
2694/// Remove a contact.
2695async fn remove_contact(
2696    request: proto::RemoveContact,
2697    response: Response<proto::RemoveContact>,
2698    session: MessageContext,
2699) -> Result<()> {
2700    let requester_id = session.user_id();
2701    let responder_id = UserId::from_proto(request.user_id);
2702    let db = session.db().await;
2703    let (contact_accepted, deleted_notification_id) =
2704        db.remove_contact(requester_id, responder_id).await?;
2705
2706    let pool = session.connection_pool().await;
2707    // Update outgoing contact requests of requester
2708    let mut update = proto::UpdateContacts::default();
2709    if contact_accepted {
2710        update.remove_contacts.push(responder_id.to_proto());
2711    } else {
2712        update
2713            .remove_outgoing_requests
2714            .push(responder_id.to_proto());
2715    }
2716    for connection_id in pool.user_connection_ids(requester_id) {
2717        session.peer.send(connection_id, update.clone())?;
2718    }
2719
2720    // Update incoming contact requests of responder
2721    let mut update = proto::UpdateContacts::default();
2722    if contact_accepted {
2723        update.remove_contacts.push(requester_id.to_proto());
2724    } else {
2725        update
2726            .remove_incoming_requests
2727            .push(requester_id.to_proto());
2728    }
2729    for connection_id in pool.user_connection_ids(responder_id) {
2730        session.peer.send(connection_id, update.clone())?;
2731        if let Some(notification_id) = deleted_notification_id {
2732            session.peer.send(
2733                connection_id,
2734                proto::DeleteNotification {
2735                    notification_id: notification_id.to_proto(),
2736                },
2737            )?;
2738        }
2739    }
2740
2741    response.send(proto::Ack {})?;
2742    Ok(())
2743}
2744
2745fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2746    version.0.minor < 139
2747}
2748
2749async fn subscribe_to_channels(
2750    _: proto::SubscribeToChannels,
2751    session: MessageContext,
2752) -> Result<()> {
2753    subscribe_user_to_channels(session.user_id(), &session).await?;
2754    Ok(())
2755}
2756
2757async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2758    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2759    let mut pool = session.connection_pool().await;
2760    for membership in &channels_for_user.channel_memberships {
2761        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2762    }
2763    session.peer.send(
2764        session.connection_id,
2765        build_update_user_channels(&channels_for_user),
2766    )?;
2767    session.peer.send(
2768        session.connection_id,
2769        build_channels_update(channels_for_user),
2770    )?;
2771    Ok(())
2772}
2773
2774/// Creates a new channel.
2775async fn create_channel(
2776    request: proto::CreateChannel,
2777    response: Response<proto::CreateChannel>,
2778    session: MessageContext,
2779) -> Result<()> {
2780    let db = session.db().await;
2781
2782    let parent_id = request.parent_id.map(ChannelId::from_proto);
2783    let (channel, membership) = db
2784        .create_channel(&request.name, parent_id, session.user_id())
2785        .await?;
2786
2787    let root_id = channel.root_id();
2788    let channel = Channel::from_model(channel);
2789
2790    response.send(proto::CreateChannelResponse {
2791        channel: Some(channel.to_proto()),
2792        parent_id: request.parent_id,
2793    })?;
2794
2795    let mut connection_pool = session.connection_pool().await;
2796    if let Some(membership) = membership {
2797        connection_pool.subscribe_to_channel(
2798            membership.user_id,
2799            membership.channel_id,
2800            membership.role,
2801        );
2802        let update = proto::UpdateUserChannels {
2803            channel_memberships: vec![proto::ChannelMembership {
2804                channel_id: membership.channel_id.to_proto(),
2805                role: membership.role.into(),
2806            }],
2807            ..Default::default()
2808        };
2809        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2810            session.peer.send(connection_id, update.clone())?;
2811        }
2812    }
2813
2814    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2815        if !role.can_see_channel(channel.visibility) {
2816            continue;
2817        }
2818
2819        let update = proto::UpdateChannels {
2820            channels: vec![channel.to_proto()],
2821            ..Default::default()
2822        };
2823        session.peer.send(connection_id, update.clone())?;
2824    }
2825
2826    Ok(())
2827}
2828
2829/// Delete a channel
2830async fn delete_channel(
2831    request: proto::DeleteChannel,
2832    response: Response<proto::DeleteChannel>,
2833    session: MessageContext,
2834) -> Result<()> {
2835    let db = session.db().await;
2836
2837    let channel_id = request.channel_id;
2838    let (root_channel, removed_channels) = db
2839        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2840        .await?;
2841    response.send(proto::Ack {})?;
2842
2843    // Notify members of removed channels
2844    let mut update = proto::UpdateChannels::default();
2845    update
2846        .delete_channels
2847        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2848
2849    let connection_pool = session.connection_pool().await;
2850    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2851        session.peer.send(connection_id, update.clone())?;
2852    }
2853
2854    Ok(())
2855}
2856
2857/// Invite someone to join a channel.
2858async fn invite_channel_member(
2859    request: proto::InviteChannelMember,
2860    response: Response<proto::InviteChannelMember>,
2861    session: MessageContext,
2862) -> Result<()> {
2863    let db = session.db().await;
2864    let channel_id = ChannelId::from_proto(request.channel_id);
2865    let invitee_id = UserId::from_proto(request.user_id);
2866    let InviteMemberResult {
2867        channel,
2868        notifications,
2869    } = db
2870        .invite_channel_member(
2871            channel_id,
2872            invitee_id,
2873            session.user_id(),
2874            request.role().into(),
2875        )
2876        .await?;
2877
2878    let update = proto::UpdateChannels {
2879        channel_invitations: vec![channel.to_proto()],
2880        ..Default::default()
2881    };
2882
2883    let connection_pool = session.connection_pool().await;
2884    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2885        session.peer.send(connection_id, update.clone())?;
2886    }
2887
2888    send_notifications(&connection_pool, &session.peer, notifications);
2889
2890    response.send(proto::Ack {})?;
2891    Ok(())
2892}
2893
2894/// remove someone from a channel
2895async fn remove_channel_member(
2896    request: proto::RemoveChannelMember,
2897    response: Response<proto::RemoveChannelMember>,
2898    session: MessageContext,
2899) -> Result<()> {
2900    let db = session.db().await;
2901    let channel_id = ChannelId::from_proto(request.channel_id);
2902    let member_id = UserId::from_proto(request.user_id);
2903
2904    let RemoveChannelMemberResult {
2905        membership_update,
2906        notification_id,
2907    } = db
2908        .remove_channel_member(channel_id, member_id, session.user_id())
2909        .await?;
2910
2911    let mut connection_pool = session.connection_pool().await;
2912    notify_membership_updated(
2913        &mut connection_pool,
2914        membership_update,
2915        member_id,
2916        &session.peer,
2917    );
2918    for connection_id in connection_pool.user_connection_ids(member_id) {
2919        if let Some(notification_id) = notification_id {
2920            session
2921                .peer
2922                .send(
2923                    connection_id,
2924                    proto::DeleteNotification {
2925                        notification_id: notification_id.to_proto(),
2926                    },
2927                )
2928                .trace_err();
2929        }
2930    }
2931
2932    response.send(proto::Ack {})?;
2933    Ok(())
2934}
2935
2936/// Toggle the channel between public and private.
2937/// Care is taken to maintain the invariant that public channels only descend from public channels,
2938/// (though members-only channels can appear at any point in the hierarchy).
2939async fn set_channel_visibility(
2940    request: proto::SetChannelVisibility,
2941    response: Response<proto::SetChannelVisibility>,
2942    session: MessageContext,
2943) -> Result<()> {
2944    let db = session.db().await;
2945    let channel_id = ChannelId::from_proto(request.channel_id);
2946    let visibility = request.visibility().into();
2947
2948    let channel_model = db
2949        .set_channel_visibility(channel_id, visibility, session.user_id())
2950        .await?;
2951    let root_id = channel_model.root_id();
2952    let channel = Channel::from_model(channel_model);
2953
2954    let mut connection_pool = session.connection_pool().await;
2955    for (user_id, role) in connection_pool
2956        .channel_user_ids(root_id)
2957        .collect::<Vec<_>>()
2958        .into_iter()
2959    {
2960        let update = if role.can_see_channel(channel.visibility) {
2961            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2962            proto::UpdateChannels {
2963                channels: vec![channel.to_proto()],
2964                ..Default::default()
2965            }
2966        } else {
2967            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2968            proto::UpdateChannels {
2969                delete_channels: vec![channel.id.to_proto()],
2970                ..Default::default()
2971            }
2972        };
2973
2974        for connection_id in connection_pool.user_connection_ids(user_id) {
2975            session.peer.send(connection_id, update.clone())?;
2976        }
2977    }
2978
2979    response.send(proto::Ack {})?;
2980    Ok(())
2981}
2982
2983/// Alter the role for a user in the channel.
2984async fn set_channel_member_role(
2985    request: proto::SetChannelMemberRole,
2986    response: Response<proto::SetChannelMemberRole>,
2987    session: MessageContext,
2988) -> Result<()> {
2989    let db = session.db().await;
2990    let channel_id = ChannelId::from_proto(request.channel_id);
2991    let member_id = UserId::from_proto(request.user_id);
2992    let result = db
2993        .set_channel_member_role(
2994            channel_id,
2995            session.user_id(),
2996            member_id,
2997            request.role().into(),
2998        )
2999        .await?;
3000
3001    match result {
3002        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3003            let mut connection_pool = session.connection_pool().await;
3004            notify_membership_updated(
3005                &mut connection_pool,
3006                membership_update,
3007                member_id,
3008                &session.peer,
3009            )
3010        }
3011        db::SetMemberRoleResult::InviteUpdated(channel) => {
3012            let update = proto::UpdateChannels {
3013                channel_invitations: vec![channel.to_proto()],
3014                ..Default::default()
3015            };
3016
3017            for connection_id in session
3018                .connection_pool()
3019                .await
3020                .user_connection_ids(member_id)
3021            {
3022                session.peer.send(connection_id, update.clone())?;
3023            }
3024        }
3025    }
3026
3027    response.send(proto::Ack {})?;
3028    Ok(())
3029}
3030
3031/// Change the name of a channel
3032async fn rename_channel(
3033    request: proto::RenameChannel,
3034    response: Response<proto::RenameChannel>,
3035    session: MessageContext,
3036) -> Result<()> {
3037    let db = session.db().await;
3038    let channel_id = ChannelId::from_proto(request.channel_id);
3039    let channel_model = db
3040        .rename_channel(channel_id, session.user_id(), &request.name)
3041        .await?;
3042    let root_id = channel_model.root_id();
3043    let channel = Channel::from_model(channel_model);
3044
3045    response.send(proto::RenameChannelResponse {
3046        channel: Some(channel.to_proto()),
3047    })?;
3048
3049    let connection_pool = session.connection_pool().await;
3050    let update = proto::UpdateChannels {
3051        channels: vec![channel.to_proto()],
3052        ..Default::default()
3053    };
3054    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3055        if role.can_see_channel(channel.visibility) {
3056            session.peer.send(connection_id, update.clone())?;
3057        }
3058    }
3059
3060    Ok(())
3061}
3062
3063/// Move a channel to a new parent.
3064async fn move_channel(
3065    request: proto::MoveChannel,
3066    response: Response<proto::MoveChannel>,
3067    session: MessageContext,
3068) -> Result<()> {
3069    let channel_id = ChannelId::from_proto(request.channel_id);
3070    let to = ChannelId::from_proto(request.to);
3071
3072    let (root_id, channels) = session
3073        .db()
3074        .await
3075        .move_channel(channel_id, to, session.user_id())
3076        .await?;
3077
3078    let connection_pool = session.connection_pool().await;
3079    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3080        let channels = channels
3081            .iter()
3082            .filter_map(|channel| {
3083                if role.can_see_channel(channel.visibility) {
3084                    Some(channel.to_proto())
3085                } else {
3086                    None
3087                }
3088            })
3089            .collect::<Vec<_>>();
3090        if channels.is_empty() {
3091            continue;
3092        }
3093
3094        let update = proto::UpdateChannels {
3095            channels,
3096            ..Default::default()
3097        };
3098
3099        session.peer.send(connection_id, update.clone())?;
3100    }
3101
3102    response.send(Ack {})?;
3103    Ok(())
3104}
3105
3106async fn reorder_channel(
3107    request: proto::ReorderChannel,
3108    response: Response<proto::ReorderChannel>,
3109    session: MessageContext,
3110) -> Result<()> {
3111    let channel_id = ChannelId::from_proto(request.channel_id);
3112    let direction = request.direction();
3113
3114    let updated_channels = session
3115        .db()
3116        .await
3117        .reorder_channel(channel_id, direction, session.user_id())
3118        .await?;
3119
3120    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3121        let connection_pool = session.connection_pool().await;
3122        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3123            let channels = updated_channels
3124                .iter()
3125                .filter_map(|channel| {
3126                    if role.can_see_channel(channel.visibility) {
3127                        Some(channel.to_proto())
3128                    } else {
3129                        None
3130                    }
3131                })
3132                .collect::<Vec<_>>();
3133
3134            if channels.is_empty() {
3135                continue;
3136            }
3137
3138            let update = proto::UpdateChannels {
3139                channels,
3140                ..Default::default()
3141            };
3142
3143            session.peer.send(connection_id, update.clone())?;
3144        }
3145    }
3146
3147    response.send(Ack {})?;
3148    Ok(())
3149}
3150
3151/// Get the list of channel members
3152async fn get_channel_members(
3153    request: proto::GetChannelMembers,
3154    response: Response<proto::GetChannelMembers>,
3155    session: MessageContext,
3156) -> Result<()> {
3157    let db = session.db().await;
3158    let channel_id = ChannelId::from_proto(request.channel_id);
3159    let limit = if request.limit == 0 {
3160        u16::MAX as u64
3161    } else {
3162        request.limit
3163    };
3164    let (members, users) = db
3165        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3166        .await?;
3167    response.send(proto::GetChannelMembersResponse { members, users })?;
3168    Ok(())
3169}
3170
3171/// Accept or decline a channel invitation.
3172async fn respond_to_channel_invite(
3173    request: proto::RespondToChannelInvite,
3174    response: Response<proto::RespondToChannelInvite>,
3175    session: MessageContext,
3176) -> Result<()> {
3177    let db = session.db().await;
3178    let channel_id = ChannelId::from_proto(request.channel_id);
3179    let RespondToChannelInvite {
3180        membership_update,
3181        notifications,
3182    } = db
3183        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3184        .await?;
3185
3186    let mut connection_pool = session.connection_pool().await;
3187    if let Some(membership_update) = membership_update {
3188        notify_membership_updated(
3189            &mut connection_pool,
3190            membership_update,
3191            session.user_id(),
3192            &session.peer,
3193        );
3194    } else {
3195        let update = proto::UpdateChannels {
3196            remove_channel_invitations: vec![channel_id.to_proto()],
3197            ..Default::default()
3198        };
3199
3200        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3201            session.peer.send(connection_id, update.clone())?;
3202        }
3203    };
3204
3205    send_notifications(&connection_pool, &session.peer, notifications);
3206
3207    response.send(proto::Ack {})?;
3208
3209    Ok(())
3210}
3211
3212/// Join the channels' room
3213async fn join_channel(
3214    request: proto::JoinChannel,
3215    response: Response<proto::JoinChannel>,
3216    session: MessageContext,
3217) -> Result<()> {
3218    let channel_id = ChannelId::from_proto(request.channel_id);
3219    join_channel_internal(channel_id, Box::new(response), session).await
3220}
3221
3222trait JoinChannelInternalResponse {
3223    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3224}
3225impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3226    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3227        Response::<proto::JoinChannel>::send(self, result)
3228    }
3229}
3230impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3231    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3232        Response::<proto::JoinRoom>::send(self, result)
3233    }
3234}
3235
3236async fn join_channel_internal(
3237    channel_id: ChannelId,
3238    response: Box<impl JoinChannelInternalResponse>,
3239    session: MessageContext,
3240) -> Result<()> {
3241    let joined_room = {
3242        let mut db = session.db().await;
3243        // If zed quits without leaving the room, and the user re-opens zed before the
3244        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3245        // room they were in.
3246        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3247            tracing::info!(
3248                stale_connection_id = %connection,
3249                "cleaning up stale connection",
3250            );
3251            drop(db);
3252            leave_room_for_session(&session, connection).await?;
3253            db = session.db().await;
3254        }
3255
3256        let (joined_room, membership_updated, role) = db
3257            .join_channel(channel_id, session.user_id(), session.connection_id)
3258            .await?;
3259
3260        let live_kit_connection_info =
3261            session
3262                .app_state
3263                .livekit_client
3264                .as_ref()
3265                .and_then(|live_kit| {
3266                    let (can_publish, token) = if role == ChannelRole::Guest {
3267                        (
3268                            false,
3269                            live_kit
3270                                .guest_token(
3271                                    &joined_room.room.livekit_room,
3272                                    &session.user_id().to_string(),
3273                                )
3274                                .trace_err()?,
3275                        )
3276                    } else {
3277                        (
3278                            true,
3279                            live_kit
3280                                .room_token(
3281                                    &joined_room.room.livekit_room,
3282                                    &session.user_id().to_string(),
3283                                )
3284                                .trace_err()?,
3285                        )
3286                    };
3287
3288                    Some(LiveKitConnectionInfo {
3289                        server_url: live_kit.url().into(),
3290                        token,
3291                        can_publish,
3292                    })
3293                });
3294
3295        response.send(proto::JoinRoomResponse {
3296            room: Some(joined_room.room.clone()),
3297            channel_id: joined_room
3298                .channel
3299                .as_ref()
3300                .map(|channel| channel.id.to_proto()),
3301            live_kit_connection_info,
3302        })?;
3303
3304        let mut connection_pool = session.connection_pool().await;
3305        if let Some(membership_updated) = membership_updated {
3306            notify_membership_updated(
3307                &mut connection_pool,
3308                membership_updated,
3309                session.user_id(),
3310                &session.peer,
3311            );
3312        }
3313
3314        room_updated(&joined_room.room, &session.peer);
3315
3316        joined_room
3317    };
3318
3319    channel_updated(
3320        &joined_room.channel.context("channel not returned")?,
3321        &joined_room.room,
3322        &session.peer,
3323        &*session.connection_pool().await,
3324    );
3325
3326    update_user_contacts(session.user_id(), &session).await?;
3327    Ok(())
3328}
3329
3330/// Start editing the channel notes
3331async fn join_channel_buffer(
3332    request: proto::JoinChannelBuffer,
3333    response: Response<proto::JoinChannelBuffer>,
3334    session: MessageContext,
3335) -> Result<()> {
3336    let db = session.db().await;
3337    let channel_id = ChannelId::from_proto(request.channel_id);
3338
3339    let open_response = db
3340        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3341        .await?;
3342
3343    let collaborators = open_response.collaborators.clone();
3344    response.send(open_response)?;
3345
3346    let update = UpdateChannelBufferCollaborators {
3347        channel_id: channel_id.to_proto(),
3348        collaborators: collaborators.clone(),
3349    };
3350    channel_buffer_updated(
3351        session.connection_id,
3352        collaborators
3353            .iter()
3354            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3355        &update,
3356        &session.peer,
3357    );
3358
3359    Ok(())
3360}
3361
3362/// Edit the channel notes
3363async fn update_channel_buffer(
3364    request: proto::UpdateChannelBuffer,
3365    session: MessageContext,
3366) -> Result<()> {
3367    let db = session.db().await;
3368    let channel_id = ChannelId::from_proto(request.channel_id);
3369
3370    let (collaborators, epoch, version) = db
3371        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3372        .await?;
3373
3374    channel_buffer_updated(
3375        session.connection_id,
3376        collaborators.clone(),
3377        &proto::UpdateChannelBuffer {
3378            channel_id: channel_id.to_proto(),
3379            operations: request.operations,
3380        },
3381        &session.peer,
3382    );
3383
3384    let pool = &*session.connection_pool().await;
3385
3386    let non_collaborators =
3387        pool.channel_connection_ids(channel_id)
3388            .filter_map(|(connection_id, _)| {
3389                if collaborators.contains(&connection_id) {
3390                    None
3391                } else {
3392                    Some(connection_id)
3393                }
3394            });
3395
3396    broadcast(None, non_collaborators, |peer_id| {
3397        session.peer.send(
3398            peer_id,
3399            proto::UpdateChannels {
3400                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3401                    channel_id: channel_id.to_proto(),
3402                    epoch: epoch as u64,
3403                    version: version.clone(),
3404                }],
3405                ..Default::default()
3406            },
3407        )
3408    });
3409
3410    Ok(())
3411}
3412
3413/// Rejoin the channel notes after a connection blip
3414async fn rejoin_channel_buffers(
3415    request: proto::RejoinChannelBuffers,
3416    response: Response<proto::RejoinChannelBuffers>,
3417    session: MessageContext,
3418) -> Result<()> {
3419    let db = session.db().await;
3420    let buffers = db
3421        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3422        .await?;
3423
3424    for rejoined_buffer in &buffers {
3425        let collaborators_to_notify = rejoined_buffer
3426            .buffer
3427            .collaborators
3428            .iter()
3429            .filter_map(|c| Some(c.peer_id?.into()));
3430        channel_buffer_updated(
3431            session.connection_id,
3432            collaborators_to_notify,
3433            &proto::UpdateChannelBufferCollaborators {
3434                channel_id: rejoined_buffer.buffer.channel_id,
3435                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3436            },
3437            &session.peer,
3438        );
3439    }
3440
3441    response.send(proto::RejoinChannelBuffersResponse {
3442        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3443    })?;
3444
3445    Ok(())
3446}
3447
3448/// Stop editing the channel notes
3449async fn leave_channel_buffer(
3450    request: proto::LeaveChannelBuffer,
3451    response: Response<proto::LeaveChannelBuffer>,
3452    session: MessageContext,
3453) -> Result<()> {
3454    let db = session.db().await;
3455    let channel_id = ChannelId::from_proto(request.channel_id);
3456
3457    let left_buffer = db
3458        .leave_channel_buffer(channel_id, session.connection_id)
3459        .await?;
3460
3461    response.send(Ack {})?;
3462
3463    channel_buffer_updated(
3464        session.connection_id,
3465        left_buffer.connections,
3466        &proto::UpdateChannelBufferCollaborators {
3467            channel_id: channel_id.to_proto(),
3468            collaborators: left_buffer.collaborators,
3469        },
3470        &session.peer,
3471    );
3472
3473    Ok(())
3474}
3475
3476fn channel_buffer_updated<T: EnvelopedMessage>(
3477    sender_id: ConnectionId,
3478    collaborators: impl IntoIterator<Item = ConnectionId>,
3479    message: &T,
3480    peer: &Peer,
3481) {
3482    broadcast(Some(sender_id), collaborators, |peer_id| {
3483        peer.send(peer_id, message.clone())
3484    });
3485}
3486
3487fn send_notifications(
3488    connection_pool: &ConnectionPool,
3489    peer: &Peer,
3490    notifications: db::NotificationBatch,
3491) {
3492    for (user_id, notification) in notifications {
3493        for connection_id in connection_pool.user_connection_ids(user_id) {
3494            if let Err(error) = peer.send(
3495                connection_id,
3496                proto::AddNotification {
3497                    notification: Some(notification.clone()),
3498                },
3499            ) {
3500                tracing::error!(
3501                    "failed to send notification to {:?} {}",
3502                    connection_id,
3503                    error
3504                );
3505            }
3506        }
3507    }
3508}
3509
3510/// Send a message to the channel
3511async fn send_channel_message(
3512    _request: proto::SendChannelMessage,
3513    _response: Response<proto::SendChannelMessage>,
3514    _session: MessageContext,
3515) -> Result<()> {
3516    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3517}
3518
3519/// Delete a channel message
3520async fn remove_channel_message(
3521    _request: proto::RemoveChannelMessage,
3522    _response: Response<proto::RemoveChannelMessage>,
3523    _session: MessageContext,
3524) -> Result<()> {
3525    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3526}
3527
3528async fn update_channel_message(
3529    _request: proto::UpdateChannelMessage,
3530    _response: Response<proto::UpdateChannelMessage>,
3531    _session: MessageContext,
3532) -> Result<()> {
3533    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3534}
3535
3536/// Mark a channel message as read
3537async fn acknowledge_channel_message(
3538    _request: proto::AckChannelMessage,
3539    _session: MessageContext,
3540) -> Result<()> {
3541    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3542}
3543
3544/// Mark a buffer version as synced
3545async fn acknowledge_buffer_version(
3546    request: proto::AckBufferOperation,
3547    session: MessageContext,
3548) -> Result<()> {
3549    let buffer_id = BufferId::from_proto(request.buffer_id);
3550    session
3551        .db()
3552        .await
3553        .observe_buffer_version(
3554            buffer_id,
3555            session.user_id(),
3556            request.epoch as i32,
3557            &request.version,
3558        )
3559        .await?;
3560    Ok(())
3561}
3562
3563/// Start receiving chat updates for a channel
3564async fn join_channel_chat(
3565    _request: proto::JoinChannelChat,
3566    _response: Response<proto::JoinChannelChat>,
3567    _session: MessageContext,
3568) -> Result<()> {
3569    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3570}
3571
3572/// Stop receiving chat updates for a channel
3573async fn leave_channel_chat(
3574    _request: proto::LeaveChannelChat,
3575    _session: MessageContext,
3576) -> Result<()> {
3577    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3578}
3579
3580/// Retrieve the chat history for a channel
3581async fn get_channel_messages(
3582    _request: proto::GetChannelMessages,
3583    _response: Response<proto::GetChannelMessages>,
3584    _session: MessageContext,
3585) -> Result<()> {
3586    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3587}
3588
3589/// Retrieve specific chat messages
3590async fn get_channel_messages_by_id(
3591    _request: proto::GetChannelMessagesById,
3592    _response: Response<proto::GetChannelMessagesById>,
3593    _session: MessageContext,
3594) -> Result<()> {
3595    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3596}
3597
3598/// Retrieve the current users notifications
3599async fn get_notifications(
3600    request: proto::GetNotifications,
3601    response: Response<proto::GetNotifications>,
3602    session: MessageContext,
3603) -> Result<()> {
3604    let notifications = session
3605        .db()
3606        .await
3607        .get_notifications(
3608            session.user_id(),
3609            NOTIFICATION_COUNT_PER_PAGE,
3610            request.before_id.map(db::NotificationId::from_proto),
3611        )
3612        .await?;
3613    response.send(proto::GetNotificationsResponse {
3614        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3615        notifications,
3616    })?;
3617    Ok(())
3618}
3619
3620/// Mark notifications as read
3621async fn mark_notification_as_read(
3622    request: proto::MarkNotificationRead,
3623    response: Response<proto::MarkNotificationRead>,
3624    session: MessageContext,
3625) -> Result<()> {
3626    let database = &session.db().await;
3627    let notifications = database
3628        .mark_notification_as_read_by_id(
3629            session.user_id(),
3630            NotificationId::from_proto(request.notification_id),
3631        )
3632        .await?;
3633    send_notifications(
3634        &*session.connection_pool().await,
3635        &session.peer,
3636        notifications,
3637    );
3638    response.send(proto::Ack {})?;
3639    Ok(())
3640}
3641
3642fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3643    let message = match message {
3644        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3645        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3646        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3647        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3648        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3649            code: frame.code.into(),
3650            reason: frame.reason.as_str().to_owned().into(),
3651        })),
3652        // We should never receive a frame while reading the message, according
3653        // to the `tungstenite` maintainers:
3654        //
3655        // > It cannot occur when you read messages from the WebSocket, but it
3656        // > can be used when you want to send the raw frames (e.g. you want to
3657        // > send the frames to the WebSocket without composing the full message first).
3658        // >
3659        // > — https://github.com/snapview/tungstenite-rs/issues/268
3660        TungsteniteMessage::Frame(_) => {
3661            bail!("received an unexpected frame while reading the message")
3662        }
3663    };
3664
3665    Ok(message)
3666}
3667
3668fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3669    match message {
3670        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3671        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3672        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3673        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3674        AxumMessage::Close(frame) => {
3675            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3676                code: frame.code.into(),
3677                reason: frame.reason.as_ref().into(),
3678            }))
3679        }
3680    }
3681}
3682
3683fn notify_membership_updated(
3684    connection_pool: &mut ConnectionPool,
3685    result: MembershipUpdated,
3686    user_id: UserId,
3687    peer: &Peer,
3688) {
3689    for membership in &result.new_channels.channel_memberships {
3690        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3691    }
3692    for channel_id in &result.removed_channels {
3693        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3694    }
3695
3696    let user_channels_update = proto::UpdateUserChannels {
3697        channel_memberships: result
3698            .new_channels
3699            .channel_memberships
3700            .iter()
3701            .map(|cm| proto::ChannelMembership {
3702                channel_id: cm.channel_id.to_proto(),
3703                role: cm.role.into(),
3704            })
3705            .collect(),
3706        ..Default::default()
3707    };
3708
3709    let mut update = build_channels_update(result.new_channels);
3710    update.delete_channels = result
3711        .removed_channels
3712        .into_iter()
3713        .map(|id| id.to_proto())
3714        .collect();
3715    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3716
3717    for connection_id in connection_pool.user_connection_ids(user_id) {
3718        peer.send(connection_id, user_channels_update.clone())
3719            .trace_err();
3720        peer.send(connection_id, update.clone()).trace_err();
3721    }
3722}
3723
3724fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3725    proto::UpdateUserChannels {
3726        channel_memberships: channels
3727            .channel_memberships
3728            .iter()
3729            .map(|m| proto::ChannelMembership {
3730                channel_id: m.channel_id.to_proto(),
3731                role: m.role.into(),
3732            })
3733            .collect(),
3734        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3735    }
3736}
3737
3738fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3739    let mut update = proto::UpdateChannels::default();
3740
3741    for channel in channels.channels {
3742        update.channels.push(channel.to_proto());
3743    }
3744
3745    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3746
3747    for (channel_id, participants) in channels.channel_participants {
3748        update
3749            .channel_participants
3750            .push(proto::ChannelParticipants {
3751                channel_id: channel_id.to_proto(),
3752                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3753            });
3754    }
3755
3756    for channel in channels.invited_channels {
3757        update.channel_invitations.push(channel.to_proto());
3758    }
3759
3760    update
3761}
3762
3763fn build_initial_contacts_update(
3764    contacts: Vec<db::Contact>,
3765    pool: &ConnectionPool,
3766) -> proto::UpdateContacts {
3767    let mut update = proto::UpdateContacts::default();
3768
3769    for contact in contacts {
3770        match contact {
3771            db::Contact::Accepted { user_id, busy } => {
3772                update.contacts.push(contact_for_user(user_id, busy, pool));
3773            }
3774            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3775            db::Contact::Incoming { user_id } => {
3776                update
3777                    .incoming_requests
3778                    .push(proto::IncomingContactRequest {
3779                        requester_id: user_id.to_proto(),
3780                    })
3781            }
3782        }
3783    }
3784
3785    update
3786}
3787
3788fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3789    proto::Contact {
3790        user_id: user_id.to_proto(),
3791        online: pool.is_user_online(user_id),
3792        busy,
3793    }
3794}
3795
3796fn room_updated(room: &proto::Room, peer: &Peer) {
3797    broadcast(
3798        None,
3799        room.participants
3800            .iter()
3801            .filter_map(|participant| Some(participant.peer_id?.into())),
3802        |peer_id| {
3803            peer.send(
3804                peer_id,
3805                proto::RoomUpdated {
3806                    room: Some(room.clone()),
3807                },
3808            )
3809        },
3810    );
3811}
3812
3813fn channel_updated(
3814    channel: &db::channel::Model,
3815    room: &proto::Room,
3816    peer: &Peer,
3817    pool: &ConnectionPool,
3818) {
3819    let participants = room
3820        .participants
3821        .iter()
3822        .map(|p| p.user_id)
3823        .collect::<Vec<_>>();
3824
3825    broadcast(
3826        None,
3827        pool.channel_connection_ids(channel.root_id())
3828            .filter_map(|(channel_id, role)| {
3829                role.can_see_channel(channel.visibility)
3830                    .then_some(channel_id)
3831            }),
3832        |peer_id| {
3833            peer.send(
3834                peer_id,
3835                proto::UpdateChannels {
3836                    channel_participants: vec![proto::ChannelParticipants {
3837                        channel_id: channel.id.to_proto(),
3838                        participant_user_ids: participants.clone(),
3839                    }],
3840                    ..Default::default()
3841                },
3842            )
3843        },
3844    );
3845}
3846
3847async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3848    let db = session.db().await;
3849
3850    let contacts = db.get_contacts(user_id).await?;
3851    let busy = db.is_user_busy(user_id).await?;
3852
3853    let pool = session.connection_pool().await;
3854    let updated_contact = contact_for_user(user_id, busy, &pool);
3855    for contact in contacts {
3856        if let db::Contact::Accepted {
3857            user_id: contact_user_id,
3858            ..
3859        } = contact
3860        {
3861            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3862                session
3863                    .peer
3864                    .send(
3865                        contact_conn_id,
3866                        proto::UpdateContacts {
3867                            contacts: vec![updated_contact.clone()],
3868                            remove_contacts: Default::default(),
3869                            incoming_requests: Default::default(),
3870                            remove_incoming_requests: Default::default(),
3871                            outgoing_requests: Default::default(),
3872                            remove_outgoing_requests: Default::default(),
3873                        },
3874                    )
3875                    .trace_err();
3876            }
3877        }
3878    }
3879    Ok(())
3880}
3881
3882async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3883    let mut contacts_to_update = HashSet::default();
3884
3885    let room_id;
3886    let canceled_calls_to_user_ids;
3887    let livekit_room;
3888    let delete_livekit_room;
3889    let room;
3890    let channel;
3891
3892    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3893        contacts_to_update.insert(session.user_id());
3894
3895        for project in left_room.left_projects.values() {
3896            project_left(project, session);
3897        }
3898
3899        room_id = RoomId::from_proto(left_room.room.id);
3900        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3901        livekit_room = mem::take(&mut left_room.room.livekit_room);
3902        delete_livekit_room = left_room.deleted;
3903        room = mem::take(&mut left_room.room);
3904        channel = mem::take(&mut left_room.channel);
3905
3906        room_updated(&room, &session.peer);
3907    } else {
3908        return Ok(());
3909    }
3910
3911    if let Some(channel) = channel {
3912        channel_updated(
3913            &channel,
3914            &room,
3915            &session.peer,
3916            &*session.connection_pool().await,
3917        );
3918    }
3919
3920    {
3921        let pool = session.connection_pool().await;
3922        for canceled_user_id in canceled_calls_to_user_ids {
3923            for connection_id in pool.user_connection_ids(canceled_user_id) {
3924                session
3925                    .peer
3926                    .send(
3927                        connection_id,
3928                        proto::CallCanceled {
3929                            room_id: room_id.to_proto(),
3930                        },
3931                    )
3932                    .trace_err();
3933            }
3934            contacts_to_update.insert(canceled_user_id);
3935        }
3936    }
3937
3938    for contact_user_id in contacts_to_update {
3939        update_user_contacts(contact_user_id, session).await?;
3940    }
3941
3942    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3943        live_kit
3944            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3945            .await
3946            .trace_err();
3947
3948        if delete_livekit_room {
3949            live_kit.delete_room(livekit_room).await.trace_err();
3950        }
3951    }
3952
3953    Ok(())
3954}
3955
3956async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3957    let left_channel_buffers = session
3958        .db()
3959        .await
3960        .leave_channel_buffers(session.connection_id)
3961        .await?;
3962
3963    for left_buffer in left_channel_buffers {
3964        channel_buffer_updated(
3965            session.connection_id,
3966            left_buffer.connections,
3967            &proto::UpdateChannelBufferCollaborators {
3968                channel_id: left_buffer.channel_id.to_proto(),
3969                collaborators: left_buffer.collaborators,
3970            },
3971            &session.peer,
3972        );
3973    }
3974
3975    Ok(())
3976}
3977
3978fn project_left(project: &db::LeftProject, session: &Session) {
3979    for connection_id in &project.connection_ids {
3980        if project.should_unshare {
3981            session
3982                .peer
3983                .send(
3984                    *connection_id,
3985                    proto::UnshareProject {
3986                        project_id: project.id.to_proto(),
3987                    },
3988                )
3989                .trace_err();
3990        } else {
3991            session
3992                .peer
3993                .send(
3994                    *connection_id,
3995                    proto::RemoveProjectCollaborator {
3996                        project_id: project.id.to_proto(),
3997                        peer_id: Some(session.connection_id.into()),
3998                    },
3999                )
4000                .trace_err();
4001        }
4002    }
4003}
4004
4005async fn share_agent_thread(
4006    request: proto::ShareAgentThread,
4007    response: Response<proto::ShareAgentThread>,
4008    session: MessageContext,
4009) -> Result<()> {
4010    let user_id = session.user_id();
4011
4012    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4013        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4014
4015    session
4016        .db()
4017        .await
4018        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4019        .await?;
4020
4021    response.send(proto::Ack {})?;
4022
4023    Ok(())
4024}
4025
4026async fn get_shared_agent_thread(
4027    request: proto::GetSharedAgentThread,
4028    response: Response<proto::GetSharedAgentThread>,
4029    session: MessageContext,
4030) -> Result<()> {
4031    let share_id = SharedThreadId::from_proto(request.session_id)
4032        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4033
4034    let result = session.db().await.get_shared_thread(share_id).await?;
4035
4036    match result {
4037        Some((thread, username)) => {
4038            response.send(proto::GetSharedAgentThreadResponse {
4039                title: thread.title,
4040                thread_data: thread.data,
4041                sharer_username: username,
4042                created_at: thread.created_at.and_utc().to_rfc3339(),
4043            })?;
4044        }
4045        None => {
4046            return Err(anyhow!("Shared thread not found").into());
4047        }
4048    }
4049
4050    Ok(())
4051}
4052
4053pub trait ResultExt {
4054    type Ok;
4055
4056    fn trace_err(self) -> Option<Self::Ok>;
4057}
4058
4059impl<T, E> ResultExt for Result<T, E>
4060where
4061    E: std::fmt::Debug,
4062{
4063    type Ok = T;
4064
4065    #[track_caller]
4066    fn trace_err(self) -> Option<T> {
4067        match self {
4068            Ok(value) => Some(value),
4069            Err(error) => {
4070                tracing::error!("{:?}", error);
4071                None
4072            }
4073        }
4074    }
4075}