rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 415            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 416            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 417            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 419            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 421            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 422            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 423            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 426            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 427            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 430            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 431            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 437            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
 439            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 440            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 441            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 442            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 443            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 444            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 445            .add_request_handler(share_agent_thread)
 446            .add_request_handler(get_shared_agent_thread)
 447            .add_request_handler(forward_project_search_chunk);
 448
 449        Arc::new(server)
 450    }
 451
 452    pub async fn start(&self) -> Result<()> {
 453        let server_id = *self.id.lock();
 454        let app_state = self.app_state.clone();
 455        let peer = self.peer.clone();
 456        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 457        let pool = self.connection_pool.clone();
 458        let livekit_client = self.app_state.livekit_client.clone();
 459
 460        let span = info_span!("start server");
 461        self.app_state.executor.spawn_detached(
 462            async move {
 463                tracing::info!("waiting for cleanup timeout");
 464                timeout.await;
 465                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 466
 467                app_state
 468                    .db
 469                    .delete_stale_channel_chat_participants(
 470                        &app_state.config.zed_environment,
 471                        server_id,
 472                    )
 473                    .await
 474                    .trace_err();
 475
 476                if let Some((room_ids, channel_ids)) = app_state
 477                    .db
 478                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 479                    .await
 480                    .trace_err()
 481                {
 482                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 483                    tracing::info!(
 484                        stale_channel_buffer_count = channel_ids.len(),
 485                        "retrieved stale channel buffers"
 486                    );
 487
 488                    for channel_id in channel_ids {
 489                        if let Some(refreshed_channel_buffer) = app_state
 490                            .db
 491                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 492                            .await
 493                            .trace_err()
 494                        {
 495                            for connection_id in refreshed_channel_buffer.connection_ids {
 496                                peer.send(
 497                                    connection_id,
 498                                    proto::UpdateChannelBufferCollaborators {
 499                                        channel_id: channel_id.to_proto(),
 500                                        collaborators: refreshed_channel_buffer
 501                                            .collaborators
 502                                            .clone(),
 503                                    },
 504                                )
 505                                .trace_err();
 506                            }
 507                        }
 508                    }
 509
 510                    for room_id in room_ids {
 511                        let mut contacts_to_update = HashSet::default();
 512                        let mut canceled_calls_to_user_ids = Vec::new();
 513                        let mut livekit_room = String::new();
 514                        let mut delete_livekit_room = false;
 515
 516                        if let Some(mut refreshed_room) = app_state
 517                            .db
 518                            .clear_stale_room_participants(room_id, server_id)
 519                            .await
 520                            .trace_err()
 521                        {
 522                            tracing::info!(
 523                                room_id = room_id.0,
 524                                new_participant_count = refreshed_room.room.participants.len(),
 525                                "refreshed room"
 526                            );
 527                            room_updated(&refreshed_room.room, &peer);
 528                            if let Some(channel) = refreshed_room.channel.as_ref() {
 529                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 530                            }
 531                            contacts_to_update
 532                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 533                            contacts_to_update
 534                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 535                            canceled_calls_to_user_ids =
 536                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 537                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 538                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 539                        }
 540
 541                        {
 542                            let pool = pool.lock();
 543                            for canceled_user_id in canceled_calls_to_user_ids {
 544                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 545                                    peer.send(
 546                                        connection_id,
 547                                        proto::CallCanceled {
 548                                            room_id: room_id.to_proto(),
 549                                        },
 550                                    )
 551                                    .trace_err();
 552                                }
 553                            }
 554                        }
 555
 556                        for user_id in contacts_to_update {
 557                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 558                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 559                            if let Some((busy, contacts)) = busy.zip(contacts) {
 560                                let pool = pool.lock();
 561                                let updated_contact = contact_for_user(user_id, busy, &pool);
 562                                for contact in contacts {
 563                                    if let db::Contact::Accepted {
 564                                        user_id: contact_user_id,
 565                                        ..
 566                                    } = contact
 567                                    {
 568                                        for contact_conn_id in
 569                                            pool.user_connection_ids(contact_user_id)
 570                                        {
 571                                            peer.send(
 572                                                contact_conn_id,
 573                                                proto::UpdateContacts {
 574                                                    contacts: vec![updated_contact.clone()],
 575                                                    remove_contacts: Default::default(),
 576                                                    incoming_requests: Default::default(),
 577                                                    remove_incoming_requests: Default::default(),
 578                                                    outgoing_requests: Default::default(),
 579                                                    remove_outgoing_requests: Default::default(),
 580                                                },
 581                                            )
 582                                            .trace_err();
 583                                        }
 584                                    }
 585                                }
 586                            }
 587                        }
 588
 589                        if let Some(live_kit) = livekit_client.as_ref()
 590                            && delete_livekit_room
 591                        {
 592                            live_kit.delete_room(livekit_room).await.trace_err();
 593                        }
 594                    }
 595                }
 596
 597                app_state
 598                    .db
 599                    .delete_stale_channel_chat_participants(
 600                        &app_state.config.zed_environment,
 601                        server_id,
 602                    )
 603                    .await
 604                    .trace_err();
 605
 606                app_state
 607                    .db
 608                    .clear_old_worktree_entries(server_id)
 609                    .await
 610                    .trace_err();
 611
 612                app_state
 613                    .db
 614                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 615                    .await
 616                    .trace_err();
 617            }
 618            .instrument(span),
 619        );
 620        Ok(())
 621    }
 622
 623    pub fn teardown(&self) {
 624        self.peer.teardown();
 625        self.connection_pool.lock().reset();
 626        let _ = self.teardown.send(true);
 627    }
 628
 629    #[cfg(feature = "test-support")]
 630    pub fn reset(&self, id: ServerId) {
 631        self.teardown();
 632        *self.id.lock() = id;
 633        self.peer.reset(id.0 as u32);
 634        let _ = self.teardown.send(false);
 635    }
 636
 637    #[cfg(feature = "test-support")]
 638    pub fn id(&self) -> ServerId {
 639        *self.id.lock()
 640    }
 641
 642    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 643    where
 644        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 645        Fut: 'static + Send + Future<Output = Result<()>>,
 646        M: EnvelopedMessage,
 647    {
 648        let prev_handler = self.handlers.insert(
 649            TypeId::of::<M>(),
 650            Box::new(move |envelope, session, span| {
 651                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 652                let received_at = envelope.received_at;
 653                tracing::info!("message received");
 654                let start_time = Instant::now();
 655                let future = (handler)(
 656                    *envelope,
 657                    MessageContext {
 658                        session,
 659                        span: span.clone(),
 660                    },
 661                );
 662                async move {
 663                    let result = future.await;
 664                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 665                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 666                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 667                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 668                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 669                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 670                    match result {
 671                        Err(error) => {
 672                            tracing::error!(?error, "error handling message")
 673                        }
 674                        Ok(()) => tracing::info!("finished handling message"),
 675                    }
 676                }
 677                .boxed()
 678            }),
 679        );
 680        if prev_handler.is_some() {
 681            panic!("registered a handler for the same message twice");
 682        }
 683        self
 684    }
 685
 686    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 687    where
 688        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 689        Fut: 'static + Send + Future<Output = Result<()>>,
 690        M: EnvelopedMessage,
 691    {
 692        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 693        self
 694    }
 695
 696    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 697    where
 698        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 699        Fut: Send + Future<Output = Result<()>>,
 700        M: RequestMessage,
 701    {
 702        let handler = Arc::new(handler);
 703        self.add_handler(move |envelope, session| {
 704            let receipt = envelope.receipt();
 705            let handler = handler.clone();
 706            async move {
 707                let peer = session.peer.clone();
 708                let responded = Arc::new(AtomicBool::default());
 709                let response = Response {
 710                    peer: peer.clone(),
 711                    responded: responded.clone(),
 712                    receipt,
 713                };
 714                match (handler)(envelope.payload, response, session).await {
 715                    Ok(()) => {
 716                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 717                            Ok(())
 718                        } else {
 719                            Err(anyhow!("handler did not send a response"))?
 720                        }
 721                    }
 722                    Err(error) => {
 723                        let proto_err = match &error {
 724                            Error::Internal(err) => err.to_proto(),
 725                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 726                        };
 727                        peer.respond_with_error(receipt, proto_err)?;
 728                        Err(error)
 729                    }
 730                }
 731            }
 732        })
 733    }
 734
 735    pub fn handle_connection(
 736        self: &Arc<Self>,
 737        connection: Connection,
 738        address: String,
 739        principal: Principal,
 740        zed_version: ZedVersion,
 741        release_channel: Option<String>,
 742        user_agent: Option<String>,
 743        geoip_country_code: Option<String>,
 744        system_id: Option<String>,
 745        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 746        executor: Executor,
 747        connection_guard: Option<ConnectionGuard>,
 748    ) -> impl Future<Output = ()> + use<> {
 749        let this = self.clone();
 750        let span = info_span!("handle connection", %address,
 751            connection_id=field::Empty,
 752            user_id=field::Empty,
 753            login=field::Empty,
 754            user_agent=field::Empty,
 755            geoip_country_code=field::Empty,
 756            release_channel=field::Empty,
 757        );
 758        principal.update_span(&span);
 759        if let Some(user_agent) = user_agent {
 760            span.record("user_agent", user_agent);
 761        }
 762        if let Some(release_channel) = release_channel {
 763            span.record("release_channel", release_channel);
 764        }
 765
 766        if let Some(country_code) = geoip_country_code.as_ref() {
 767            span.record("geoip_country_code", country_code);
 768        }
 769
 770        let mut teardown = self.teardown.subscribe();
 771        async move {
 772            if *teardown.borrow() {
 773                tracing::error!("server is tearing down");
 774                return;
 775            }
 776
 777            let (connection_id, handle_io, mut incoming_rx) =
 778                this.peer.add_connection(connection, {
 779                    let executor = executor.clone();
 780                    move |duration| executor.sleep(duration)
 781                });
 782            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 783
 784            tracing::info!("connection opened");
 785
 786            let session = Session {
 787                principal: principal.clone(),
 788                connection_id,
 789                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 790                peer: this.peer.clone(),
 791                connection_pool: this.connection_pool.clone(),
 792                app_state: this.app_state.clone(),
 793                geoip_country_code,
 794                system_id,
 795                _executor: executor.clone(),
 796            };
 797
 798            if let Err(error) = this
 799                .send_initial_client_update(
 800                    connection_id,
 801                    zed_version,
 802                    send_connection_id,
 803                    &session,
 804                )
 805                .await
 806            {
 807                tracing::error!(?error, "failed to send initial client update");
 808                return;
 809            }
 810            drop(connection_guard);
 811
 812            let handle_io = handle_io.fuse();
 813            futures::pin_mut!(handle_io);
 814
 815            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 816            // This prevents deadlocks when e.g., client A performs a request to client B and
 817            // client B performs a request to client A. If both clients stop processing further
 818            // messages until their respective request completes, they won't have a chance to
 819            // respond to the other client's request and cause a deadlock.
 820            //
 821            // This arrangement ensures we will attempt to process earlier messages first, but fall
 822            // back to processing messages arrived later in the spirit of making progress.
 823            const MAX_CONCURRENT_HANDLERS: usize = 256;
 824            let mut foreground_message_handlers = FuturesUnordered::new();
 825            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 826            let get_concurrent_handlers = {
 827                let concurrent_handlers = concurrent_handlers.clone();
 828                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 829            };
 830            loop {
 831                let next_message = async {
 832                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 833                    let message = incoming_rx.next().await;
 834                    // Cache the concurrent_handlers here, so that we know what the
 835                    // queue looks like as each handler starts
 836                    (permit, message, get_concurrent_handlers())
 837                }
 838                .fuse();
 839                futures::pin_mut!(next_message);
 840                futures::select_biased! {
 841                    _ = teardown.changed().fuse() => return,
 842                    result = handle_io => {
 843                        if let Err(error) = result {
 844                            tracing::error!(?error, "error handling I/O");
 845                        }
 846                        break;
 847                    }
 848                    _ = foreground_message_handlers.next() => {}
 849                    next_message = next_message => {
 850                        let (permit, message, concurrent_handlers) = next_message;
 851                        if let Some(message) = message {
 852                            let type_name = message.payload_type_name();
 853                            // note: we copy all the fields from the parent span so we can query them in the logs.
 854                            // (https://github.com/tokio-rs/tracing/issues/2670).
 855                            let span = tracing::info_span!("receive message",
 856                                %connection_id,
 857                                %address,
 858                                type_name,
 859                                concurrent_handlers,
 860                                user_id=field::Empty,
 861                                login=field::Empty,
 862                                lsp_query_request=field::Empty,
 863                                release_channel=field::Empty,
 864                                { TOTAL_DURATION_MS }=field::Empty,
 865                                { PROCESSING_DURATION_MS }=field::Empty,
 866                                { QUEUE_DURATION_MS }=field::Empty,
 867                                { HOST_WAITING_MS }=field::Empty
 868                            );
 869                            principal.update_span(&span);
 870                            let span_enter = span.enter();
 871                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 872                                let is_background = message.is_background();
 873                                let handle_message = (handler)(message, session.clone(), span.clone());
 874                                drop(span_enter);
 875
 876                                let handle_message = async move {
 877                                    handle_message.await;
 878                                    drop(permit);
 879                                }.instrument(span);
 880                                if is_background {
 881                                    executor.spawn_detached(handle_message);
 882                                } else {
 883                                    foreground_message_handlers.push(handle_message);
 884                                }
 885                            } else {
 886                                tracing::error!("no message handler");
 887                            }
 888                        } else {
 889                            tracing::info!("connection closed");
 890                            break;
 891                        }
 892                    }
 893                }
 894            }
 895
 896            drop(foreground_message_handlers);
 897            let concurrent_handlers = get_concurrent_handlers();
 898            tracing::info!(concurrent_handlers, "signing out");
 899            if let Err(error) = connection_lost(session, teardown, executor).await {
 900                tracing::error!(?error, "error signing out");
 901            }
 902        }
 903        .instrument(span)
 904    }
 905
 906    async fn send_initial_client_update(
 907        &self,
 908        connection_id: ConnectionId,
 909        zed_version: ZedVersion,
 910        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 911        session: &Session,
 912    ) -> Result<()> {
 913        self.peer.send(
 914            connection_id,
 915            proto::Hello {
 916                peer_id: Some(connection_id.into()),
 917            },
 918        )?;
 919        tracing::info!("sent hello message");
 920        if let Some(send_connection_id) = send_connection_id.take() {
 921            let _ = send_connection_id.send(connection_id);
 922        }
 923
 924        match &session.principal {
 925            Principal::User(user) => {
 926                if !user.connected_once {
 927                    self.peer.send(connection_id, proto::ShowContacts {})?;
 928                    self.app_state
 929                        .db
 930                        .set_user_connected_once(user.id, true)
 931                        .await?;
 932                }
 933
 934                let contacts = self.app_state.db.get_contacts(user.id).await?;
 935
 936                {
 937                    let mut pool = self.connection_pool.lock();
 938                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 939                    self.peer.send(
 940                        connection_id,
 941                        build_initial_contacts_update(contacts, &pool),
 942                    )?;
 943                }
 944
 945                if should_auto_subscribe_to_channels(&zed_version) {
 946                    subscribe_user_to_channels(user.id, session).await?;
 947                }
 948
 949                if let Some(incoming_call) =
 950                    self.app_state.db.incoming_call_for_user(user.id).await?
 951                {
 952                    self.peer.send(connection_id, incoming_call)?;
 953                }
 954
 955                update_user_contacts(user.id, session).await?;
 956            }
 957        }
 958
 959        Ok(())
 960    }
 961}
 962
 963impl Deref for ConnectionPoolGuard<'_> {
 964    type Target = ConnectionPool;
 965
 966    fn deref(&self) -> &Self::Target {
 967        &self.guard
 968    }
 969}
 970
 971impl DerefMut for ConnectionPoolGuard<'_> {
 972    fn deref_mut(&mut self) -> &mut Self::Target {
 973        &mut self.guard
 974    }
 975}
 976
 977impl Drop for ConnectionPoolGuard<'_> {
 978    fn drop(&mut self) {
 979        #[cfg(feature = "test-support")]
 980        self.check_invariants();
 981    }
 982}
 983
 984fn broadcast<F>(
 985    sender_id: Option<ConnectionId>,
 986    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 987    mut f: F,
 988) where
 989    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 990{
 991    for receiver_id in receiver_ids {
 992        if Some(receiver_id) != sender_id
 993            && let Err(error) = f(receiver_id)
 994        {
 995            tracing::error!("failed to send to {:?} {}", receiver_id, error);
 996        }
 997    }
 998}
 999
1000pub struct ProtocolVersion(u32);
1001
1002impl Header for ProtocolVersion {
1003    fn name() -> &'static HeaderName {
1004        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1005        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1006    }
1007
1008    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1009    where
1010        Self: Sized,
1011        I: Iterator<Item = &'i axum::http::HeaderValue>,
1012    {
1013        let version = values
1014            .next()
1015            .ok_or_else(axum::headers::Error::invalid)?
1016            .to_str()
1017            .map_err(|_| axum::headers::Error::invalid())?
1018            .parse()
1019            .map_err(|_| axum::headers::Error::invalid())?;
1020        Ok(Self(version))
1021    }
1022
1023    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1024        values.extend([self.0.to_string().parse().unwrap()]);
1025    }
1026}
1027
1028pub struct AppVersionHeader(Version);
1029impl Header for AppVersionHeader {
1030    fn name() -> &'static HeaderName {
1031        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1032        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1033    }
1034
1035    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1036    where
1037        Self: Sized,
1038        I: Iterator<Item = &'i axum::http::HeaderValue>,
1039    {
1040        let version = values
1041            .next()
1042            .ok_or_else(axum::headers::Error::invalid)?
1043            .to_str()
1044            .map_err(|_| axum::headers::Error::invalid())?
1045            .parse()
1046            .map_err(|_| axum::headers::Error::invalid())?;
1047        Ok(Self(version))
1048    }
1049
1050    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1051        values.extend([self.0.to_string().parse().unwrap()]);
1052    }
1053}
1054
1055#[derive(Debug)]
1056pub struct ReleaseChannelHeader(String);
1057
1058impl Header for ReleaseChannelHeader {
1059    fn name() -> &'static HeaderName {
1060        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1061        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1062    }
1063
1064    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1065    where
1066        Self: Sized,
1067        I: Iterator<Item = &'i axum::http::HeaderValue>,
1068    {
1069        Ok(Self(
1070            values
1071                .next()
1072                .ok_or_else(axum::headers::Error::invalid)?
1073                .to_str()
1074                .map_err(|_| axum::headers::Error::invalid())?
1075                .to_owned(),
1076        ))
1077    }
1078
1079    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1080        values.extend([self.0.parse().unwrap()]);
1081    }
1082}
1083
1084pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1085    Router::new()
1086        .route("/rpc", get(handle_websocket_request))
1087        .layer(
1088            ServiceBuilder::new()
1089                .layer(Extension(server.app_state.clone()))
1090                .layer(middleware::from_fn(auth::validate_header)),
1091        )
1092        .route("/metrics", get(handle_metrics))
1093        .layer(Extension(server))
1094}
1095
1096pub async fn handle_websocket_request(
1097    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1098    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1099    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1100    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1101    Extension(server): Extension<Arc<Server>>,
1102    Extension(principal): Extension<Principal>,
1103    user_agent: Option<TypedHeader<UserAgent>>,
1104    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1105    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1106    ws: WebSocketUpgrade,
1107) -> axum::response::Response {
1108    if protocol_version != rpc::PROTOCOL_VERSION {
1109        return (
1110            StatusCode::UPGRADE_REQUIRED,
1111            "client must be upgraded".to_string(),
1112        )
1113            .into_response();
1114    }
1115
1116    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1117        return (
1118            StatusCode::UPGRADE_REQUIRED,
1119            "no version header found".to_string(),
1120        )
1121            .into_response();
1122    };
1123
1124    let release_channel = release_channel_header.map(|header| header.0.0);
1125
1126    if !version.can_collaborate() {
1127        return (
1128            StatusCode::UPGRADE_REQUIRED,
1129            "client must be upgraded".to_string(),
1130        )
1131            .into_response();
1132    }
1133
1134    let socket_address = socket_address.to_string();
1135
1136    // Acquire connection guard before WebSocket upgrade
1137    let connection_guard = match ConnectionGuard::try_acquire() {
1138        Ok(guard) => guard,
1139        Err(()) => {
1140            return (
1141                StatusCode::SERVICE_UNAVAILABLE,
1142                "Too many concurrent connections",
1143            )
1144                .into_response();
1145        }
1146    };
1147
1148    ws.on_upgrade(move |socket| {
1149        let socket = socket
1150            .map_ok(to_tungstenite_message)
1151            .err_into()
1152            .with(|message| async move { to_axum_message(message) });
1153        let connection = Connection::new(Box::pin(socket));
1154        async move {
1155            server
1156                .handle_connection(
1157                    connection,
1158                    socket_address,
1159                    principal,
1160                    version,
1161                    release_channel,
1162                    user_agent.map(|header| header.to_string()),
1163                    country_code_header.map(|header| header.to_string()),
1164                    system_id_header.map(|header| header.to_string()),
1165                    None,
1166                    Executor::Production,
1167                    Some(connection_guard),
1168                )
1169                .await;
1170        }
1171    })
1172}
1173
1174pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1175    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1176    let connections_metric = CONNECTIONS_METRIC
1177        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1178
1179    let connections = server
1180        .connection_pool
1181        .lock()
1182        .connections()
1183        .filter(|connection| !connection.admin)
1184        .count();
1185    connections_metric.set(connections as _);
1186
1187    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1188    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1189        register_int_gauge!(
1190            "shared_projects",
1191            "number of open projects with one or more guests"
1192        )
1193        .unwrap()
1194    });
1195
1196    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1197    shared_projects_metric.set(shared_projects as _);
1198
1199    let encoder = prometheus::TextEncoder::new();
1200    let metric_families = prometheus::gather();
1201    let encoded_metrics = encoder
1202        .encode_to_string(&metric_families)
1203        .map_err(|err| anyhow!("{err}"))?;
1204    Ok(encoded_metrics)
1205}
1206
1207#[instrument(err, skip(executor))]
1208async fn connection_lost(
1209    session: Session,
1210    mut teardown: watch::Receiver<bool>,
1211    executor: Executor,
1212) -> Result<()> {
1213    session.peer.disconnect(session.connection_id);
1214    session
1215        .connection_pool()
1216        .await
1217        .remove_connection(session.connection_id)?;
1218
1219    session
1220        .db()
1221        .await
1222        .connection_lost(session.connection_id)
1223        .await
1224        .trace_err();
1225
1226    futures::select_biased! {
1227        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1228
1229            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1230            leave_room_for_session(&session, session.connection_id).await.trace_err();
1231            leave_channel_buffers_for_session(&session)
1232                .await
1233                .trace_err();
1234
1235            if !session
1236                .connection_pool()
1237                .await
1238                .is_user_online(session.user_id())
1239            {
1240                let db = session.db().await;
1241                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1242                    room_updated(&room, &session.peer);
1243                }
1244            }
1245
1246            update_user_contacts(session.user_id(), &session).await?;
1247        },
1248        _ = teardown.changed().fuse() => {}
1249    }
1250
1251    Ok(())
1252}
1253
1254/// Acknowledges a ping from a client, used to keep the connection alive.
1255async fn ping(
1256    _: proto::Ping,
1257    response: Response<proto::Ping>,
1258    _session: MessageContext,
1259) -> Result<()> {
1260    response.send(proto::Ack {})?;
1261    Ok(())
1262}
1263
1264/// Creates a new room for calling (outside of channels)
1265async fn create_room(
1266    _request: proto::CreateRoom,
1267    response: Response<proto::CreateRoom>,
1268    session: MessageContext,
1269) -> Result<()> {
1270    let livekit_room = nanoid::nanoid!(30);
1271
1272    let live_kit_connection_info = util::maybe!(async {
1273        let live_kit = session.app_state.livekit_client.as_ref();
1274        let live_kit = live_kit?;
1275        let user_id = session.user_id().to_string();
1276
1277        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1278
1279        Some(proto::LiveKitConnectionInfo {
1280            server_url: live_kit.url().into(),
1281            token,
1282            can_publish: true,
1283        })
1284    })
1285    .await;
1286
1287    let room = session
1288        .db()
1289        .await
1290        .create_room(session.user_id(), session.connection_id, &livekit_room)
1291        .await?;
1292
1293    response.send(proto::CreateRoomResponse {
1294        room: Some(room.clone()),
1295        live_kit_connection_info,
1296    })?;
1297
1298    update_user_contacts(session.user_id(), &session).await?;
1299    Ok(())
1300}
1301
1302/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1303async fn join_room(
1304    request: proto::JoinRoom,
1305    response: Response<proto::JoinRoom>,
1306    session: MessageContext,
1307) -> Result<()> {
1308    let room_id = RoomId::from_proto(request.id);
1309
1310    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1311
1312    if let Some(channel_id) = channel_id {
1313        return join_channel_internal(channel_id, Box::new(response), session).await;
1314    }
1315
1316    let joined_room = {
1317        let room = session
1318            .db()
1319            .await
1320            .join_room(room_id, session.user_id(), session.connection_id)
1321            .await?;
1322        room_updated(&room.room, &session.peer);
1323        room.into_inner()
1324    };
1325
1326    for connection_id in session
1327        .connection_pool()
1328        .await
1329        .user_connection_ids(session.user_id())
1330    {
1331        session
1332            .peer
1333            .send(
1334                connection_id,
1335                proto::CallCanceled {
1336                    room_id: room_id.to_proto(),
1337                },
1338            )
1339            .trace_err();
1340    }
1341
1342    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1343    {
1344        live_kit
1345            .room_token(
1346                &joined_room.room.livekit_room,
1347                &session.user_id().to_string(),
1348            )
1349            .trace_err()
1350            .map(|token| proto::LiveKitConnectionInfo {
1351                server_url: live_kit.url().into(),
1352                token,
1353                can_publish: true,
1354            })
1355    } else {
1356        None
1357    };
1358
1359    response.send(proto::JoinRoomResponse {
1360        room: Some(joined_room.room),
1361        channel_id: None,
1362        live_kit_connection_info,
1363    })?;
1364
1365    update_user_contacts(session.user_id(), &session).await?;
1366    Ok(())
1367}
1368
1369/// Rejoin room is used to reconnect to a room after connection errors.
1370async fn rejoin_room(
1371    request: proto::RejoinRoom,
1372    response: Response<proto::RejoinRoom>,
1373    session: MessageContext,
1374) -> Result<()> {
1375    let room;
1376    let channel;
1377    {
1378        let mut rejoined_room = session
1379            .db()
1380            .await
1381            .rejoin_room(request, session.user_id(), session.connection_id)
1382            .await?;
1383
1384        response.send(proto::RejoinRoomResponse {
1385            room: Some(rejoined_room.room.clone()),
1386            reshared_projects: rejoined_room
1387                .reshared_projects
1388                .iter()
1389                .map(|project| proto::ResharedProject {
1390                    id: project.id.to_proto(),
1391                    collaborators: project
1392                        .collaborators
1393                        .iter()
1394                        .map(|collaborator| collaborator.to_proto())
1395                        .collect(),
1396                })
1397                .collect(),
1398            rejoined_projects: rejoined_room
1399                .rejoined_projects
1400                .iter()
1401                .map(|rejoined_project| rejoined_project.to_proto())
1402                .collect(),
1403        })?;
1404        room_updated(&rejoined_room.room, &session.peer);
1405
1406        for project in &rejoined_room.reshared_projects {
1407            for collaborator in &project.collaborators {
1408                session
1409                    .peer
1410                    .send(
1411                        collaborator.connection_id,
1412                        proto::UpdateProjectCollaborator {
1413                            project_id: project.id.to_proto(),
1414                            old_peer_id: Some(project.old_connection_id.into()),
1415                            new_peer_id: Some(session.connection_id.into()),
1416                        },
1417                    )
1418                    .trace_err();
1419            }
1420
1421            broadcast(
1422                Some(session.connection_id),
1423                project
1424                    .collaborators
1425                    .iter()
1426                    .map(|collaborator| collaborator.connection_id),
1427                |connection_id| {
1428                    session.peer.forward_send(
1429                        session.connection_id,
1430                        connection_id,
1431                        proto::UpdateProject {
1432                            project_id: project.id.to_proto(),
1433                            worktrees: project.worktrees.clone(),
1434                        },
1435                    )
1436                },
1437            );
1438        }
1439
1440        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1441
1442        let rejoined_room = rejoined_room.into_inner();
1443
1444        room = rejoined_room.room;
1445        channel = rejoined_room.channel;
1446    }
1447
1448    if let Some(channel) = channel {
1449        channel_updated(
1450            &channel,
1451            &room,
1452            &session.peer,
1453            &*session.connection_pool().await,
1454        );
1455    }
1456
1457    update_user_contacts(session.user_id(), &session).await?;
1458    Ok(())
1459}
1460
1461fn notify_rejoined_projects(
1462    rejoined_projects: &mut Vec<RejoinedProject>,
1463    session: &Session,
1464) -> Result<()> {
1465    for project in rejoined_projects.iter() {
1466        for collaborator in &project.collaborators {
1467            session
1468                .peer
1469                .send(
1470                    collaborator.connection_id,
1471                    proto::UpdateProjectCollaborator {
1472                        project_id: project.id.to_proto(),
1473                        old_peer_id: Some(project.old_connection_id.into()),
1474                        new_peer_id: Some(session.connection_id.into()),
1475                    },
1476                )
1477                .trace_err();
1478        }
1479    }
1480
1481    for project in rejoined_projects {
1482        for worktree in mem::take(&mut project.worktrees) {
1483            // Stream this worktree's entries.
1484            let message = proto::UpdateWorktree {
1485                project_id: project.id.to_proto(),
1486                worktree_id: worktree.id,
1487                abs_path: worktree.abs_path.clone(),
1488                root_name: worktree.root_name,
1489                root_repo_common_dir: worktree.root_repo_common_dir,
1490                updated_entries: worktree.updated_entries,
1491                removed_entries: worktree.removed_entries,
1492                scan_id: worktree.scan_id,
1493                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1494                updated_repositories: worktree.updated_repositories,
1495                removed_repositories: worktree.removed_repositories,
1496            };
1497            for update in proto::split_worktree_update(message) {
1498                session.peer.send(session.connection_id, update)?;
1499            }
1500
1501            // Stream this worktree's diagnostics.
1502            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1503            if let Some(summary) = worktree_diagnostics.next() {
1504                let message = proto::UpdateDiagnosticSummary {
1505                    project_id: project.id.to_proto(),
1506                    worktree_id: worktree.id,
1507                    summary: Some(summary),
1508                    more_summaries: worktree_diagnostics.collect(),
1509                };
1510                session.peer.send(session.connection_id, message)?;
1511            }
1512
1513            for settings_file in worktree.settings_files {
1514                session.peer.send(
1515                    session.connection_id,
1516                    proto::UpdateWorktreeSettings {
1517                        project_id: project.id.to_proto(),
1518                        worktree_id: worktree.id,
1519                        path: settings_file.path,
1520                        content: Some(settings_file.content),
1521                        kind: Some(settings_file.kind.to_proto().into()),
1522                        outside_worktree: Some(settings_file.outside_worktree),
1523                    },
1524                )?;
1525            }
1526        }
1527
1528        for repository in mem::take(&mut project.updated_repositories) {
1529            for update in split_repository_update(repository) {
1530                session.peer.send(session.connection_id, update)?;
1531            }
1532        }
1533
1534        for id in mem::take(&mut project.removed_repositories) {
1535            session.peer.send(
1536                session.connection_id,
1537                proto::RemoveRepository {
1538                    project_id: project.id.to_proto(),
1539                    id,
1540                },
1541            )?;
1542        }
1543    }
1544
1545    Ok(())
1546}
1547
1548/// leave room disconnects from the room.
1549async fn leave_room(
1550    _: proto::LeaveRoom,
1551    response: Response<proto::LeaveRoom>,
1552    session: MessageContext,
1553) -> Result<()> {
1554    leave_room_for_session(&session, session.connection_id).await?;
1555    response.send(proto::Ack {})?;
1556    Ok(())
1557}
1558
1559/// Updates the permissions of someone else in the room.
1560async fn set_room_participant_role(
1561    request: proto::SetRoomParticipantRole,
1562    response: Response<proto::SetRoomParticipantRole>,
1563    session: MessageContext,
1564) -> Result<()> {
1565    let user_id = UserId::from_proto(request.user_id);
1566    let role = ChannelRole::from(request.role());
1567
1568    let (livekit_room, can_publish) = {
1569        let room = session
1570            .db()
1571            .await
1572            .set_room_participant_role(
1573                session.user_id(),
1574                RoomId::from_proto(request.room_id),
1575                user_id,
1576                role,
1577            )
1578            .await?;
1579
1580        let livekit_room = room.livekit_room.clone();
1581        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1582        room_updated(&room, &session.peer);
1583        (livekit_room, can_publish)
1584    };
1585
1586    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1587        live_kit
1588            .update_participant(
1589                livekit_room.clone(),
1590                request.user_id.to_string(),
1591                livekit_api::proto::ParticipantPermission {
1592                    can_subscribe: true,
1593                    can_publish,
1594                    can_publish_data: can_publish,
1595                    hidden: false,
1596                    recorder: false,
1597                },
1598            )
1599            .await
1600            .trace_err();
1601    }
1602
1603    response.send(proto::Ack {})?;
1604    Ok(())
1605}
1606
1607/// Call someone else into the current room
1608async fn call(
1609    request: proto::Call,
1610    response: Response<proto::Call>,
1611    session: MessageContext,
1612) -> Result<()> {
1613    let room_id = RoomId::from_proto(request.room_id);
1614    let calling_user_id = session.user_id();
1615    let calling_connection_id = session.connection_id;
1616    let called_user_id = UserId::from_proto(request.called_user_id);
1617    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1618    if !session
1619        .db()
1620        .await
1621        .has_contact(calling_user_id, called_user_id)
1622        .await?
1623    {
1624        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1625    }
1626
1627    let incoming_call = {
1628        let (room, incoming_call) = &mut *session
1629            .db()
1630            .await
1631            .call(
1632                room_id,
1633                calling_user_id,
1634                calling_connection_id,
1635                called_user_id,
1636                initial_project_id,
1637            )
1638            .await?;
1639        room_updated(room, &session.peer);
1640        mem::take(incoming_call)
1641    };
1642    update_user_contacts(called_user_id, &session).await?;
1643
1644    let mut calls = session
1645        .connection_pool()
1646        .await
1647        .user_connection_ids(called_user_id)
1648        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1649        .collect::<FuturesUnordered<_>>();
1650
1651    while let Some(call_response) = calls.next().await {
1652        match call_response.as_ref() {
1653            Ok(_) => {
1654                response.send(proto::Ack {})?;
1655                return Ok(());
1656            }
1657            Err(_) => {
1658                call_response.trace_err();
1659            }
1660        }
1661    }
1662
1663    {
1664        let room = session
1665            .db()
1666            .await
1667            .call_failed(room_id, called_user_id)
1668            .await?;
1669        room_updated(&room, &session.peer);
1670    }
1671    update_user_contacts(called_user_id, &session).await?;
1672
1673    Err(anyhow!("failed to ring user"))?
1674}
1675
1676/// Cancel an outgoing call.
1677async fn cancel_call(
1678    request: proto::CancelCall,
1679    response: Response<proto::CancelCall>,
1680    session: MessageContext,
1681) -> Result<()> {
1682    let called_user_id = UserId::from_proto(request.called_user_id);
1683    let room_id = RoomId::from_proto(request.room_id);
1684    {
1685        let room = session
1686            .db()
1687            .await
1688            .cancel_call(room_id, session.connection_id, called_user_id)
1689            .await?;
1690        room_updated(&room, &session.peer);
1691    }
1692
1693    for connection_id in session
1694        .connection_pool()
1695        .await
1696        .user_connection_ids(called_user_id)
1697    {
1698        session
1699            .peer
1700            .send(
1701                connection_id,
1702                proto::CallCanceled {
1703                    room_id: room_id.to_proto(),
1704                },
1705            )
1706            .trace_err();
1707    }
1708    response.send(proto::Ack {})?;
1709
1710    update_user_contacts(called_user_id, &session).await?;
1711    Ok(())
1712}
1713
1714/// Decline an incoming call.
1715async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1716    let room_id = RoomId::from_proto(message.room_id);
1717    {
1718        let room = session
1719            .db()
1720            .await
1721            .decline_call(Some(room_id), session.user_id())
1722            .await?
1723            .context("declining call")?;
1724        room_updated(&room, &session.peer);
1725    }
1726
1727    for connection_id in session
1728        .connection_pool()
1729        .await
1730        .user_connection_ids(session.user_id())
1731    {
1732        session
1733            .peer
1734            .send(
1735                connection_id,
1736                proto::CallCanceled {
1737                    room_id: room_id.to_proto(),
1738                },
1739            )
1740            .trace_err();
1741    }
1742    update_user_contacts(session.user_id(), &session).await?;
1743    Ok(())
1744}
1745
1746/// Updates other participants in the room with your current location.
1747async fn update_participant_location(
1748    request: proto::UpdateParticipantLocation,
1749    response: Response<proto::UpdateParticipantLocation>,
1750    session: MessageContext,
1751) -> Result<()> {
1752    let room_id = RoomId::from_proto(request.room_id);
1753    let location = request.location.context("invalid location")?;
1754
1755    let db = session.db().await;
1756    let room = db
1757        .update_room_participant_location(room_id, session.connection_id, location)
1758        .await?;
1759
1760    room_updated(&room, &session.peer);
1761    response.send(proto::Ack {})?;
1762    Ok(())
1763}
1764
1765/// Share a project into the room.
1766async fn share_project(
1767    request: proto::ShareProject,
1768    response: Response<proto::ShareProject>,
1769    session: MessageContext,
1770) -> Result<()> {
1771    let (project_id, room) = &*session
1772        .db()
1773        .await
1774        .share_project(
1775            RoomId::from_proto(request.room_id),
1776            session.connection_id,
1777            &request.worktrees,
1778            request.is_ssh_project,
1779            request.windows_paths.unwrap_or(false),
1780            &request.features,
1781        )
1782        .await?;
1783    response.send(proto::ShareProjectResponse {
1784        project_id: project_id.to_proto(),
1785    })?;
1786    room_updated(room, &session.peer);
1787
1788    Ok(())
1789}
1790
1791/// Unshare a project from the room.
1792async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1793    let project_id = ProjectId::from_proto(message.project_id);
1794    unshare_project_internal(project_id, session.connection_id, &session).await
1795}
1796
1797async fn unshare_project_internal(
1798    project_id: ProjectId,
1799    connection_id: ConnectionId,
1800    session: &Session,
1801) -> Result<()> {
1802    let delete = {
1803        let room_guard = session
1804            .db()
1805            .await
1806            .unshare_project(project_id, connection_id)
1807            .await?;
1808
1809        let (delete, room, guest_connection_ids) = &*room_guard;
1810
1811        let message = proto::UnshareProject {
1812            project_id: project_id.to_proto(),
1813        };
1814
1815        broadcast(
1816            Some(connection_id),
1817            guest_connection_ids.iter().copied(),
1818            |conn_id| session.peer.send(conn_id, message.clone()),
1819        );
1820        if let Some(room) = room {
1821            room_updated(room, &session.peer);
1822        }
1823
1824        *delete
1825    };
1826
1827    if delete {
1828        let db = session.db().await;
1829        db.delete_project(project_id).await?;
1830    }
1831
1832    Ok(())
1833}
1834
1835/// Join someone elses shared project.
1836async fn join_project(
1837    request: proto::JoinProject,
1838    response: Response<proto::JoinProject>,
1839    session: MessageContext,
1840) -> Result<()> {
1841    let project_id = ProjectId::from_proto(request.project_id);
1842
1843    tracing::info!(%project_id, "join project");
1844
1845    let db = session.db().await;
1846    let project_model = db.get_project(project_id).await?;
1847    let host_features: Vec<String> =
1848        serde_json::from_str(&project_model.features).unwrap_or_default();
1849    let guest_features: HashSet<_> = request.features.iter().collect();
1850    let host_features_set: HashSet<_> = host_features.iter().collect();
1851    if guest_features != host_features_set {
1852        let host_connection_id = project_model.host_connection()?;
1853        let mut pool = session.connection_pool().await;
1854        let host_version = pool
1855            .connection(host_connection_id)
1856            .map(|c| c.zed_version.to_string());
1857        let guest_version = pool
1858            .connection(session.connection_id)
1859            .map(|c| c.zed_version.to_string());
1860        drop(pool);
1861        Err(anyhow!(
1862            "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1863            host_version.as_deref().unwrap_or("unknown"),
1864            guest_version.as_deref().unwrap_or("unknown"),
1865        ))?;
1866    }
1867
1868    let (project, replica_id) = &mut *db
1869        .join_project(
1870            project_id,
1871            session.connection_id,
1872            session.user_id(),
1873            request.committer_name.clone(),
1874            request.committer_email.clone(),
1875        )
1876        .await?;
1877    drop(db);
1878
1879    tracing::info!(%project_id, "join remote project");
1880    let collaborators = project
1881        .collaborators
1882        .iter()
1883        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1884        .map(|collaborator| collaborator.to_proto())
1885        .collect::<Vec<_>>();
1886    let project_id = project.id;
1887    let guest_user_id = session.user_id();
1888
1889    let worktrees = project
1890        .worktrees
1891        .iter()
1892        .map(|(id, worktree)| proto::WorktreeMetadata {
1893            id: *id,
1894            root_name: worktree.root_name.clone(),
1895            visible: worktree.visible,
1896            abs_path: worktree.abs_path.clone(),
1897            root_repo_common_dir: None,
1898        })
1899        .collect::<Vec<_>>();
1900
1901    let add_project_collaborator = proto::AddProjectCollaborator {
1902        project_id: project_id.to_proto(),
1903        collaborator: Some(proto::Collaborator {
1904            peer_id: Some(session.connection_id.into()),
1905            replica_id: replica_id.0 as u32,
1906            user_id: guest_user_id.to_proto(),
1907            is_host: false,
1908            committer_name: request.committer_name.clone(),
1909            committer_email: request.committer_email.clone(),
1910        }),
1911    };
1912
1913    for collaborator in &collaborators {
1914        session
1915            .peer
1916            .send(
1917                collaborator.peer_id.unwrap().into(),
1918                add_project_collaborator.clone(),
1919            )
1920            .trace_err();
1921    }
1922
1923    // First, we send the metadata associated with each worktree.
1924    let (language_servers, language_server_capabilities) = project
1925        .language_servers
1926        .clone()
1927        .into_iter()
1928        .map(|server| (server.server, server.capabilities))
1929        .unzip();
1930    response.send(proto::JoinProjectResponse {
1931        project_id: project.id.0 as u64,
1932        worktrees,
1933        replica_id: replica_id.0 as u32,
1934        collaborators,
1935        language_servers,
1936        language_server_capabilities,
1937        role: project.role.into(),
1938        windows_paths: project.path_style == PathStyle::Windows,
1939        features: project.features.clone(),
1940    })?;
1941
1942    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1943        // Stream this worktree's entries.
1944        let message = proto::UpdateWorktree {
1945            project_id: project_id.to_proto(),
1946            worktree_id,
1947            abs_path: worktree.abs_path.clone(),
1948            root_name: worktree.root_name,
1949            root_repo_common_dir: worktree.root_repo_common_dir,
1950            updated_entries: worktree.entries,
1951            removed_entries: Default::default(),
1952            scan_id: worktree.scan_id,
1953            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1954            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1955            removed_repositories: Default::default(),
1956        };
1957        for update in proto::split_worktree_update(message) {
1958            session.peer.send(session.connection_id, update.clone())?;
1959        }
1960
1961        // Stream this worktree's diagnostics.
1962        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1963        if let Some(summary) = worktree_diagnostics.next() {
1964            let message = proto::UpdateDiagnosticSummary {
1965                project_id: project.id.to_proto(),
1966                worktree_id: worktree.id,
1967                summary: Some(summary),
1968                more_summaries: worktree_diagnostics.collect(),
1969            };
1970            session.peer.send(session.connection_id, message)?;
1971        }
1972
1973        for settings_file in worktree.settings_files {
1974            session.peer.send(
1975                session.connection_id,
1976                proto::UpdateWorktreeSettings {
1977                    project_id: project_id.to_proto(),
1978                    worktree_id: worktree.id,
1979                    path: settings_file.path,
1980                    content: Some(settings_file.content),
1981                    kind: Some(settings_file.kind.to_proto() as i32),
1982                    outside_worktree: Some(settings_file.outside_worktree),
1983                },
1984            )?;
1985        }
1986    }
1987
1988    for repository in mem::take(&mut project.repositories) {
1989        for update in split_repository_update(repository) {
1990            session.peer.send(session.connection_id, update)?;
1991        }
1992    }
1993
1994    for language_server in &project.language_servers {
1995        session.peer.send(
1996            session.connection_id,
1997            proto::UpdateLanguageServer {
1998                project_id: project_id.to_proto(),
1999                server_name: Some(language_server.server.name.clone()),
2000                language_server_id: language_server.server.id,
2001                variant: Some(
2002                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2003                        proto::LspDiskBasedDiagnosticsUpdated {},
2004                    ),
2005                ),
2006            },
2007        )?;
2008    }
2009
2010    Ok(())
2011}
2012
2013/// Leave someone elses shared project.
2014async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2015    let sender_id = session.connection_id;
2016    let project_id = ProjectId::from_proto(request.project_id);
2017    let db = session.db().await;
2018
2019    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2020    tracing::info!(
2021        %project_id,
2022        "leave project"
2023    );
2024
2025    project_left(project, &session);
2026    if let Some(room) = room {
2027        room_updated(room, &session.peer);
2028    }
2029
2030    Ok(())
2031}
2032
2033/// Updates other participants with changes to the project
2034async fn update_project(
2035    request: proto::UpdateProject,
2036    response: Response<proto::UpdateProject>,
2037    session: MessageContext,
2038) -> Result<()> {
2039    let project_id = ProjectId::from_proto(request.project_id);
2040    let (room, guest_connection_ids) = &*session
2041        .db()
2042        .await
2043        .update_project(project_id, session.connection_id, &request.worktrees)
2044        .await?;
2045    broadcast(
2046        Some(session.connection_id),
2047        guest_connection_ids.iter().copied(),
2048        |connection_id| {
2049            session
2050                .peer
2051                .forward_send(session.connection_id, connection_id, request.clone())
2052        },
2053    );
2054    if let Some(room) = room {
2055        room_updated(room, &session.peer);
2056    }
2057    response.send(proto::Ack {})?;
2058
2059    Ok(())
2060}
2061
2062/// Updates other participants with changes to the worktree
2063async fn update_worktree(
2064    request: proto::UpdateWorktree,
2065    response: Response<proto::UpdateWorktree>,
2066    session: MessageContext,
2067) -> Result<()> {
2068    let guest_connection_ids = session
2069        .db()
2070        .await
2071        .update_worktree(&request, session.connection_id)
2072        .await?;
2073
2074    broadcast(
2075        Some(session.connection_id),
2076        guest_connection_ids.iter().copied(),
2077        |connection_id| {
2078            session
2079                .peer
2080                .forward_send(session.connection_id, connection_id, request.clone())
2081        },
2082    );
2083    response.send(proto::Ack {})?;
2084    Ok(())
2085}
2086
2087async fn update_repository(
2088    request: proto::UpdateRepository,
2089    response: Response<proto::UpdateRepository>,
2090    session: MessageContext,
2091) -> Result<()> {
2092    let guest_connection_ids = session
2093        .db()
2094        .await
2095        .update_repository(&request, session.connection_id)
2096        .await?;
2097
2098    broadcast(
2099        Some(session.connection_id),
2100        guest_connection_ids.iter().copied(),
2101        |connection_id| {
2102            session
2103                .peer
2104                .forward_send(session.connection_id, connection_id, request.clone())
2105        },
2106    );
2107    response.send(proto::Ack {})?;
2108    Ok(())
2109}
2110
2111async fn remove_repository(
2112    request: proto::RemoveRepository,
2113    response: Response<proto::RemoveRepository>,
2114    session: MessageContext,
2115) -> Result<()> {
2116    let guest_connection_ids = session
2117        .db()
2118        .await
2119        .remove_repository(&request, session.connection_id)
2120        .await?;
2121
2122    broadcast(
2123        Some(session.connection_id),
2124        guest_connection_ids.iter().copied(),
2125        |connection_id| {
2126            session
2127                .peer
2128                .forward_send(session.connection_id, connection_id, request.clone())
2129        },
2130    );
2131    response.send(proto::Ack {})?;
2132    Ok(())
2133}
2134
2135/// Updates other participants with changes to the diagnostics
2136async fn update_diagnostic_summary(
2137    message: proto::UpdateDiagnosticSummary,
2138    session: MessageContext,
2139) -> Result<()> {
2140    let guest_connection_ids = session
2141        .db()
2142        .await
2143        .update_diagnostic_summary(&message, session.connection_id)
2144        .await?;
2145
2146    broadcast(
2147        Some(session.connection_id),
2148        guest_connection_ids.iter().copied(),
2149        |connection_id| {
2150            session
2151                .peer
2152                .forward_send(session.connection_id, connection_id, message.clone())
2153        },
2154    );
2155
2156    Ok(())
2157}
2158
2159/// Updates other participants with changes to the worktree settings
2160async fn update_worktree_settings(
2161    message: proto::UpdateWorktreeSettings,
2162    session: MessageContext,
2163) -> Result<()> {
2164    let guest_connection_ids = session
2165        .db()
2166        .await
2167        .update_worktree_settings(&message, session.connection_id)
2168        .await?;
2169
2170    broadcast(
2171        Some(session.connection_id),
2172        guest_connection_ids.iter().copied(),
2173        |connection_id| {
2174            session
2175                .peer
2176                .forward_send(session.connection_id, connection_id, message.clone())
2177        },
2178    );
2179
2180    Ok(())
2181}
2182
2183/// Notify other participants that a language server has started.
2184async fn start_language_server(
2185    request: proto::StartLanguageServer,
2186    session: MessageContext,
2187) -> Result<()> {
2188    let guest_connection_ids = session
2189        .db()
2190        .await
2191        .start_language_server(&request, session.connection_id)
2192        .await?;
2193
2194    broadcast(
2195        Some(session.connection_id),
2196        guest_connection_ids.iter().copied(),
2197        |connection_id| {
2198            session
2199                .peer
2200                .forward_send(session.connection_id, connection_id, request.clone())
2201        },
2202    );
2203    Ok(())
2204}
2205
2206/// Notify other participants that a language server has changed.
2207async fn update_language_server(
2208    request: proto::UpdateLanguageServer,
2209    session: MessageContext,
2210) -> Result<()> {
2211    let project_id = ProjectId::from_proto(request.project_id);
2212    let db = session.db().await;
2213
2214    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2215        && let Some(capabilities) = update.capabilities.clone()
2216    {
2217        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2218            .await?;
2219    }
2220
2221    let project_connection_ids = db
2222        .project_connection_ids(project_id, session.connection_id, true)
2223        .await?;
2224    broadcast(
2225        Some(session.connection_id),
2226        project_connection_ids.iter().copied(),
2227        |connection_id| {
2228            session
2229                .peer
2230                .forward_send(session.connection_id, connection_id, request.clone())
2231        },
2232    );
2233    Ok(())
2234}
2235
2236/// forward a project request to the host. These requests should be read only
2237/// as guests are allowed to send them.
2238async fn forward_read_only_project_request<T>(
2239    request: T,
2240    response: Response<T>,
2241    session: MessageContext,
2242) -> Result<()>
2243where
2244    T: EntityMessage + RequestMessage,
2245{
2246    let project_id = ProjectId::from_proto(request.remote_entity_id());
2247    let host_connection_id = session
2248        .db()
2249        .await
2250        .host_for_read_only_project_request(project_id, session.connection_id)
2251        .await?;
2252    let payload = session.forward_request(host_connection_id, request).await?;
2253    response.send(payload)?;
2254    Ok(())
2255}
2256
2257/// forward a project request to the host. These requests are disallowed
2258/// for guests.
2259async fn forward_mutating_project_request<T>(
2260    request: T,
2261    response: Response<T>,
2262    session: MessageContext,
2263) -> Result<()>
2264where
2265    T: EntityMessage + RequestMessage,
2266{
2267    let project_id = ProjectId::from_proto(request.remote_entity_id());
2268
2269    let host_connection_id = session
2270        .db()
2271        .await
2272        .host_for_mutating_project_request(project_id, session.connection_id)
2273        .await?;
2274    let payload = session.forward_request(host_connection_id, request).await?;
2275    response.send(payload)?;
2276    Ok(())
2277}
2278
2279async fn disallow_guest_request<T>(
2280    _request: T,
2281    response: Response<T>,
2282    _session: MessageContext,
2283) -> Result<()>
2284where
2285    T: RequestMessage,
2286{
2287    response.peer.respond_with_error(
2288        response.receipt,
2289        ErrorCode::Forbidden
2290            .message("request is not allowed for guests".to_string())
2291            .to_proto(),
2292    )?;
2293    response.responded.store(true, SeqCst);
2294    Ok(())
2295}
2296
2297async fn lsp_query(
2298    request: proto::LspQuery,
2299    response: Response<proto::LspQuery>,
2300    session: MessageContext,
2301) -> Result<()> {
2302    let (name, should_write) = request.query_name_and_write_permissions();
2303    tracing::Span::current().record("lsp_query_request", name);
2304    tracing::info!("lsp_query message received");
2305    if should_write {
2306        forward_mutating_project_request(request, response, session).await
2307    } else {
2308        forward_read_only_project_request(request, response, session).await
2309    }
2310}
2311
2312/// Notify other participants that a new buffer has been created
2313async fn create_buffer_for_peer(
2314    request: proto::CreateBufferForPeer,
2315    session: MessageContext,
2316) -> Result<()> {
2317    session
2318        .db()
2319        .await
2320        .check_user_is_project_host(
2321            ProjectId::from_proto(request.project_id),
2322            session.connection_id,
2323        )
2324        .await?;
2325    let peer_id = request.peer_id.context("invalid peer id")?;
2326    session
2327        .peer
2328        .forward_send(session.connection_id, peer_id.into(), request)?;
2329    Ok(())
2330}
2331
2332/// Notify other participants that a new image has been created
2333async fn create_image_for_peer(
2334    request: proto::CreateImageForPeer,
2335    session: MessageContext,
2336) -> Result<()> {
2337    session
2338        .db()
2339        .await
2340        .check_user_is_project_host(
2341            ProjectId::from_proto(request.project_id),
2342            session.connection_id,
2343        )
2344        .await?;
2345    let peer_id = request.peer_id.context("invalid peer id")?;
2346    session
2347        .peer
2348        .forward_send(session.connection_id, peer_id.into(), request)?;
2349    Ok(())
2350}
2351
2352/// Notify other participants that a buffer has been updated. This is
2353/// allowed for guests as long as the update is limited to selections.
2354async fn update_buffer(
2355    request: proto::UpdateBuffer,
2356    response: Response<proto::UpdateBuffer>,
2357    session: MessageContext,
2358) -> Result<()> {
2359    let project_id = ProjectId::from_proto(request.project_id);
2360    let mut capability = Capability::ReadOnly;
2361
2362    for op in request.operations.iter() {
2363        match op.variant {
2364            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2365            Some(_) => capability = Capability::ReadWrite,
2366        }
2367    }
2368
2369    let host = {
2370        let guard = session
2371            .db()
2372            .await
2373            .connections_for_buffer_update(project_id, session.connection_id, capability)
2374            .await?;
2375
2376        let (host, guests) = &*guard;
2377
2378        broadcast(
2379            Some(session.connection_id),
2380            guests.clone(),
2381            |connection_id| {
2382                session
2383                    .peer
2384                    .forward_send(session.connection_id, connection_id, request.clone())
2385            },
2386        );
2387
2388        *host
2389    };
2390
2391    if host != session.connection_id {
2392        session.forward_request(host, request.clone()).await?;
2393    }
2394
2395    response.send(proto::Ack {})?;
2396    Ok(())
2397}
2398
2399async fn forward_project_search_chunk(
2400    message: proto::FindSearchCandidatesChunk,
2401    response: Response<proto::FindSearchCandidatesChunk>,
2402    session: MessageContext,
2403) -> Result<()> {
2404    let peer_id = message.peer_id.context("missing peer_id")?;
2405    let payload = session
2406        .peer
2407        .forward_request(session.connection_id, peer_id.into(), message)
2408        .await?;
2409    response.send(payload)?;
2410    Ok(())
2411}
2412
2413/// Notify other participants that a project has been updated.
2414async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2415    request: T,
2416    session: MessageContext,
2417) -> Result<()> {
2418    let project_id = ProjectId::from_proto(request.remote_entity_id());
2419    let project_connection_ids = session
2420        .db()
2421        .await
2422        .project_connection_ids(project_id, session.connection_id, false)
2423        .await?;
2424
2425    broadcast(
2426        Some(session.connection_id),
2427        project_connection_ids.iter().copied(),
2428        |connection_id| {
2429            session
2430                .peer
2431                .forward_send(session.connection_id, connection_id, request.clone())
2432        },
2433    );
2434    Ok(())
2435}
2436
2437/// Start following another user in a call.
2438async fn follow(
2439    request: proto::Follow,
2440    response: Response<proto::Follow>,
2441    session: MessageContext,
2442) -> Result<()> {
2443    let room_id = RoomId::from_proto(request.room_id);
2444    let project_id = request.project_id.map(ProjectId::from_proto);
2445    let leader_id = request.leader_id.context("invalid leader id")?.into();
2446    let follower_id = session.connection_id;
2447
2448    session
2449        .db()
2450        .await
2451        .check_room_participants(room_id, leader_id, session.connection_id)
2452        .await?;
2453
2454    let response_payload = session.forward_request(leader_id, request).await?;
2455    response.send(response_payload)?;
2456
2457    if let Some(project_id) = project_id {
2458        let room = session
2459            .db()
2460            .await
2461            .follow(room_id, project_id, leader_id, follower_id)
2462            .await?;
2463        room_updated(&room, &session.peer);
2464    }
2465
2466    Ok(())
2467}
2468
2469/// Stop following another user in a call.
2470async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2471    let room_id = RoomId::from_proto(request.room_id);
2472    let project_id = request.project_id.map(ProjectId::from_proto);
2473    let leader_id = request.leader_id.context("invalid leader id")?.into();
2474    let follower_id = session.connection_id;
2475
2476    session
2477        .db()
2478        .await
2479        .check_room_participants(room_id, leader_id, session.connection_id)
2480        .await?;
2481
2482    session
2483        .peer
2484        .forward_send(session.connection_id, leader_id, request)?;
2485
2486    if let Some(project_id) = project_id {
2487        let room = session
2488            .db()
2489            .await
2490            .unfollow(room_id, project_id, leader_id, follower_id)
2491            .await?;
2492        room_updated(&room, &session.peer);
2493    }
2494
2495    Ok(())
2496}
2497
2498/// Notify everyone following you of your current location.
2499async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2500    let room_id = RoomId::from_proto(request.room_id);
2501    let database = session.db.lock().await;
2502
2503    let connection_ids = if let Some(project_id) = request.project_id {
2504        let project_id = ProjectId::from_proto(project_id);
2505        database
2506            .project_connection_ids(project_id, session.connection_id, true)
2507            .await?
2508    } else {
2509        database
2510            .room_connection_ids(room_id, session.connection_id)
2511            .await?
2512    };
2513
2514    // For now, don't send view update messages back to that view's current leader.
2515    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2516        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2517        _ => None,
2518    });
2519
2520    for connection_id in connection_ids.iter().cloned() {
2521        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2522            session
2523                .peer
2524                .forward_send(session.connection_id, connection_id, request.clone())?;
2525        }
2526    }
2527    Ok(())
2528}
2529
2530/// Get public data about users.
2531async fn get_users(
2532    request: proto::GetUsers,
2533    response: Response<proto::GetUsers>,
2534    session: MessageContext,
2535) -> Result<()> {
2536    let user_ids = request
2537        .user_ids
2538        .into_iter()
2539        .map(UserId::from_proto)
2540        .collect();
2541    let users = session
2542        .db()
2543        .await
2544        .get_users_by_ids(user_ids)
2545        .await?
2546        .into_iter()
2547        .map(|user| proto::User {
2548            id: user.id.to_proto(),
2549            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2550            github_login: user.github_login,
2551            name: user.name,
2552        })
2553        .collect();
2554    response.send(proto::UsersResponse { users })?;
2555    Ok(())
2556}
2557
2558/// Search for users (to invite) buy Github login
2559async fn fuzzy_search_users(
2560    request: proto::FuzzySearchUsers,
2561    response: Response<proto::FuzzySearchUsers>,
2562    session: MessageContext,
2563) -> Result<()> {
2564    let query = request.query;
2565    let users = match query.len() {
2566        0 => vec![],
2567        1 | 2 => session
2568            .db()
2569            .await
2570            .get_user_by_github_login(&query)
2571            .await?
2572            .into_iter()
2573            .collect(),
2574        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2575    };
2576    let users = users
2577        .into_iter()
2578        .filter(|user| user.id != session.user_id())
2579        .map(|user| proto::User {
2580            id: user.id.to_proto(),
2581            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2582            github_login: user.github_login,
2583            name: user.name,
2584        })
2585        .collect();
2586    response.send(proto::UsersResponse { users })?;
2587    Ok(())
2588}
2589
2590/// Send a contact request to another user.
2591async fn request_contact(
2592    request: proto::RequestContact,
2593    response: Response<proto::RequestContact>,
2594    session: MessageContext,
2595) -> Result<()> {
2596    let requester_id = session.user_id();
2597    let responder_id = UserId::from_proto(request.responder_id);
2598    if requester_id == responder_id {
2599        return Err(anyhow!("cannot add yourself as a contact"))?;
2600    }
2601
2602    let notifications = session
2603        .db()
2604        .await
2605        .send_contact_request(requester_id, responder_id)
2606        .await?;
2607
2608    // Update outgoing contact requests of requester
2609    let mut update = proto::UpdateContacts::default();
2610    update.outgoing_requests.push(responder_id.to_proto());
2611    for connection_id in session
2612        .connection_pool()
2613        .await
2614        .user_connection_ids(requester_id)
2615    {
2616        session.peer.send(connection_id, update.clone())?;
2617    }
2618
2619    // Update incoming contact requests of responder
2620    let mut update = proto::UpdateContacts::default();
2621    update
2622        .incoming_requests
2623        .push(proto::IncomingContactRequest {
2624            requester_id: requester_id.to_proto(),
2625        });
2626    let connection_pool = session.connection_pool().await;
2627    for connection_id in connection_pool.user_connection_ids(responder_id) {
2628        session.peer.send(connection_id, update.clone())?;
2629    }
2630
2631    send_notifications(&connection_pool, &session.peer, notifications);
2632
2633    response.send(proto::Ack {})?;
2634    Ok(())
2635}
2636
2637/// Accept or decline a contact request
2638async fn respond_to_contact_request(
2639    request: proto::RespondToContactRequest,
2640    response: Response<proto::RespondToContactRequest>,
2641    session: MessageContext,
2642) -> Result<()> {
2643    let responder_id = session.user_id();
2644    let requester_id = UserId::from_proto(request.requester_id);
2645    let db = session.db().await;
2646    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2647        db.dismiss_contact_notification(responder_id, requester_id)
2648            .await?;
2649    } else {
2650        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2651
2652        let notifications = db
2653            .respond_to_contact_request(responder_id, requester_id, accept)
2654            .await?;
2655        let requester_busy = db.is_user_busy(requester_id).await?;
2656        let responder_busy = db.is_user_busy(responder_id).await?;
2657
2658        let pool = session.connection_pool().await;
2659        // Update responder with new contact
2660        let mut update = proto::UpdateContacts::default();
2661        if accept {
2662            update
2663                .contacts
2664                .push(contact_for_user(requester_id, requester_busy, &pool));
2665        }
2666        update
2667            .remove_incoming_requests
2668            .push(requester_id.to_proto());
2669        for connection_id in pool.user_connection_ids(responder_id) {
2670            session.peer.send(connection_id, update.clone())?;
2671        }
2672
2673        // Update requester with new contact
2674        let mut update = proto::UpdateContacts::default();
2675        if accept {
2676            update
2677                .contacts
2678                .push(contact_for_user(responder_id, responder_busy, &pool));
2679        }
2680        update
2681            .remove_outgoing_requests
2682            .push(responder_id.to_proto());
2683
2684        for connection_id in pool.user_connection_ids(requester_id) {
2685            session.peer.send(connection_id, update.clone())?;
2686        }
2687
2688        send_notifications(&pool, &session.peer, notifications);
2689    }
2690
2691    response.send(proto::Ack {})?;
2692    Ok(())
2693}
2694
2695/// Remove a contact.
2696async fn remove_contact(
2697    request: proto::RemoveContact,
2698    response: Response<proto::RemoveContact>,
2699    session: MessageContext,
2700) -> Result<()> {
2701    let requester_id = session.user_id();
2702    let responder_id = UserId::from_proto(request.user_id);
2703    let db = session.db().await;
2704    let (contact_accepted, deleted_notification_id) =
2705        db.remove_contact(requester_id, responder_id).await?;
2706
2707    let pool = session.connection_pool().await;
2708    // Update outgoing contact requests of requester
2709    let mut update = proto::UpdateContacts::default();
2710    if contact_accepted {
2711        update.remove_contacts.push(responder_id.to_proto());
2712    } else {
2713        update
2714            .remove_outgoing_requests
2715            .push(responder_id.to_proto());
2716    }
2717    for connection_id in pool.user_connection_ids(requester_id) {
2718        session.peer.send(connection_id, update.clone())?;
2719    }
2720
2721    // Update incoming contact requests of responder
2722    let mut update = proto::UpdateContacts::default();
2723    if contact_accepted {
2724        update.remove_contacts.push(requester_id.to_proto());
2725    } else {
2726        update
2727            .remove_incoming_requests
2728            .push(requester_id.to_proto());
2729    }
2730    for connection_id in pool.user_connection_ids(responder_id) {
2731        session.peer.send(connection_id, update.clone())?;
2732        if let Some(notification_id) = deleted_notification_id {
2733            session.peer.send(
2734                connection_id,
2735                proto::DeleteNotification {
2736                    notification_id: notification_id.to_proto(),
2737                },
2738            )?;
2739        }
2740    }
2741
2742    response.send(proto::Ack {})?;
2743    Ok(())
2744}
2745
2746fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2747    version.0.minor < 139
2748}
2749
2750async fn subscribe_to_channels(
2751    _: proto::SubscribeToChannels,
2752    session: MessageContext,
2753) -> Result<()> {
2754    subscribe_user_to_channels(session.user_id(), &session).await?;
2755    Ok(())
2756}
2757
2758async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2759    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2760    let mut pool = session.connection_pool().await;
2761    for membership in &channels_for_user.channel_memberships {
2762        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2763    }
2764    session.peer.send(
2765        session.connection_id,
2766        build_update_user_channels(&channels_for_user),
2767    )?;
2768    session.peer.send(
2769        session.connection_id,
2770        build_channels_update(channels_for_user),
2771    )?;
2772    Ok(())
2773}
2774
2775/// Creates a new channel.
2776async fn create_channel(
2777    request: proto::CreateChannel,
2778    response: Response<proto::CreateChannel>,
2779    session: MessageContext,
2780) -> Result<()> {
2781    let db = session.db().await;
2782
2783    let parent_id = request.parent_id.map(ChannelId::from_proto);
2784    let (channel, membership) = db
2785        .create_channel(&request.name, parent_id, session.user_id())
2786        .await?;
2787
2788    let root_id = channel.root_id();
2789    let channel = Channel::from_model(channel);
2790
2791    response.send(proto::CreateChannelResponse {
2792        channel: Some(channel.to_proto()),
2793        parent_id: request.parent_id,
2794    })?;
2795
2796    let mut connection_pool = session.connection_pool().await;
2797    if let Some(membership) = membership {
2798        connection_pool.subscribe_to_channel(
2799            membership.user_id,
2800            membership.channel_id,
2801            membership.role,
2802        );
2803        let update = proto::UpdateUserChannels {
2804            channel_memberships: vec![proto::ChannelMembership {
2805                channel_id: membership.channel_id.to_proto(),
2806                role: membership.role.into(),
2807            }],
2808            ..Default::default()
2809        };
2810        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2811            session.peer.send(connection_id, update.clone())?;
2812        }
2813    }
2814
2815    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2816        if !role.can_see_channel(channel.visibility) {
2817            continue;
2818        }
2819
2820        let update = proto::UpdateChannels {
2821            channels: vec![channel.to_proto()],
2822            ..Default::default()
2823        };
2824        session.peer.send(connection_id, update.clone())?;
2825    }
2826
2827    Ok(())
2828}
2829
2830/// Delete a channel
2831async fn delete_channel(
2832    request: proto::DeleteChannel,
2833    response: Response<proto::DeleteChannel>,
2834    session: MessageContext,
2835) -> Result<()> {
2836    let db = session.db().await;
2837
2838    let channel_id = request.channel_id;
2839    let (root_channel, removed_channels) = db
2840        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2841        .await?;
2842    response.send(proto::Ack {})?;
2843
2844    // Notify members of removed channels
2845    let mut update = proto::UpdateChannels::default();
2846    update
2847        .delete_channels
2848        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2849
2850    let connection_pool = session.connection_pool().await;
2851    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2852        session.peer.send(connection_id, update.clone())?;
2853    }
2854
2855    Ok(())
2856}
2857
2858/// Invite someone to join a channel.
2859async fn invite_channel_member(
2860    request: proto::InviteChannelMember,
2861    response: Response<proto::InviteChannelMember>,
2862    session: MessageContext,
2863) -> Result<()> {
2864    let db = session.db().await;
2865    let channel_id = ChannelId::from_proto(request.channel_id);
2866    let invitee_id = UserId::from_proto(request.user_id);
2867    let InviteMemberResult {
2868        channel,
2869        notifications,
2870    } = db
2871        .invite_channel_member(
2872            channel_id,
2873            invitee_id,
2874            session.user_id(),
2875            request.role().into(),
2876        )
2877        .await?;
2878
2879    let update = proto::UpdateChannels {
2880        channel_invitations: vec![channel.to_proto()],
2881        ..Default::default()
2882    };
2883
2884    let connection_pool = session.connection_pool().await;
2885    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2886        session.peer.send(connection_id, update.clone())?;
2887    }
2888
2889    send_notifications(&connection_pool, &session.peer, notifications);
2890
2891    response.send(proto::Ack {})?;
2892    Ok(())
2893}
2894
2895/// remove someone from a channel
2896async fn remove_channel_member(
2897    request: proto::RemoveChannelMember,
2898    response: Response<proto::RemoveChannelMember>,
2899    session: MessageContext,
2900) -> Result<()> {
2901    let db = session.db().await;
2902    let channel_id = ChannelId::from_proto(request.channel_id);
2903    let member_id = UserId::from_proto(request.user_id);
2904
2905    let RemoveChannelMemberResult {
2906        membership_update,
2907        notification_id,
2908    } = db
2909        .remove_channel_member(channel_id, member_id, session.user_id())
2910        .await?;
2911
2912    let mut connection_pool = session.connection_pool().await;
2913    notify_membership_updated(
2914        &mut connection_pool,
2915        membership_update,
2916        member_id,
2917        &session.peer,
2918    );
2919    for connection_id in connection_pool.user_connection_ids(member_id) {
2920        if let Some(notification_id) = notification_id {
2921            session
2922                .peer
2923                .send(
2924                    connection_id,
2925                    proto::DeleteNotification {
2926                        notification_id: notification_id.to_proto(),
2927                    },
2928                )
2929                .trace_err();
2930        }
2931    }
2932
2933    response.send(proto::Ack {})?;
2934    Ok(())
2935}
2936
2937/// Toggle the channel between public and private.
2938/// Care is taken to maintain the invariant that public channels only descend from public channels,
2939/// (though members-only channels can appear at any point in the hierarchy).
2940async fn set_channel_visibility(
2941    request: proto::SetChannelVisibility,
2942    response: Response<proto::SetChannelVisibility>,
2943    session: MessageContext,
2944) -> Result<()> {
2945    let db = session.db().await;
2946    let channel_id = ChannelId::from_proto(request.channel_id);
2947    let visibility = request.visibility().into();
2948
2949    let channel_model = db
2950        .set_channel_visibility(channel_id, visibility, session.user_id())
2951        .await?;
2952    let root_id = channel_model.root_id();
2953    let channel = Channel::from_model(channel_model);
2954
2955    let mut connection_pool = session.connection_pool().await;
2956    for (user_id, role) in connection_pool
2957        .channel_user_ids(root_id)
2958        .collect::<Vec<_>>()
2959        .into_iter()
2960    {
2961        let update = if role.can_see_channel(channel.visibility) {
2962            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2963            proto::UpdateChannels {
2964                channels: vec![channel.to_proto()],
2965                ..Default::default()
2966            }
2967        } else {
2968            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2969            proto::UpdateChannels {
2970                delete_channels: vec![channel.id.to_proto()],
2971                ..Default::default()
2972            }
2973        };
2974
2975        for connection_id in connection_pool.user_connection_ids(user_id) {
2976            session.peer.send(connection_id, update.clone())?;
2977        }
2978    }
2979
2980    response.send(proto::Ack {})?;
2981    Ok(())
2982}
2983
2984/// Alter the role for a user in the channel.
2985async fn set_channel_member_role(
2986    request: proto::SetChannelMemberRole,
2987    response: Response<proto::SetChannelMemberRole>,
2988    session: MessageContext,
2989) -> Result<()> {
2990    let db = session.db().await;
2991    let channel_id = ChannelId::from_proto(request.channel_id);
2992    let member_id = UserId::from_proto(request.user_id);
2993    let result = db
2994        .set_channel_member_role(
2995            channel_id,
2996            session.user_id(),
2997            member_id,
2998            request.role().into(),
2999        )
3000        .await?;
3001
3002    match result {
3003        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3004            let mut connection_pool = session.connection_pool().await;
3005            notify_membership_updated(
3006                &mut connection_pool,
3007                membership_update,
3008                member_id,
3009                &session.peer,
3010            )
3011        }
3012        db::SetMemberRoleResult::InviteUpdated(channel) => {
3013            let update = proto::UpdateChannels {
3014                channel_invitations: vec![channel.to_proto()],
3015                ..Default::default()
3016            };
3017
3018            for connection_id in session
3019                .connection_pool()
3020                .await
3021                .user_connection_ids(member_id)
3022            {
3023                session.peer.send(connection_id, update.clone())?;
3024            }
3025        }
3026    }
3027
3028    response.send(proto::Ack {})?;
3029    Ok(())
3030}
3031
3032/// Change the name of a channel
3033async fn rename_channel(
3034    request: proto::RenameChannel,
3035    response: Response<proto::RenameChannel>,
3036    session: MessageContext,
3037) -> Result<()> {
3038    let db = session.db().await;
3039    let channel_id = ChannelId::from_proto(request.channel_id);
3040    let channel_model = db
3041        .rename_channel(channel_id, session.user_id(), &request.name)
3042        .await?;
3043    let root_id = channel_model.root_id();
3044    let channel = Channel::from_model(channel_model);
3045
3046    response.send(proto::RenameChannelResponse {
3047        channel: Some(channel.to_proto()),
3048    })?;
3049
3050    let connection_pool = session.connection_pool().await;
3051    let update = proto::UpdateChannels {
3052        channels: vec![channel.to_proto()],
3053        ..Default::default()
3054    };
3055    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3056        if role.can_see_channel(channel.visibility) {
3057            session.peer.send(connection_id, update.clone())?;
3058        }
3059    }
3060
3061    Ok(())
3062}
3063
3064/// Move a channel to a new parent.
3065async fn move_channel(
3066    request: proto::MoveChannel,
3067    response: Response<proto::MoveChannel>,
3068    session: MessageContext,
3069) -> Result<()> {
3070    let channel_id = ChannelId::from_proto(request.channel_id);
3071    let to = ChannelId::from_proto(request.to);
3072
3073    let (root_id, channels) = session
3074        .db()
3075        .await
3076        .move_channel(channel_id, to, session.user_id())
3077        .await?;
3078
3079    let connection_pool = session.connection_pool().await;
3080    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3081        let channels = channels
3082            .iter()
3083            .filter_map(|channel| {
3084                if role.can_see_channel(channel.visibility) {
3085                    Some(channel.to_proto())
3086                } else {
3087                    None
3088                }
3089            })
3090            .collect::<Vec<_>>();
3091        if channels.is_empty() {
3092            continue;
3093        }
3094
3095        let update = proto::UpdateChannels {
3096            channels,
3097            ..Default::default()
3098        };
3099
3100        session.peer.send(connection_id, update.clone())?;
3101    }
3102
3103    response.send(Ack {})?;
3104    Ok(())
3105}
3106
3107async fn reorder_channel(
3108    request: proto::ReorderChannel,
3109    response: Response<proto::ReorderChannel>,
3110    session: MessageContext,
3111) -> Result<()> {
3112    let channel_id = ChannelId::from_proto(request.channel_id);
3113    let direction = request.direction();
3114
3115    let updated_channels = session
3116        .db()
3117        .await
3118        .reorder_channel(channel_id, direction, session.user_id())
3119        .await?;
3120
3121    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3122        let connection_pool = session.connection_pool().await;
3123        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3124            let channels = updated_channels
3125                .iter()
3126                .filter_map(|channel| {
3127                    if role.can_see_channel(channel.visibility) {
3128                        Some(channel.to_proto())
3129                    } else {
3130                        None
3131                    }
3132                })
3133                .collect::<Vec<_>>();
3134
3135            if channels.is_empty() {
3136                continue;
3137            }
3138
3139            let update = proto::UpdateChannels {
3140                channels,
3141                ..Default::default()
3142            };
3143
3144            session.peer.send(connection_id, update.clone())?;
3145        }
3146    }
3147
3148    response.send(Ack {})?;
3149    Ok(())
3150}
3151
3152/// Get the list of channel members
3153async fn get_channel_members(
3154    request: proto::GetChannelMembers,
3155    response: Response<proto::GetChannelMembers>,
3156    session: MessageContext,
3157) -> Result<()> {
3158    let db = session.db().await;
3159    let channel_id = ChannelId::from_proto(request.channel_id);
3160    let limit = if request.limit == 0 {
3161        u16::MAX as u64
3162    } else {
3163        request.limit
3164    };
3165    let (members, users) = db
3166        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3167        .await?;
3168    response.send(proto::GetChannelMembersResponse { members, users })?;
3169    Ok(())
3170}
3171
3172/// Accept or decline a channel invitation.
3173async fn respond_to_channel_invite(
3174    request: proto::RespondToChannelInvite,
3175    response: Response<proto::RespondToChannelInvite>,
3176    session: MessageContext,
3177) -> Result<()> {
3178    let db = session.db().await;
3179    let channel_id = ChannelId::from_proto(request.channel_id);
3180    let RespondToChannelInvite {
3181        membership_update,
3182        notifications,
3183    } = db
3184        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3185        .await?;
3186
3187    let mut connection_pool = session.connection_pool().await;
3188    if let Some(membership_update) = membership_update {
3189        notify_membership_updated(
3190            &mut connection_pool,
3191            membership_update,
3192            session.user_id(),
3193            &session.peer,
3194        );
3195    } else {
3196        let update = proto::UpdateChannels {
3197            remove_channel_invitations: vec![channel_id.to_proto()],
3198            ..Default::default()
3199        };
3200
3201        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3202            session.peer.send(connection_id, update.clone())?;
3203        }
3204    };
3205
3206    send_notifications(&connection_pool, &session.peer, notifications);
3207
3208    response.send(proto::Ack {})?;
3209
3210    Ok(())
3211}
3212
3213/// Join the channels' room
3214async fn join_channel(
3215    request: proto::JoinChannel,
3216    response: Response<proto::JoinChannel>,
3217    session: MessageContext,
3218) -> Result<()> {
3219    let channel_id = ChannelId::from_proto(request.channel_id);
3220    join_channel_internal(channel_id, Box::new(response), session).await
3221}
3222
3223trait JoinChannelInternalResponse {
3224    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3225}
3226impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3227    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3228        Response::<proto::JoinChannel>::send(self, result)
3229    }
3230}
3231impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3232    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3233        Response::<proto::JoinRoom>::send(self, result)
3234    }
3235}
3236
3237async fn join_channel_internal(
3238    channel_id: ChannelId,
3239    response: Box<impl JoinChannelInternalResponse>,
3240    session: MessageContext,
3241) -> Result<()> {
3242    let joined_room = {
3243        let mut db = session.db().await;
3244        // If zed quits without leaving the room, and the user re-opens zed before the
3245        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3246        // room they were in.
3247        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3248            tracing::info!(
3249                stale_connection_id = %connection,
3250                "cleaning up stale connection",
3251            );
3252            drop(db);
3253            leave_room_for_session(&session, connection).await?;
3254            db = session.db().await;
3255        }
3256
3257        let (joined_room, membership_updated, role) = db
3258            .join_channel(channel_id, session.user_id(), session.connection_id)
3259            .await?;
3260
3261        let live_kit_connection_info =
3262            session
3263                .app_state
3264                .livekit_client
3265                .as_ref()
3266                .and_then(|live_kit| {
3267                    let (can_publish, token) = if role == ChannelRole::Guest {
3268                        (
3269                            false,
3270                            live_kit
3271                                .guest_token(
3272                                    &joined_room.room.livekit_room,
3273                                    &session.user_id().to_string(),
3274                                )
3275                                .trace_err()?,
3276                        )
3277                    } else {
3278                        (
3279                            true,
3280                            live_kit
3281                                .room_token(
3282                                    &joined_room.room.livekit_room,
3283                                    &session.user_id().to_string(),
3284                                )
3285                                .trace_err()?,
3286                        )
3287                    };
3288
3289                    Some(LiveKitConnectionInfo {
3290                        server_url: live_kit.url().into(),
3291                        token,
3292                        can_publish,
3293                    })
3294                });
3295
3296        response.send(proto::JoinRoomResponse {
3297            room: Some(joined_room.room.clone()),
3298            channel_id: joined_room
3299                .channel
3300                .as_ref()
3301                .map(|channel| channel.id.to_proto()),
3302            live_kit_connection_info,
3303        })?;
3304
3305        let mut connection_pool = session.connection_pool().await;
3306        if let Some(membership_updated) = membership_updated {
3307            notify_membership_updated(
3308                &mut connection_pool,
3309                membership_updated,
3310                session.user_id(),
3311                &session.peer,
3312            );
3313        }
3314
3315        room_updated(&joined_room.room, &session.peer);
3316
3317        joined_room
3318    };
3319
3320    channel_updated(
3321        &joined_room.channel.context("channel not returned")?,
3322        &joined_room.room,
3323        &session.peer,
3324        &*session.connection_pool().await,
3325    );
3326
3327    update_user_contacts(session.user_id(), &session).await?;
3328    Ok(())
3329}
3330
3331/// Start editing the channel notes
3332async fn join_channel_buffer(
3333    request: proto::JoinChannelBuffer,
3334    response: Response<proto::JoinChannelBuffer>,
3335    session: MessageContext,
3336) -> Result<()> {
3337    let db = session.db().await;
3338    let channel_id = ChannelId::from_proto(request.channel_id);
3339
3340    let open_response = db
3341        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3342        .await?;
3343
3344    let collaborators = open_response.collaborators.clone();
3345    response.send(open_response)?;
3346
3347    let update = UpdateChannelBufferCollaborators {
3348        channel_id: channel_id.to_proto(),
3349        collaborators: collaborators.clone(),
3350    };
3351    channel_buffer_updated(
3352        session.connection_id,
3353        collaborators
3354            .iter()
3355            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3356        &update,
3357        &session.peer,
3358    );
3359
3360    Ok(())
3361}
3362
3363/// Edit the channel notes
3364async fn update_channel_buffer(
3365    request: proto::UpdateChannelBuffer,
3366    session: MessageContext,
3367) -> Result<()> {
3368    let db = session.db().await;
3369    let channel_id = ChannelId::from_proto(request.channel_id);
3370
3371    let (collaborators, epoch, version) = db
3372        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3373        .await?;
3374
3375    channel_buffer_updated(
3376        session.connection_id,
3377        collaborators.clone(),
3378        &proto::UpdateChannelBuffer {
3379            channel_id: channel_id.to_proto(),
3380            operations: request.operations,
3381        },
3382        &session.peer,
3383    );
3384
3385    let pool = &*session.connection_pool().await;
3386
3387    let non_collaborators =
3388        pool.channel_connection_ids(channel_id)
3389            .filter_map(|(connection_id, _)| {
3390                if collaborators.contains(&connection_id) {
3391                    None
3392                } else {
3393                    Some(connection_id)
3394                }
3395            });
3396
3397    broadcast(None, non_collaborators, |peer_id| {
3398        session.peer.send(
3399            peer_id,
3400            proto::UpdateChannels {
3401                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3402                    channel_id: channel_id.to_proto(),
3403                    epoch: epoch as u64,
3404                    version: version.clone(),
3405                }],
3406                ..Default::default()
3407            },
3408        )
3409    });
3410
3411    Ok(())
3412}
3413
3414/// Rejoin the channel notes after a connection blip
3415async fn rejoin_channel_buffers(
3416    request: proto::RejoinChannelBuffers,
3417    response: Response<proto::RejoinChannelBuffers>,
3418    session: MessageContext,
3419) -> Result<()> {
3420    let db = session.db().await;
3421    let buffers = db
3422        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3423        .await?;
3424
3425    for rejoined_buffer in &buffers {
3426        let collaborators_to_notify = rejoined_buffer
3427            .buffer
3428            .collaborators
3429            .iter()
3430            .filter_map(|c| Some(c.peer_id?.into()));
3431        channel_buffer_updated(
3432            session.connection_id,
3433            collaborators_to_notify,
3434            &proto::UpdateChannelBufferCollaborators {
3435                channel_id: rejoined_buffer.buffer.channel_id,
3436                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3437            },
3438            &session.peer,
3439        );
3440    }
3441
3442    response.send(proto::RejoinChannelBuffersResponse {
3443        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3444    })?;
3445
3446    Ok(())
3447}
3448
3449/// Stop editing the channel notes
3450async fn leave_channel_buffer(
3451    request: proto::LeaveChannelBuffer,
3452    response: Response<proto::LeaveChannelBuffer>,
3453    session: MessageContext,
3454) -> Result<()> {
3455    let db = session.db().await;
3456    let channel_id = ChannelId::from_proto(request.channel_id);
3457
3458    let left_buffer = db
3459        .leave_channel_buffer(channel_id, session.connection_id)
3460        .await?;
3461
3462    response.send(Ack {})?;
3463
3464    channel_buffer_updated(
3465        session.connection_id,
3466        left_buffer.connections,
3467        &proto::UpdateChannelBufferCollaborators {
3468            channel_id: channel_id.to_proto(),
3469            collaborators: left_buffer.collaborators,
3470        },
3471        &session.peer,
3472    );
3473
3474    Ok(())
3475}
3476
3477fn channel_buffer_updated<T: EnvelopedMessage>(
3478    sender_id: ConnectionId,
3479    collaborators: impl IntoIterator<Item = ConnectionId>,
3480    message: &T,
3481    peer: &Peer,
3482) {
3483    broadcast(Some(sender_id), collaborators, |peer_id| {
3484        peer.send(peer_id, message.clone())
3485    });
3486}
3487
3488fn send_notifications(
3489    connection_pool: &ConnectionPool,
3490    peer: &Peer,
3491    notifications: db::NotificationBatch,
3492) {
3493    for (user_id, notification) in notifications {
3494        for connection_id in connection_pool.user_connection_ids(user_id) {
3495            if let Err(error) = peer.send(
3496                connection_id,
3497                proto::AddNotification {
3498                    notification: Some(notification.clone()),
3499                },
3500            ) {
3501                tracing::error!(
3502                    "failed to send notification to {:?} {}",
3503                    connection_id,
3504                    error
3505                );
3506            }
3507        }
3508    }
3509}
3510
3511/// Send a message to the channel
3512async fn send_channel_message(
3513    _request: proto::SendChannelMessage,
3514    _response: Response<proto::SendChannelMessage>,
3515    _session: MessageContext,
3516) -> Result<()> {
3517    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3518}
3519
3520/// Delete a channel message
3521async fn remove_channel_message(
3522    _request: proto::RemoveChannelMessage,
3523    _response: Response<proto::RemoveChannelMessage>,
3524    _session: MessageContext,
3525) -> Result<()> {
3526    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3527}
3528
3529async fn update_channel_message(
3530    _request: proto::UpdateChannelMessage,
3531    _response: Response<proto::UpdateChannelMessage>,
3532    _session: MessageContext,
3533) -> Result<()> {
3534    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3535}
3536
3537/// Mark a channel message as read
3538async fn acknowledge_channel_message(
3539    _request: proto::AckChannelMessage,
3540    _session: MessageContext,
3541) -> Result<()> {
3542    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3543}
3544
3545/// Mark a buffer version as synced
3546async fn acknowledge_buffer_version(
3547    request: proto::AckBufferOperation,
3548    session: MessageContext,
3549) -> Result<()> {
3550    let buffer_id = BufferId::from_proto(request.buffer_id);
3551    session
3552        .db()
3553        .await
3554        .observe_buffer_version(
3555            buffer_id,
3556            session.user_id(),
3557            request.epoch as i32,
3558            &request.version,
3559        )
3560        .await?;
3561    Ok(())
3562}
3563
3564/// Start receiving chat updates for a channel
3565async fn join_channel_chat(
3566    _request: proto::JoinChannelChat,
3567    _response: Response<proto::JoinChannelChat>,
3568    _session: MessageContext,
3569) -> Result<()> {
3570    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3571}
3572
3573/// Stop receiving chat updates for a channel
3574async fn leave_channel_chat(
3575    _request: proto::LeaveChannelChat,
3576    _session: MessageContext,
3577) -> Result<()> {
3578    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3579}
3580
3581/// Retrieve the chat history for a channel
3582async fn get_channel_messages(
3583    _request: proto::GetChannelMessages,
3584    _response: Response<proto::GetChannelMessages>,
3585    _session: MessageContext,
3586) -> Result<()> {
3587    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3588}
3589
3590/// Retrieve specific chat messages
3591async fn get_channel_messages_by_id(
3592    _request: proto::GetChannelMessagesById,
3593    _response: Response<proto::GetChannelMessagesById>,
3594    _session: MessageContext,
3595) -> Result<()> {
3596    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3597}
3598
3599/// Retrieve the current users notifications
3600async fn get_notifications(
3601    request: proto::GetNotifications,
3602    response: Response<proto::GetNotifications>,
3603    session: MessageContext,
3604) -> Result<()> {
3605    let notifications = session
3606        .db()
3607        .await
3608        .get_notifications(
3609            session.user_id(),
3610            NOTIFICATION_COUNT_PER_PAGE,
3611            request.before_id.map(db::NotificationId::from_proto),
3612        )
3613        .await?;
3614    response.send(proto::GetNotificationsResponse {
3615        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3616        notifications,
3617    })?;
3618    Ok(())
3619}
3620
3621/// Mark notifications as read
3622async fn mark_notification_as_read(
3623    request: proto::MarkNotificationRead,
3624    response: Response<proto::MarkNotificationRead>,
3625    session: MessageContext,
3626) -> Result<()> {
3627    let database = &session.db().await;
3628    let notifications = database
3629        .mark_notification_as_read_by_id(
3630            session.user_id(),
3631            NotificationId::from_proto(request.notification_id),
3632        )
3633        .await?;
3634    send_notifications(
3635        &*session.connection_pool().await,
3636        &session.peer,
3637        notifications,
3638    );
3639    response.send(proto::Ack {})?;
3640    Ok(())
3641}
3642
3643fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3644    let message = match message {
3645        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3646        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3647        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3648        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3649        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3650            code: frame.code.into(),
3651            reason: frame.reason.as_str().to_owned().into(),
3652        })),
3653        // We should never receive a frame while reading the message, according
3654        // to the `tungstenite` maintainers:
3655        //
3656        // > It cannot occur when you read messages from the WebSocket, but it
3657        // > can be used when you want to send the raw frames (e.g. you want to
3658        // > send the frames to the WebSocket without composing the full message first).
3659        // >
3660        // > — https://github.com/snapview/tungstenite-rs/issues/268
3661        TungsteniteMessage::Frame(_) => {
3662            bail!("received an unexpected frame while reading the message")
3663        }
3664    };
3665
3666    Ok(message)
3667}
3668
3669fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3670    match message {
3671        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3672        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3673        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3674        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3675        AxumMessage::Close(frame) => {
3676            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3677                code: frame.code.into(),
3678                reason: frame.reason.as_ref().into(),
3679            }))
3680        }
3681    }
3682}
3683
3684fn notify_membership_updated(
3685    connection_pool: &mut ConnectionPool,
3686    result: MembershipUpdated,
3687    user_id: UserId,
3688    peer: &Peer,
3689) {
3690    for membership in &result.new_channels.channel_memberships {
3691        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3692    }
3693    for channel_id in &result.removed_channels {
3694        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3695    }
3696
3697    let user_channels_update = proto::UpdateUserChannels {
3698        channel_memberships: result
3699            .new_channels
3700            .channel_memberships
3701            .iter()
3702            .map(|cm| proto::ChannelMembership {
3703                channel_id: cm.channel_id.to_proto(),
3704                role: cm.role.into(),
3705            })
3706            .collect(),
3707        ..Default::default()
3708    };
3709
3710    let mut update = build_channels_update(result.new_channels);
3711    update.delete_channels = result
3712        .removed_channels
3713        .into_iter()
3714        .map(|id| id.to_proto())
3715        .collect();
3716    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3717
3718    for connection_id in connection_pool.user_connection_ids(user_id) {
3719        peer.send(connection_id, user_channels_update.clone())
3720            .trace_err();
3721        peer.send(connection_id, update.clone()).trace_err();
3722    }
3723}
3724
3725fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3726    proto::UpdateUserChannels {
3727        channel_memberships: channels
3728            .channel_memberships
3729            .iter()
3730            .map(|m| proto::ChannelMembership {
3731                channel_id: m.channel_id.to_proto(),
3732                role: m.role.into(),
3733            })
3734            .collect(),
3735        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3736    }
3737}
3738
3739fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3740    let mut update = proto::UpdateChannels::default();
3741
3742    for channel in channels.channels {
3743        update.channels.push(channel.to_proto());
3744    }
3745
3746    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3747
3748    for (channel_id, participants) in channels.channel_participants {
3749        update
3750            .channel_participants
3751            .push(proto::ChannelParticipants {
3752                channel_id: channel_id.to_proto(),
3753                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3754            });
3755    }
3756
3757    for channel in channels.invited_channels {
3758        update.channel_invitations.push(channel.to_proto());
3759    }
3760
3761    update
3762}
3763
3764fn build_initial_contacts_update(
3765    contacts: Vec<db::Contact>,
3766    pool: &ConnectionPool,
3767) -> proto::UpdateContacts {
3768    let mut update = proto::UpdateContacts::default();
3769
3770    for contact in contacts {
3771        match contact {
3772            db::Contact::Accepted { user_id, busy } => {
3773                update.contacts.push(contact_for_user(user_id, busy, pool));
3774            }
3775            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3776            db::Contact::Incoming { user_id } => {
3777                update
3778                    .incoming_requests
3779                    .push(proto::IncomingContactRequest {
3780                        requester_id: user_id.to_proto(),
3781                    })
3782            }
3783        }
3784    }
3785
3786    update
3787}
3788
3789fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3790    proto::Contact {
3791        user_id: user_id.to_proto(),
3792        online: pool.is_user_online(user_id),
3793        busy,
3794    }
3795}
3796
3797fn room_updated(room: &proto::Room, peer: &Peer) {
3798    broadcast(
3799        None,
3800        room.participants
3801            .iter()
3802            .filter_map(|participant| Some(participant.peer_id?.into())),
3803        |peer_id| {
3804            peer.send(
3805                peer_id,
3806                proto::RoomUpdated {
3807                    room: Some(room.clone()),
3808                },
3809            )
3810        },
3811    );
3812}
3813
3814fn channel_updated(
3815    channel: &db::channel::Model,
3816    room: &proto::Room,
3817    peer: &Peer,
3818    pool: &ConnectionPool,
3819) {
3820    let participants = room
3821        .participants
3822        .iter()
3823        .map(|p| p.user_id)
3824        .collect::<Vec<_>>();
3825
3826    broadcast(
3827        None,
3828        pool.channel_connection_ids(channel.root_id())
3829            .filter_map(|(channel_id, role)| {
3830                role.can_see_channel(channel.visibility)
3831                    .then_some(channel_id)
3832            }),
3833        |peer_id| {
3834            peer.send(
3835                peer_id,
3836                proto::UpdateChannels {
3837                    channel_participants: vec![proto::ChannelParticipants {
3838                        channel_id: channel.id.to_proto(),
3839                        participant_user_ids: participants.clone(),
3840                    }],
3841                    ..Default::default()
3842                },
3843            )
3844        },
3845    );
3846}
3847
3848async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3849    let db = session.db().await;
3850
3851    let contacts = db.get_contacts(user_id).await?;
3852    let busy = db.is_user_busy(user_id).await?;
3853
3854    let pool = session.connection_pool().await;
3855    let updated_contact = contact_for_user(user_id, busy, &pool);
3856    for contact in contacts {
3857        if let db::Contact::Accepted {
3858            user_id: contact_user_id,
3859            ..
3860        } = contact
3861        {
3862            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3863                session
3864                    .peer
3865                    .send(
3866                        contact_conn_id,
3867                        proto::UpdateContacts {
3868                            contacts: vec![updated_contact.clone()],
3869                            remove_contacts: Default::default(),
3870                            incoming_requests: Default::default(),
3871                            remove_incoming_requests: Default::default(),
3872                            outgoing_requests: Default::default(),
3873                            remove_outgoing_requests: Default::default(),
3874                        },
3875                    )
3876                    .trace_err();
3877            }
3878        }
3879    }
3880    Ok(())
3881}
3882
3883async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3884    let mut contacts_to_update = HashSet::default();
3885
3886    let room_id;
3887    let canceled_calls_to_user_ids;
3888    let livekit_room;
3889    let delete_livekit_room;
3890    let room;
3891    let channel;
3892
3893    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3894        contacts_to_update.insert(session.user_id());
3895
3896        for project in left_room.left_projects.values() {
3897            project_left(project, session);
3898        }
3899
3900        room_id = RoomId::from_proto(left_room.room.id);
3901        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3902        livekit_room = mem::take(&mut left_room.room.livekit_room);
3903        delete_livekit_room = left_room.deleted;
3904        room = mem::take(&mut left_room.room);
3905        channel = mem::take(&mut left_room.channel);
3906
3907        room_updated(&room, &session.peer);
3908    } else {
3909        return Ok(());
3910    }
3911
3912    if let Some(channel) = channel {
3913        channel_updated(
3914            &channel,
3915            &room,
3916            &session.peer,
3917            &*session.connection_pool().await,
3918        );
3919    }
3920
3921    {
3922        let pool = session.connection_pool().await;
3923        for canceled_user_id in canceled_calls_to_user_ids {
3924            for connection_id in pool.user_connection_ids(canceled_user_id) {
3925                session
3926                    .peer
3927                    .send(
3928                        connection_id,
3929                        proto::CallCanceled {
3930                            room_id: room_id.to_proto(),
3931                        },
3932                    )
3933                    .trace_err();
3934            }
3935            contacts_to_update.insert(canceled_user_id);
3936        }
3937    }
3938
3939    for contact_user_id in contacts_to_update {
3940        update_user_contacts(contact_user_id, session).await?;
3941    }
3942
3943    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3944        live_kit
3945            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3946            .await
3947            .trace_err();
3948
3949        if delete_livekit_room {
3950            live_kit.delete_room(livekit_room).await.trace_err();
3951        }
3952    }
3953
3954    Ok(())
3955}
3956
3957async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3958    let left_channel_buffers = session
3959        .db()
3960        .await
3961        .leave_channel_buffers(session.connection_id)
3962        .await?;
3963
3964    for left_buffer in left_channel_buffers {
3965        channel_buffer_updated(
3966            session.connection_id,
3967            left_buffer.connections,
3968            &proto::UpdateChannelBufferCollaborators {
3969                channel_id: left_buffer.channel_id.to_proto(),
3970                collaborators: left_buffer.collaborators,
3971            },
3972            &session.peer,
3973        );
3974    }
3975
3976    Ok(())
3977}
3978
3979fn project_left(project: &db::LeftProject, session: &Session) {
3980    for connection_id in &project.connection_ids {
3981        if project.should_unshare {
3982            session
3983                .peer
3984                .send(
3985                    *connection_id,
3986                    proto::UnshareProject {
3987                        project_id: project.id.to_proto(),
3988                    },
3989                )
3990                .trace_err();
3991        } else {
3992            session
3993                .peer
3994                .send(
3995                    *connection_id,
3996                    proto::RemoveProjectCollaborator {
3997                        project_id: project.id.to_proto(),
3998                        peer_id: Some(session.connection_id.into()),
3999                    },
4000                )
4001                .trace_err();
4002        }
4003    }
4004}
4005
4006async fn share_agent_thread(
4007    request: proto::ShareAgentThread,
4008    response: Response<proto::ShareAgentThread>,
4009    session: MessageContext,
4010) -> Result<()> {
4011    let user_id = session.user_id();
4012
4013    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4014        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4015
4016    session
4017        .db()
4018        .await
4019        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4020        .await?;
4021
4022    response.send(proto::Ack {})?;
4023
4024    Ok(())
4025}
4026
4027async fn get_shared_agent_thread(
4028    request: proto::GetSharedAgentThread,
4029    response: Response<proto::GetSharedAgentThread>,
4030    session: MessageContext,
4031) -> Result<()> {
4032    let share_id = SharedThreadId::from_proto(request.session_id)
4033        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4034
4035    let result = session.db().await.get_shared_thread(share_id).await?;
4036
4037    match result {
4038        Some((thread, username)) => {
4039            response.send(proto::GetSharedAgentThreadResponse {
4040                title: thread.title,
4041                thread_data: thread.data,
4042                sharer_username: username,
4043                created_at: thread.created_at.and_utc().to_rfc3339(),
4044            })?;
4045        }
4046        None => {
4047            return Err(anyhow!("Shared thread not found").into());
4048        }
4049    }
4050
4051    Ok(())
4052}
4053
4054pub trait ResultExt {
4055    type Ok;
4056
4057    fn trace_err(self) -> Option<Self::Ok>;
4058}
4059
4060impl<T, E> ResultExt for Result<T, E>
4061where
4062    E: std::fmt::Debug,
4063{
4064    type Ok = T;
4065
4066    #[track_caller]
4067    fn trace_err(self) -> Option<T> {
4068        match self {
4069            Ok(value) => Some(value),
4070            Err(error) => {
4071                tracing::error!("{:?}", error);
4072                None
4073            }
4074        }
4075    }
4076}