rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 415            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 416            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 417            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 419            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 421            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 422            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 423            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 426            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 427            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 430            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 431            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 437            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 438            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 439            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 440            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 441            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 442            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 443            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 444            .add_request_handler(share_agent_thread)
 445            .add_request_handler(get_shared_agent_thread)
 446            .add_request_handler(forward_project_search_chunk);
 447
 448        Arc::new(server)
 449    }
 450
 451    pub async fn start(&self) -> Result<()> {
 452        let server_id = *self.id.lock();
 453        let app_state = self.app_state.clone();
 454        let peer = self.peer.clone();
 455        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 456        let pool = self.connection_pool.clone();
 457        let livekit_client = self.app_state.livekit_client.clone();
 458
 459        let span = info_span!("start server");
 460        self.app_state.executor.spawn_detached(
 461            async move {
 462                tracing::info!("waiting for cleanup timeout");
 463                timeout.await;
 464                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 465
 466                app_state
 467                    .db
 468                    .delete_stale_channel_chat_participants(
 469                        &app_state.config.zed_environment,
 470                        server_id,
 471                    )
 472                    .await
 473                    .trace_err();
 474
 475                if let Some((room_ids, channel_ids)) = app_state
 476                    .db
 477                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 478                    .await
 479                    .trace_err()
 480                {
 481                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 482                    tracing::info!(
 483                        stale_channel_buffer_count = channel_ids.len(),
 484                        "retrieved stale channel buffers"
 485                    );
 486
 487                    for channel_id in channel_ids {
 488                        if let Some(refreshed_channel_buffer) = app_state
 489                            .db
 490                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 491                            .await
 492                            .trace_err()
 493                        {
 494                            for connection_id in refreshed_channel_buffer.connection_ids {
 495                                peer.send(
 496                                    connection_id,
 497                                    proto::UpdateChannelBufferCollaborators {
 498                                        channel_id: channel_id.to_proto(),
 499                                        collaborators: refreshed_channel_buffer
 500                                            .collaborators
 501                                            .clone(),
 502                                    },
 503                                )
 504                                .trace_err();
 505                            }
 506                        }
 507                    }
 508
 509                    for room_id in room_ids {
 510                        let mut contacts_to_update = HashSet::default();
 511                        let mut canceled_calls_to_user_ids = Vec::new();
 512                        let mut livekit_room = String::new();
 513                        let mut delete_livekit_room = false;
 514
 515                        if let Some(mut refreshed_room) = app_state
 516                            .db
 517                            .clear_stale_room_participants(room_id, server_id)
 518                            .await
 519                            .trace_err()
 520                        {
 521                            tracing::info!(
 522                                room_id = room_id.0,
 523                                new_participant_count = refreshed_room.room.participants.len(),
 524                                "refreshed room"
 525                            );
 526                            room_updated(&refreshed_room.room, &peer);
 527                            if let Some(channel) = refreshed_room.channel.as_ref() {
 528                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 529                            }
 530                            contacts_to_update
 531                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 532                            contacts_to_update
 533                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 534                            canceled_calls_to_user_ids =
 535                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 536                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 537                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 538                        }
 539
 540                        {
 541                            let pool = pool.lock();
 542                            for canceled_user_id in canceled_calls_to_user_ids {
 543                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 544                                    peer.send(
 545                                        connection_id,
 546                                        proto::CallCanceled {
 547                                            room_id: room_id.to_proto(),
 548                                        },
 549                                    )
 550                                    .trace_err();
 551                                }
 552                            }
 553                        }
 554
 555                        for user_id in contacts_to_update {
 556                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 557                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 558                            if let Some((busy, contacts)) = busy.zip(contacts) {
 559                                let pool = pool.lock();
 560                                let updated_contact = contact_for_user(user_id, busy, &pool);
 561                                for contact in contacts {
 562                                    if let db::Contact::Accepted {
 563                                        user_id: contact_user_id,
 564                                        ..
 565                                    } = contact
 566                                    {
 567                                        for contact_conn_id in
 568                                            pool.user_connection_ids(contact_user_id)
 569                                        {
 570                                            peer.send(
 571                                                contact_conn_id,
 572                                                proto::UpdateContacts {
 573                                                    contacts: vec![updated_contact.clone()],
 574                                                    remove_contacts: Default::default(),
 575                                                    incoming_requests: Default::default(),
 576                                                    remove_incoming_requests: Default::default(),
 577                                                    outgoing_requests: Default::default(),
 578                                                    remove_outgoing_requests: Default::default(),
 579                                                },
 580                                            )
 581                                            .trace_err();
 582                                        }
 583                                    }
 584                                }
 585                            }
 586                        }
 587
 588                        if let Some(live_kit) = livekit_client.as_ref()
 589                            && delete_livekit_room
 590                        {
 591                            live_kit.delete_room(livekit_room).await.trace_err();
 592                        }
 593                    }
 594                }
 595
 596                app_state
 597                    .db
 598                    .delete_stale_channel_chat_participants(
 599                        &app_state.config.zed_environment,
 600                        server_id,
 601                    )
 602                    .await
 603                    .trace_err();
 604
 605                app_state
 606                    .db
 607                    .clear_old_worktree_entries(server_id)
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 614                    .await
 615                    .trace_err();
 616            }
 617            .instrument(span),
 618        );
 619        Ok(())
 620    }
 621
 622    pub fn teardown(&self) {
 623        self.peer.teardown();
 624        self.connection_pool.lock().reset();
 625        let _ = self.teardown.send(true);
 626    }
 627
 628    #[cfg(feature = "test-support")]
 629    pub fn reset(&self, id: ServerId) {
 630        self.teardown();
 631        *self.id.lock() = id;
 632        self.peer.reset(id.0 as u32);
 633        let _ = self.teardown.send(false);
 634    }
 635
 636    #[cfg(feature = "test-support")]
 637    pub fn id(&self) -> ServerId {
 638        *self.id.lock()
 639    }
 640
 641    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 642    where
 643        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 644        Fut: 'static + Send + Future<Output = Result<()>>,
 645        M: EnvelopedMessage,
 646    {
 647        let prev_handler = self.handlers.insert(
 648            TypeId::of::<M>(),
 649            Box::new(move |envelope, session, span| {
 650                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 651                let received_at = envelope.received_at;
 652                tracing::info!("message received");
 653                let start_time = Instant::now();
 654                let future = (handler)(
 655                    *envelope,
 656                    MessageContext {
 657                        session,
 658                        span: span.clone(),
 659                    },
 660                );
 661                async move {
 662                    let result = future.await;
 663                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 664                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 665                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 666                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 667                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 668                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 669                    match result {
 670                        Err(error) => {
 671                            tracing::error!(?error, "error handling message")
 672                        }
 673                        Ok(()) => tracing::info!("finished handling message"),
 674                    }
 675                }
 676                .boxed()
 677            }),
 678        );
 679        if prev_handler.is_some() {
 680            panic!("registered a handler for the same message twice");
 681        }
 682        self
 683    }
 684
 685    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 686    where
 687        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 688        Fut: 'static + Send + Future<Output = Result<()>>,
 689        M: EnvelopedMessage,
 690    {
 691        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 692        self
 693    }
 694
 695    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 696    where
 697        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 698        Fut: Send + Future<Output = Result<()>>,
 699        M: RequestMessage,
 700    {
 701        let handler = Arc::new(handler);
 702        self.add_handler(move |envelope, session| {
 703            let receipt = envelope.receipt();
 704            let handler = handler.clone();
 705            async move {
 706                let peer = session.peer.clone();
 707                let responded = Arc::new(AtomicBool::default());
 708                let response = Response {
 709                    peer: peer.clone(),
 710                    responded: responded.clone(),
 711                    receipt,
 712                };
 713                match (handler)(envelope.payload, response, session).await {
 714                    Ok(()) => {
 715                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 716                            Ok(())
 717                        } else {
 718                            Err(anyhow!("handler did not send a response"))?
 719                        }
 720                    }
 721                    Err(error) => {
 722                        let proto_err = match &error {
 723                            Error::Internal(err) => err.to_proto(),
 724                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 725                        };
 726                        peer.respond_with_error(receipt, proto_err)?;
 727                        Err(error)
 728                    }
 729                }
 730            }
 731        })
 732    }
 733
 734    pub fn handle_connection(
 735        self: &Arc<Self>,
 736        connection: Connection,
 737        address: String,
 738        principal: Principal,
 739        zed_version: ZedVersion,
 740        release_channel: Option<String>,
 741        user_agent: Option<String>,
 742        geoip_country_code: Option<String>,
 743        system_id: Option<String>,
 744        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 745        executor: Executor,
 746        connection_guard: Option<ConnectionGuard>,
 747    ) -> impl Future<Output = ()> + use<> {
 748        let this = self.clone();
 749        let span = info_span!("handle connection", %address,
 750            connection_id=field::Empty,
 751            user_id=field::Empty,
 752            login=field::Empty,
 753            user_agent=field::Empty,
 754            geoip_country_code=field::Empty,
 755            release_channel=field::Empty,
 756        );
 757        principal.update_span(&span);
 758        if let Some(user_agent) = user_agent {
 759            span.record("user_agent", user_agent);
 760        }
 761        if let Some(release_channel) = release_channel {
 762            span.record("release_channel", release_channel);
 763        }
 764
 765        if let Some(country_code) = geoip_country_code.as_ref() {
 766            span.record("geoip_country_code", country_code);
 767        }
 768
 769        let mut teardown = self.teardown.subscribe();
 770        async move {
 771            if *teardown.borrow() {
 772                tracing::error!("server is tearing down");
 773                return;
 774            }
 775
 776            let (connection_id, handle_io, mut incoming_rx) =
 777                this.peer.add_connection(connection, {
 778                    let executor = executor.clone();
 779                    move |duration| executor.sleep(duration)
 780                });
 781            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 782
 783            tracing::info!("connection opened");
 784
 785            let session = Session {
 786                principal: principal.clone(),
 787                connection_id,
 788                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 789                peer: this.peer.clone(),
 790                connection_pool: this.connection_pool.clone(),
 791                app_state: this.app_state.clone(),
 792                geoip_country_code,
 793                system_id,
 794                _executor: executor.clone(),
 795            };
 796
 797            if let Err(error) = this
 798                .send_initial_client_update(
 799                    connection_id,
 800                    zed_version,
 801                    send_connection_id,
 802                    &session,
 803                )
 804                .await
 805            {
 806                tracing::error!(?error, "failed to send initial client update");
 807                return;
 808            }
 809            drop(connection_guard);
 810
 811            let handle_io = handle_io.fuse();
 812            futures::pin_mut!(handle_io);
 813
 814            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 815            // This prevents deadlocks when e.g., client A performs a request to client B and
 816            // client B performs a request to client A. If both clients stop processing further
 817            // messages until their respective request completes, they won't have a chance to
 818            // respond to the other client's request and cause a deadlock.
 819            //
 820            // This arrangement ensures we will attempt to process earlier messages first, but fall
 821            // back to processing messages arrived later in the spirit of making progress.
 822            const MAX_CONCURRENT_HANDLERS: usize = 256;
 823            let mut foreground_message_handlers = FuturesUnordered::new();
 824            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 825            let get_concurrent_handlers = {
 826                let concurrent_handlers = concurrent_handlers.clone();
 827                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 828            };
 829            loop {
 830                let next_message = async {
 831                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 832                    let message = incoming_rx.next().await;
 833                    // Cache the concurrent_handlers here, so that we know what the
 834                    // queue looks like as each handler starts
 835                    (permit, message, get_concurrent_handlers())
 836                }
 837                .fuse();
 838                futures::pin_mut!(next_message);
 839                futures::select_biased! {
 840                    _ = teardown.changed().fuse() => return,
 841                    result = handle_io => {
 842                        if let Err(error) = result {
 843                            tracing::error!(?error, "error handling I/O");
 844                        }
 845                        break;
 846                    }
 847                    _ = foreground_message_handlers.next() => {}
 848                    next_message = next_message => {
 849                        let (permit, message, concurrent_handlers) = next_message;
 850                        if let Some(message) = message {
 851                            let type_name = message.payload_type_name();
 852                            // note: we copy all the fields from the parent span so we can query them in the logs.
 853                            // (https://github.com/tokio-rs/tracing/issues/2670).
 854                            let span = tracing::info_span!("receive message",
 855                                %connection_id,
 856                                %address,
 857                                type_name,
 858                                concurrent_handlers,
 859                                user_id=field::Empty,
 860                                login=field::Empty,
 861                                lsp_query_request=field::Empty,
 862                                release_channel=field::Empty,
 863                                { TOTAL_DURATION_MS }=field::Empty,
 864                                { PROCESSING_DURATION_MS }=field::Empty,
 865                                { QUEUE_DURATION_MS }=field::Empty,
 866                                { HOST_WAITING_MS }=field::Empty
 867                            );
 868                            principal.update_span(&span);
 869                            let span_enter = span.enter();
 870                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 871                                let is_background = message.is_background();
 872                                let handle_message = (handler)(message, session.clone(), span.clone());
 873                                drop(span_enter);
 874
 875                                let handle_message = async move {
 876                                    handle_message.await;
 877                                    drop(permit);
 878                                }.instrument(span);
 879                                if is_background {
 880                                    executor.spawn_detached(handle_message);
 881                                } else {
 882                                    foreground_message_handlers.push(handle_message);
 883                                }
 884                            } else {
 885                                tracing::error!("no message handler");
 886                            }
 887                        } else {
 888                            tracing::info!("connection closed");
 889                            break;
 890                        }
 891                    }
 892                }
 893            }
 894
 895            drop(foreground_message_handlers);
 896            let concurrent_handlers = get_concurrent_handlers();
 897            tracing::info!(concurrent_handlers, "signing out");
 898            if let Err(error) = connection_lost(session, teardown, executor).await {
 899                tracing::error!(?error, "error signing out");
 900            }
 901        }
 902        .instrument(span)
 903    }
 904
 905    async fn send_initial_client_update(
 906        &self,
 907        connection_id: ConnectionId,
 908        zed_version: ZedVersion,
 909        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 910        session: &Session,
 911    ) -> Result<()> {
 912        self.peer.send(
 913            connection_id,
 914            proto::Hello {
 915                peer_id: Some(connection_id.into()),
 916            },
 917        )?;
 918        tracing::info!("sent hello message");
 919        if let Some(send_connection_id) = send_connection_id.take() {
 920            let _ = send_connection_id.send(connection_id);
 921        }
 922
 923        match &session.principal {
 924            Principal::User(user) => {
 925                if !user.connected_once {
 926                    self.peer.send(connection_id, proto::ShowContacts {})?;
 927                    self.app_state
 928                        .db
 929                        .set_user_connected_once(user.id, true)
 930                        .await?;
 931                }
 932
 933                let contacts = self.app_state.db.get_contacts(user.id).await?;
 934
 935                {
 936                    let mut pool = self.connection_pool.lock();
 937                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 938                    self.peer.send(
 939                        connection_id,
 940                        build_initial_contacts_update(contacts, &pool),
 941                    )?;
 942                }
 943
 944                if should_auto_subscribe_to_channels(&zed_version) {
 945                    subscribe_user_to_channels(user.id, session).await?;
 946                }
 947
 948                if let Some(incoming_call) =
 949                    self.app_state.db.incoming_call_for_user(user.id).await?
 950                {
 951                    self.peer.send(connection_id, incoming_call)?;
 952                }
 953
 954                update_user_contacts(user.id, session).await?;
 955            }
 956        }
 957
 958        Ok(())
 959    }
 960}
 961
 962impl Deref for ConnectionPoolGuard<'_> {
 963    type Target = ConnectionPool;
 964
 965    fn deref(&self) -> &Self::Target {
 966        &self.guard
 967    }
 968}
 969
 970impl DerefMut for ConnectionPoolGuard<'_> {
 971    fn deref_mut(&mut self) -> &mut Self::Target {
 972        &mut self.guard
 973    }
 974}
 975
 976impl Drop for ConnectionPoolGuard<'_> {
 977    fn drop(&mut self) {
 978        #[cfg(feature = "test-support")]
 979        self.check_invariants();
 980    }
 981}
 982
 983fn broadcast<F>(
 984    sender_id: Option<ConnectionId>,
 985    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 986    mut f: F,
 987) where
 988    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 989{
 990    for receiver_id in receiver_ids {
 991        if Some(receiver_id) != sender_id
 992            && let Err(error) = f(receiver_id)
 993        {
 994            tracing::error!("failed to send to {:?} {}", receiver_id, error);
 995        }
 996    }
 997}
 998
 999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002    fn name() -> &'static HeaderName {
1003        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005    }
1006
1007    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008    where
1009        Self: Sized,
1010        I: Iterator<Item = &'i axum::http::HeaderValue>,
1011    {
1012        let version = values
1013            .next()
1014            .ok_or_else(axum::headers::Error::invalid)?
1015            .to_str()
1016            .map_err(|_| axum::headers::Error::invalid())?
1017            .parse()
1018            .map_err(|_| axum::headers::Error::invalid())?;
1019        Ok(Self(version))
1020    }
1021
1022    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023        values.extend([self.0.to_string().parse().unwrap()]);
1024    }
1025}
1026
1027pub struct AppVersionHeader(Version);
1028impl Header for AppVersionHeader {
1029    fn name() -> &'static HeaderName {
1030        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032    }
1033
1034    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035    where
1036        Self: Sized,
1037        I: Iterator<Item = &'i axum::http::HeaderValue>,
1038    {
1039        let version = values
1040            .next()
1041            .ok_or_else(axum::headers::Error::invalid)?
1042            .to_str()
1043            .map_err(|_| axum::headers::Error::invalid())?
1044            .parse()
1045            .map_err(|_| axum::headers::Error::invalid())?;
1046        Ok(Self(version))
1047    }
1048
1049    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050        values.extend([self.0.to_string().parse().unwrap()]);
1051    }
1052}
1053
1054#[derive(Debug)]
1055pub struct ReleaseChannelHeader(String);
1056
1057impl Header for ReleaseChannelHeader {
1058    fn name() -> &'static HeaderName {
1059        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1060        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1061    }
1062
1063    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1064    where
1065        Self: Sized,
1066        I: Iterator<Item = &'i axum::http::HeaderValue>,
1067    {
1068        Ok(Self(
1069            values
1070                .next()
1071                .ok_or_else(axum::headers::Error::invalid)?
1072                .to_str()
1073                .map_err(|_| axum::headers::Error::invalid())?
1074                .to_owned(),
1075        ))
1076    }
1077
1078    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1079        values.extend([self.0.parse().unwrap()]);
1080    }
1081}
1082
1083pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1084    Router::new()
1085        .route("/rpc", get(handle_websocket_request))
1086        .layer(
1087            ServiceBuilder::new()
1088                .layer(Extension(server.app_state.clone()))
1089                .layer(middleware::from_fn(auth::validate_header)),
1090        )
1091        .route("/metrics", get(handle_metrics))
1092        .layer(Extension(server))
1093}
1094
1095pub async fn handle_websocket_request(
1096    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1097    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1098    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1099    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1100    Extension(server): Extension<Arc<Server>>,
1101    Extension(principal): Extension<Principal>,
1102    user_agent: Option<TypedHeader<UserAgent>>,
1103    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1104    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1105    ws: WebSocketUpgrade,
1106) -> axum::response::Response {
1107    if protocol_version != rpc::PROTOCOL_VERSION {
1108        return (
1109            StatusCode::UPGRADE_REQUIRED,
1110            "client must be upgraded".to_string(),
1111        )
1112            .into_response();
1113    }
1114
1115    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1116        return (
1117            StatusCode::UPGRADE_REQUIRED,
1118            "no version header found".to_string(),
1119        )
1120            .into_response();
1121    };
1122
1123    let release_channel = release_channel_header.map(|header| header.0.0);
1124
1125    if !version.can_collaborate() {
1126        return (
1127            StatusCode::UPGRADE_REQUIRED,
1128            "client must be upgraded".to_string(),
1129        )
1130            .into_response();
1131    }
1132
1133    let socket_address = socket_address.to_string();
1134
1135    // Acquire connection guard before WebSocket upgrade
1136    let connection_guard = match ConnectionGuard::try_acquire() {
1137        Ok(guard) => guard,
1138        Err(()) => {
1139            return (
1140                StatusCode::SERVICE_UNAVAILABLE,
1141                "Too many concurrent connections",
1142            )
1143                .into_response();
1144        }
1145    };
1146
1147    ws.on_upgrade(move |socket| {
1148        let socket = socket
1149            .map_ok(to_tungstenite_message)
1150            .err_into()
1151            .with(|message| async move { to_axum_message(message) });
1152        let connection = Connection::new(Box::pin(socket));
1153        async move {
1154            server
1155                .handle_connection(
1156                    connection,
1157                    socket_address,
1158                    principal,
1159                    version,
1160                    release_channel,
1161                    user_agent.map(|header| header.to_string()),
1162                    country_code_header.map(|header| header.to_string()),
1163                    system_id_header.map(|header| header.to_string()),
1164                    None,
1165                    Executor::Production,
1166                    Some(connection_guard),
1167                )
1168                .await;
1169        }
1170    })
1171}
1172
1173pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1174    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1175    let connections_metric = CONNECTIONS_METRIC
1176        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1177
1178    let connections = server
1179        .connection_pool
1180        .lock()
1181        .connections()
1182        .filter(|connection| !connection.admin)
1183        .count();
1184    connections_metric.set(connections as _);
1185
1186    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1187    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1188        register_int_gauge!(
1189            "shared_projects",
1190            "number of open projects with one or more guests"
1191        )
1192        .unwrap()
1193    });
1194
1195    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1196    shared_projects_metric.set(shared_projects as _);
1197
1198    let encoder = prometheus::TextEncoder::new();
1199    let metric_families = prometheus::gather();
1200    let encoded_metrics = encoder
1201        .encode_to_string(&metric_families)
1202        .map_err(|err| anyhow!("{err}"))?;
1203    Ok(encoded_metrics)
1204}
1205
1206#[instrument(err, skip(executor))]
1207async fn connection_lost(
1208    session: Session,
1209    mut teardown: watch::Receiver<bool>,
1210    executor: Executor,
1211) -> Result<()> {
1212    session.peer.disconnect(session.connection_id);
1213    session
1214        .connection_pool()
1215        .await
1216        .remove_connection(session.connection_id)?;
1217
1218    session
1219        .db()
1220        .await
1221        .connection_lost(session.connection_id)
1222        .await
1223        .trace_err();
1224
1225    futures::select_biased! {
1226        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1227
1228            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1229            leave_room_for_session(&session, session.connection_id).await.trace_err();
1230            leave_channel_buffers_for_session(&session)
1231                .await
1232                .trace_err();
1233
1234            if !session
1235                .connection_pool()
1236                .await
1237                .is_user_online(session.user_id())
1238            {
1239                let db = session.db().await;
1240                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1241                    room_updated(&room, &session.peer);
1242                }
1243            }
1244
1245            update_user_contacts(session.user_id(), &session).await?;
1246        },
1247        _ = teardown.changed().fuse() => {}
1248    }
1249
1250    Ok(())
1251}
1252
1253/// Acknowledges a ping from a client, used to keep the connection alive.
1254async fn ping(
1255    _: proto::Ping,
1256    response: Response<proto::Ping>,
1257    _session: MessageContext,
1258) -> Result<()> {
1259    response.send(proto::Ack {})?;
1260    Ok(())
1261}
1262
1263/// Creates a new room for calling (outside of channels)
1264async fn create_room(
1265    _request: proto::CreateRoom,
1266    response: Response<proto::CreateRoom>,
1267    session: MessageContext,
1268) -> Result<()> {
1269    let livekit_room = nanoid::nanoid!(30);
1270
1271    let live_kit_connection_info = util::maybe!(async {
1272        let live_kit = session.app_state.livekit_client.as_ref();
1273        let live_kit = live_kit?;
1274        let user_id = session.user_id().to_string();
1275
1276        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1277
1278        Some(proto::LiveKitConnectionInfo {
1279            server_url: live_kit.url().into(),
1280            token,
1281            can_publish: true,
1282        })
1283    })
1284    .await;
1285
1286    let room = session
1287        .db()
1288        .await
1289        .create_room(session.user_id(), session.connection_id, &livekit_room)
1290        .await?;
1291
1292    response.send(proto::CreateRoomResponse {
1293        room: Some(room.clone()),
1294        live_kit_connection_info,
1295    })?;
1296
1297    update_user_contacts(session.user_id(), &session).await?;
1298    Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303    request: proto::JoinRoom,
1304    response: Response<proto::JoinRoom>,
1305    session: MessageContext,
1306) -> Result<()> {
1307    let room_id = RoomId::from_proto(request.id);
1308
1309    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311    if let Some(channel_id) = channel_id {
1312        return join_channel_internal(channel_id, Box::new(response), session).await;
1313    }
1314
1315    let joined_room = {
1316        let room = session
1317            .db()
1318            .await
1319            .join_room(room_id, session.user_id(), session.connection_id)
1320            .await?;
1321        room_updated(&room.room, &session.peer);
1322        room.into_inner()
1323    };
1324
1325    for connection_id in session
1326        .connection_pool()
1327        .await
1328        .user_connection_ids(session.user_id())
1329    {
1330        session
1331            .peer
1332            .send(
1333                connection_id,
1334                proto::CallCanceled {
1335                    room_id: room_id.to_proto(),
1336                },
1337            )
1338            .trace_err();
1339    }
1340
1341    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342    {
1343        live_kit
1344            .room_token(
1345                &joined_room.room.livekit_room,
1346                &session.user_id().to_string(),
1347            )
1348            .trace_err()
1349            .map(|token| proto::LiveKitConnectionInfo {
1350                server_url: live_kit.url().into(),
1351                token,
1352                can_publish: true,
1353            })
1354    } else {
1355        None
1356    };
1357
1358    response.send(proto::JoinRoomResponse {
1359        room: Some(joined_room.room),
1360        channel_id: None,
1361        live_kit_connection_info,
1362    })?;
1363
1364    update_user_contacts(session.user_id(), &session).await?;
1365    Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370    request: proto::RejoinRoom,
1371    response: Response<proto::RejoinRoom>,
1372    session: MessageContext,
1373) -> Result<()> {
1374    let room;
1375    let channel;
1376    {
1377        let mut rejoined_room = session
1378            .db()
1379            .await
1380            .rejoin_room(request, session.user_id(), session.connection_id)
1381            .await?;
1382
1383        response.send(proto::RejoinRoomResponse {
1384            room: Some(rejoined_room.room.clone()),
1385            reshared_projects: rejoined_room
1386                .reshared_projects
1387                .iter()
1388                .map(|project| proto::ResharedProject {
1389                    id: project.id.to_proto(),
1390                    collaborators: project
1391                        .collaborators
1392                        .iter()
1393                        .map(|collaborator| collaborator.to_proto())
1394                        .collect(),
1395                })
1396                .collect(),
1397            rejoined_projects: rejoined_room
1398                .rejoined_projects
1399                .iter()
1400                .map(|rejoined_project| rejoined_project.to_proto())
1401                .collect(),
1402        })?;
1403        room_updated(&rejoined_room.room, &session.peer);
1404
1405        for project in &rejoined_room.reshared_projects {
1406            for collaborator in &project.collaborators {
1407                session
1408                    .peer
1409                    .send(
1410                        collaborator.connection_id,
1411                        proto::UpdateProjectCollaborator {
1412                            project_id: project.id.to_proto(),
1413                            old_peer_id: Some(project.old_connection_id.into()),
1414                            new_peer_id: Some(session.connection_id.into()),
1415                        },
1416                    )
1417                    .trace_err();
1418            }
1419
1420            broadcast(
1421                Some(session.connection_id),
1422                project
1423                    .collaborators
1424                    .iter()
1425                    .map(|collaborator| collaborator.connection_id),
1426                |connection_id| {
1427                    session.peer.forward_send(
1428                        session.connection_id,
1429                        connection_id,
1430                        proto::UpdateProject {
1431                            project_id: project.id.to_proto(),
1432                            worktrees: project.worktrees.clone(),
1433                        },
1434                    )
1435                },
1436            );
1437        }
1438
1439        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441        let rejoined_room = rejoined_room.into_inner();
1442
1443        room = rejoined_room.room;
1444        channel = rejoined_room.channel;
1445    }
1446
1447    if let Some(channel) = channel {
1448        channel_updated(
1449            &channel,
1450            &room,
1451            &session.peer,
1452            &*session.connection_pool().await,
1453        );
1454    }
1455
1456    update_user_contacts(session.user_id(), &session).await?;
1457    Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461    rejoined_projects: &mut Vec<RejoinedProject>,
1462    session: &Session,
1463) -> Result<()> {
1464    for project in rejoined_projects.iter() {
1465        for collaborator in &project.collaborators {
1466            session
1467                .peer
1468                .send(
1469                    collaborator.connection_id,
1470                    proto::UpdateProjectCollaborator {
1471                        project_id: project.id.to_proto(),
1472                        old_peer_id: Some(project.old_connection_id.into()),
1473                        new_peer_id: Some(session.connection_id.into()),
1474                    },
1475                )
1476                .trace_err();
1477        }
1478    }
1479
1480    for project in rejoined_projects {
1481        for worktree in mem::take(&mut project.worktrees) {
1482            // Stream this worktree's entries.
1483            let message = proto::UpdateWorktree {
1484                project_id: project.id.to_proto(),
1485                worktree_id: worktree.id,
1486                abs_path: worktree.abs_path.clone(),
1487                root_name: worktree.root_name,
1488                updated_entries: worktree.updated_entries,
1489                removed_entries: worktree.removed_entries,
1490                scan_id: worktree.scan_id,
1491                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1492                updated_repositories: worktree.updated_repositories,
1493                removed_repositories: worktree.removed_repositories,
1494            };
1495            for update in proto::split_worktree_update(message) {
1496                session.peer.send(session.connection_id, update)?;
1497            }
1498
1499            // Stream this worktree's diagnostics.
1500            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1501            if let Some(summary) = worktree_diagnostics.next() {
1502                let message = proto::UpdateDiagnosticSummary {
1503                    project_id: project.id.to_proto(),
1504                    worktree_id: worktree.id,
1505                    summary: Some(summary),
1506                    more_summaries: worktree_diagnostics.collect(),
1507                };
1508                session.peer.send(session.connection_id, message)?;
1509            }
1510
1511            for settings_file in worktree.settings_files {
1512                session.peer.send(
1513                    session.connection_id,
1514                    proto::UpdateWorktreeSettings {
1515                        project_id: project.id.to_proto(),
1516                        worktree_id: worktree.id,
1517                        path: settings_file.path,
1518                        content: Some(settings_file.content),
1519                        kind: Some(settings_file.kind.to_proto().into()),
1520                        outside_worktree: Some(settings_file.outside_worktree),
1521                    },
1522                )?;
1523            }
1524        }
1525
1526        for repository in mem::take(&mut project.updated_repositories) {
1527            for update in split_repository_update(repository) {
1528                session.peer.send(session.connection_id, update)?;
1529            }
1530        }
1531
1532        for id in mem::take(&mut project.removed_repositories) {
1533            session.peer.send(
1534                session.connection_id,
1535                proto::RemoveRepository {
1536                    project_id: project.id.to_proto(),
1537                    id,
1538                },
1539            )?;
1540        }
1541    }
1542
1543    Ok(())
1544}
1545
1546/// leave room disconnects from the room.
1547async fn leave_room(
1548    _: proto::LeaveRoom,
1549    response: Response<proto::LeaveRoom>,
1550    session: MessageContext,
1551) -> Result<()> {
1552    leave_room_for_session(&session, session.connection_id).await?;
1553    response.send(proto::Ack {})?;
1554    Ok(())
1555}
1556
1557/// Updates the permissions of someone else in the room.
1558async fn set_room_participant_role(
1559    request: proto::SetRoomParticipantRole,
1560    response: Response<proto::SetRoomParticipantRole>,
1561    session: MessageContext,
1562) -> Result<()> {
1563    let user_id = UserId::from_proto(request.user_id);
1564    let role = ChannelRole::from(request.role());
1565
1566    let (livekit_room, can_publish) = {
1567        let room = session
1568            .db()
1569            .await
1570            .set_room_participant_role(
1571                session.user_id(),
1572                RoomId::from_proto(request.room_id),
1573                user_id,
1574                role,
1575            )
1576            .await?;
1577
1578        let livekit_room = room.livekit_room.clone();
1579        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1580        room_updated(&room, &session.peer);
1581        (livekit_room, can_publish)
1582    };
1583
1584    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1585        live_kit
1586            .update_participant(
1587                livekit_room.clone(),
1588                request.user_id.to_string(),
1589                livekit_api::proto::ParticipantPermission {
1590                    can_subscribe: true,
1591                    can_publish,
1592                    can_publish_data: can_publish,
1593                    hidden: false,
1594                    recorder: false,
1595                },
1596            )
1597            .await
1598            .trace_err();
1599    }
1600
1601    response.send(proto::Ack {})?;
1602    Ok(())
1603}
1604
1605/// Call someone else into the current room
1606async fn call(
1607    request: proto::Call,
1608    response: Response<proto::Call>,
1609    session: MessageContext,
1610) -> Result<()> {
1611    let room_id = RoomId::from_proto(request.room_id);
1612    let calling_user_id = session.user_id();
1613    let calling_connection_id = session.connection_id;
1614    let called_user_id = UserId::from_proto(request.called_user_id);
1615    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1616    if !session
1617        .db()
1618        .await
1619        .has_contact(calling_user_id, called_user_id)
1620        .await?
1621    {
1622        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1623    }
1624
1625    let incoming_call = {
1626        let (room, incoming_call) = &mut *session
1627            .db()
1628            .await
1629            .call(
1630                room_id,
1631                calling_user_id,
1632                calling_connection_id,
1633                called_user_id,
1634                initial_project_id,
1635            )
1636            .await?;
1637        room_updated(room, &session.peer);
1638        mem::take(incoming_call)
1639    };
1640    update_user_contacts(called_user_id, &session).await?;
1641
1642    let mut calls = session
1643        .connection_pool()
1644        .await
1645        .user_connection_ids(called_user_id)
1646        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1647        .collect::<FuturesUnordered<_>>();
1648
1649    while let Some(call_response) = calls.next().await {
1650        match call_response.as_ref() {
1651            Ok(_) => {
1652                response.send(proto::Ack {})?;
1653                return Ok(());
1654            }
1655            Err(_) => {
1656                call_response.trace_err();
1657            }
1658        }
1659    }
1660
1661    {
1662        let room = session
1663            .db()
1664            .await
1665            .call_failed(room_id, called_user_id)
1666            .await?;
1667        room_updated(&room, &session.peer);
1668    }
1669    update_user_contacts(called_user_id, &session).await?;
1670
1671    Err(anyhow!("failed to ring user"))?
1672}
1673
1674/// Cancel an outgoing call.
1675async fn cancel_call(
1676    request: proto::CancelCall,
1677    response: Response<proto::CancelCall>,
1678    session: MessageContext,
1679) -> Result<()> {
1680    let called_user_id = UserId::from_proto(request.called_user_id);
1681    let room_id = RoomId::from_proto(request.room_id);
1682    {
1683        let room = session
1684            .db()
1685            .await
1686            .cancel_call(room_id, session.connection_id, called_user_id)
1687            .await?;
1688        room_updated(&room, &session.peer);
1689    }
1690
1691    for connection_id in session
1692        .connection_pool()
1693        .await
1694        .user_connection_ids(called_user_id)
1695    {
1696        session
1697            .peer
1698            .send(
1699                connection_id,
1700                proto::CallCanceled {
1701                    room_id: room_id.to_proto(),
1702                },
1703            )
1704            .trace_err();
1705    }
1706    response.send(proto::Ack {})?;
1707
1708    update_user_contacts(called_user_id, &session).await?;
1709    Ok(())
1710}
1711
1712/// Decline an incoming call.
1713async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1714    let room_id = RoomId::from_proto(message.room_id);
1715    {
1716        let room = session
1717            .db()
1718            .await
1719            .decline_call(Some(room_id), session.user_id())
1720            .await?
1721            .context("declining call")?;
1722        room_updated(&room, &session.peer);
1723    }
1724
1725    for connection_id in session
1726        .connection_pool()
1727        .await
1728        .user_connection_ids(session.user_id())
1729    {
1730        session
1731            .peer
1732            .send(
1733                connection_id,
1734                proto::CallCanceled {
1735                    room_id: room_id.to_proto(),
1736                },
1737            )
1738            .trace_err();
1739    }
1740    update_user_contacts(session.user_id(), &session).await?;
1741    Ok(())
1742}
1743
1744/// Updates other participants in the room with your current location.
1745async fn update_participant_location(
1746    request: proto::UpdateParticipantLocation,
1747    response: Response<proto::UpdateParticipantLocation>,
1748    session: MessageContext,
1749) -> Result<()> {
1750    let room_id = RoomId::from_proto(request.room_id);
1751    let location = request.location.context("invalid location")?;
1752
1753    let db = session.db().await;
1754    let room = db
1755        .update_room_participant_location(room_id, session.connection_id, location)
1756        .await?;
1757
1758    room_updated(&room, &session.peer);
1759    response.send(proto::Ack {})?;
1760    Ok(())
1761}
1762
1763/// Share a project into the room.
1764async fn share_project(
1765    request: proto::ShareProject,
1766    response: Response<proto::ShareProject>,
1767    session: MessageContext,
1768) -> Result<()> {
1769    let (project_id, room) = &*session
1770        .db()
1771        .await
1772        .share_project(
1773            RoomId::from_proto(request.room_id),
1774            session.connection_id,
1775            &request.worktrees,
1776            request.is_ssh_project,
1777            request.windows_paths.unwrap_or(false),
1778            &request.features,
1779        )
1780        .await?;
1781    response.send(proto::ShareProjectResponse {
1782        project_id: project_id.to_proto(),
1783    })?;
1784    room_updated(room, &session.peer);
1785
1786    Ok(())
1787}
1788
1789/// Unshare a project from the room.
1790async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1791    let project_id = ProjectId::from_proto(message.project_id);
1792    unshare_project_internal(project_id, session.connection_id, &session).await
1793}
1794
1795async fn unshare_project_internal(
1796    project_id: ProjectId,
1797    connection_id: ConnectionId,
1798    session: &Session,
1799) -> Result<()> {
1800    let delete = {
1801        let room_guard = session
1802            .db()
1803            .await
1804            .unshare_project(project_id, connection_id)
1805            .await?;
1806
1807        let (delete, room, guest_connection_ids) = &*room_guard;
1808
1809        let message = proto::UnshareProject {
1810            project_id: project_id.to_proto(),
1811        };
1812
1813        broadcast(
1814            Some(connection_id),
1815            guest_connection_ids.iter().copied(),
1816            |conn_id| session.peer.send(conn_id, message.clone()),
1817        );
1818        if let Some(room) = room {
1819            room_updated(room, &session.peer);
1820        }
1821
1822        *delete
1823    };
1824
1825    if delete {
1826        let db = session.db().await;
1827        db.delete_project(project_id).await?;
1828    }
1829
1830    Ok(())
1831}
1832
1833/// Join someone elses shared project.
1834async fn join_project(
1835    request: proto::JoinProject,
1836    response: Response<proto::JoinProject>,
1837    session: MessageContext,
1838) -> Result<()> {
1839    let project_id = ProjectId::from_proto(request.project_id);
1840
1841    tracing::info!(%project_id, "join project");
1842
1843    let db = session.db().await;
1844    let project_model = db.get_project(project_id).await?;
1845    let host_features: Vec<String> =
1846        serde_json::from_str(&project_model.features).unwrap_or_default();
1847    let guest_features: HashSet<_> = request.features.iter().collect();
1848    let host_features_set: HashSet<_> = host_features.iter().collect();
1849    if guest_features != host_features_set {
1850        let host_connection_id = project_model.host_connection()?;
1851        let mut pool = session.connection_pool().await;
1852        let host_version = pool
1853            .connection(host_connection_id)
1854            .map(|c| c.zed_version.to_string());
1855        let guest_version = pool
1856            .connection(session.connection_id)
1857            .map(|c| c.zed_version.to_string());
1858        drop(pool);
1859        Err(anyhow!(
1860            "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1861            host_version.as_deref().unwrap_or("unknown"),
1862            guest_version.as_deref().unwrap_or("unknown"),
1863        ))?;
1864    }
1865
1866    let (project, replica_id) = &mut *db
1867        .join_project(
1868            project_id,
1869            session.connection_id,
1870            session.user_id(),
1871            request.committer_name.clone(),
1872            request.committer_email.clone(),
1873        )
1874        .await?;
1875    drop(db);
1876
1877    tracing::info!(%project_id, "join remote project");
1878    let collaborators = project
1879        .collaborators
1880        .iter()
1881        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1882        .map(|collaborator| collaborator.to_proto())
1883        .collect::<Vec<_>>();
1884    let project_id = project.id;
1885    let guest_user_id = session.user_id();
1886
1887    let worktrees = project
1888        .worktrees
1889        .iter()
1890        .map(|(id, worktree)| proto::WorktreeMetadata {
1891            id: *id,
1892            root_name: worktree.root_name.clone(),
1893            visible: worktree.visible,
1894            abs_path: worktree.abs_path.clone(),
1895        })
1896        .collect::<Vec<_>>();
1897
1898    let add_project_collaborator = proto::AddProjectCollaborator {
1899        project_id: project_id.to_proto(),
1900        collaborator: Some(proto::Collaborator {
1901            peer_id: Some(session.connection_id.into()),
1902            replica_id: replica_id.0 as u32,
1903            user_id: guest_user_id.to_proto(),
1904            is_host: false,
1905            committer_name: request.committer_name.clone(),
1906            committer_email: request.committer_email.clone(),
1907        }),
1908    };
1909
1910    for collaborator in &collaborators {
1911        session
1912            .peer
1913            .send(
1914                collaborator.peer_id.unwrap().into(),
1915                add_project_collaborator.clone(),
1916            )
1917            .trace_err();
1918    }
1919
1920    // First, we send the metadata associated with each worktree.
1921    let (language_servers, language_server_capabilities) = project
1922        .language_servers
1923        .clone()
1924        .into_iter()
1925        .map(|server| (server.server, server.capabilities))
1926        .unzip();
1927    response.send(proto::JoinProjectResponse {
1928        project_id: project.id.0 as u64,
1929        worktrees,
1930        replica_id: replica_id.0 as u32,
1931        collaborators,
1932        language_servers,
1933        language_server_capabilities,
1934        role: project.role.into(),
1935        windows_paths: project.path_style == PathStyle::Windows,
1936        features: project.features.clone(),
1937    })?;
1938
1939    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1940        // Stream this worktree's entries.
1941        let message = proto::UpdateWorktree {
1942            project_id: project_id.to_proto(),
1943            worktree_id,
1944            abs_path: worktree.abs_path.clone(),
1945            root_name: worktree.root_name,
1946            updated_entries: worktree.entries,
1947            removed_entries: Default::default(),
1948            scan_id: worktree.scan_id,
1949            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1950            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1951            removed_repositories: Default::default(),
1952        };
1953        for update in proto::split_worktree_update(message) {
1954            session.peer.send(session.connection_id, update.clone())?;
1955        }
1956
1957        // Stream this worktree's diagnostics.
1958        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1959        if let Some(summary) = worktree_diagnostics.next() {
1960            let message = proto::UpdateDiagnosticSummary {
1961                project_id: project.id.to_proto(),
1962                worktree_id: worktree.id,
1963                summary: Some(summary),
1964                more_summaries: worktree_diagnostics.collect(),
1965            };
1966            session.peer.send(session.connection_id, message)?;
1967        }
1968
1969        for settings_file in worktree.settings_files {
1970            session.peer.send(
1971                session.connection_id,
1972                proto::UpdateWorktreeSettings {
1973                    project_id: project_id.to_proto(),
1974                    worktree_id: worktree.id,
1975                    path: settings_file.path,
1976                    content: Some(settings_file.content),
1977                    kind: Some(settings_file.kind.to_proto() as i32),
1978                    outside_worktree: Some(settings_file.outside_worktree),
1979                },
1980            )?;
1981        }
1982    }
1983
1984    for repository in mem::take(&mut project.repositories) {
1985        for update in split_repository_update(repository) {
1986            session.peer.send(session.connection_id, update)?;
1987        }
1988    }
1989
1990    for language_server in &project.language_servers {
1991        session.peer.send(
1992            session.connection_id,
1993            proto::UpdateLanguageServer {
1994                project_id: project_id.to_proto(),
1995                server_name: Some(language_server.server.name.clone()),
1996                language_server_id: language_server.server.id,
1997                variant: Some(
1998                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1999                        proto::LspDiskBasedDiagnosticsUpdated {},
2000                    ),
2001                ),
2002            },
2003        )?;
2004    }
2005
2006    Ok(())
2007}
2008
2009/// Leave someone elses shared project.
2010async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2011    let sender_id = session.connection_id;
2012    let project_id = ProjectId::from_proto(request.project_id);
2013    let db = session.db().await;
2014
2015    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2016    tracing::info!(
2017        %project_id,
2018        "leave project"
2019    );
2020
2021    project_left(project, &session);
2022    if let Some(room) = room {
2023        room_updated(room, &session.peer);
2024    }
2025
2026    Ok(())
2027}
2028
2029/// Updates other participants with changes to the project
2030async fn update_project(
2031    request: proto::UpdateProject,
2032    response: Response<proto::UpdateProject>,
2033    session: MessageContext,
2034) -> Result<()> {
2035    let project_id = ProjectId::from_proto(request.project_id);
2036    let (room, guest_connection_ids) = &*session
2037        .db()
2038        .await
2039        .update_project(project_id, session.connection_id, &request.worktrees)
2040        .await?;
2041    broadcast(
2042        Some(session.connection_id),
2043        guest_connection_ids.iter().copied(),
2044        |connection_id| {
2045            session
2046                .peer
2047                .forward_send(session.connection_id, connection_id, request.clone())
2048        },
2049    );
2050    if let Some(room) = room {
2051        room_updated(room, &session.peer);
2052    }
2053    response.send(proto::Ack {})?;
2054
2055    Ok(())
2056}
2057
2058/// Updates other participants with changes to the worktree
2059async fn update_worktree(
2060    request: proto::UpdateWorktree,
2061    response: Response<proto::UpdateWorktree>,
2062    session: MessageContext,
2063) -> Result<()> {
2064    let guest_connection_ids = session
2065        .db()
2066        .await
2067        .update_worktree(&request, session.connection_id)
2068        .await?;
2069
2070    broadcast(
2071        Some(session.connection_id),
2072        guest_connection_ids.iter().copied(),
2073        |connection_id| {
2074            session
2075                .peer
2076                .forward_send(session.connection_id, connection_id, request.clone())
2077        },
2078    );
2079    response.send(proto::Ack {})?;
2080    Ok(())
2081}
2082
2083async fn update_repository(
2084    request: proto::UpdateRepository,
2085    response: Response<proto::UpdateRepository>,
2086    session: MessageContext,
2087) -> Result<()> {
2088    let guest_connection_ids = session
2089        .db()
2090        .await
2091        .update_repository(&request, session.connection_id)
2092        .await?;
2093
2094    broadcast(
2095        Some(session.connection_id),
2096        guest_connection_ids.iter().copied(),
2097        |connection_id| {
2098            session
2099                .peer
2100                .forward_send(session.connection_id, connection_id, request.clone())
2101        },
2102    );
2103    response.send(proto::Ack {})?;
2104    Ok(())
2105}
2106
2107async fn remove_repository(
2108    request: proto::RemoveRepository,
2109    response: Response<proto::RemoveRepository>,
2110    session: MessageContext,
2111) -> Result<()> {
2112    let guest_connection_ids = session
2113        .db()
2114        .await
2115        .remove_repository(&request, session.connection_id)
2116        .await?;
2117
2118    broadcast(
2119        Some(session.connection_id),
2120        guest_connection_ids.iter().copied(),
2121        |connection_id| {
2122            session
2123                .peer
2124                .forward_send(session.connection_id, connection_id, request.clone())
2125        },
2126    );
2127    response.send(proto::Ack {})?;
2128    Ok(())
2129}
2130
2131/// Updates other participants with changes to the diagnostics
2132async fn update_diagnostic_summary(
2133    message: proto::UpdateDiagnosticSummary,
2134    session: MessageContext,
2135) -> Result<()> {
2136    let guest_connection_ids = session
2137        .db()
2138        .await
2139        .update_diagnostic_summary(&message, session.connection_id)
2140        .await?;
2141
2142    broadcast(
2143        Some(session.connection_id),
2144        guest_connection_ids.iter().copied(),
2145        |connection_id| {
2146            session
2147                .peer
2148                .forward_send(session.connection_id, connection_id, message.clone())
2149        },
2150    );
2151
2152    Ok(())
2153}
2154
2155/// Updates other participants with changes to the worktree settings
2156async fn update_worktree_settings(
2157    message: proto::UpdateWorktreeSettings,
2158    session: MessageContext,
2159) -> Result<()> {
2160    let guest_connection_ids = session
2161        .db()
2162        .await
2163        .update_worktree_settings(&message, session.connection_id)
2164        .await?;
2165
2166    broadcast(
2167        Some(session.connection_id),
2168        guest_connection_ids.iter().copied(),
2169        |connection_id| {
2170            session
2171                .peer
2172                .forward_send(session.connection_id, connection_id, message.clone())
2173        },
2174    );
2175
2176    Ok(())
2177}
2178
2179/// Notify other participants that a language server has started.
2180async fn start_language_server(
2181    request: proto::StartLanguageServer,
2182    session: MessageContext,
2183) -> Result<()> {
2184    let guest_connection_ids = session
2185        .db()
2186        .await
2187        .start_language_server(&request, session.connection_id)
2188        .await?;
2189
2190    broadcast(
2191        Some(session.connection_id),
2192        guest_connection_ids.iter().copied(),
2193        |connection_id| {
2194            session
2195                .peer
2196                .forward_send(session.connection_id, connection_id, request.clone())
2197        },
2198    );
2199    Ok(())
2200}
2201
2202/// Notify other participants that a language server has changed.
2203async fn update_language_server(
2204    request: proto::UpdateLanguageServer,
2205    session: MessageContext,
2206) -> Result<()> {
2207    let project_id = ProjectId::from_proto(request.project_id);
2208    let db = session.db().await;
2209
2210    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2211        && let Some(capabilities) = update.capabilities.clone()
2212    {
2213        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2214            .await?;
2215    }
2216
2217    let project_connection_ids = db
2218        .project_connection_ids(project_id, session.connection_id, true)
2219        .await?;
2220    broadcast(
2221        Some(session.connection_id),
2222        project_connection_ids.iter().copied(),
2223        |connection_id| {
2224            session
2225                .peer
2226                .forward_send(session.connection_id, connection_id, request.clone())
2227        },
2228    );
2229    Ok(())
2230}
2231
2232/// forward a project request to the host. These requests should be read only
2233/// as guests are allowed to send them.
2234async fn forward_read_only_project_request<T>(
2235    request: T,
2236    response: Response<T>,
2237    session: MessageContext,
2238) -> Result<()>
2239where
2240    T: EntityMessage + RequestMessage,
2241{
2242    let project_id = ProjectId::from_proto(request.remote_entity_id());
2243    let host_connection_id = session
2244        .db()
2245        .await
2246        .host_for_read_only_project_request(project_id, session.connection_id)
2247        .await?;
2248    let payload = session.forward_request(host_connection_id, request).await?;
2249    response.send(payload)?;
2250    Ok(())
2251}
2252
2253/// forward a project request to the host. These requests are disallowed
2254/// for guests.
2255async fn forward_mutating_project_request<T>(
2256    request: T,
2257    response: Response<T>,
2258    session: MessageContext,
2259) -> Result<()>
2260where
2261    T: EntityMessage + RequestMessage,
2262{
2263    let project_id = ProjectId::from_proto(request.remote_entity_id());
2264
2265    let host_connection_id = session
2266        .db()
2267        .await
2268        .host_for_mutating_project_request(project_id, session.connection_id)
2269        .await?;
2270    let payload = session.forward_request(host_connection_id, request).await?;
2271    response.send(payload)?;
2272    Ok(())
2273}
2274
2275async fn disallow_guest_request<T>(
2276    _request: T,
2277    response: Response<T>,
2278    _session: MessageContext,
2279) -> Result<()>
2280where
2281    T: RequestMessage,
2282{
2283    response.peer.respond_with_error(
2284        response.receipt,
2285        ErrorCode::Forbidden
2286            .message("request is not allowed for guests".to_string())
2287            .to_proto(),
2288    )?;
2289    response.responded.store(true, SeqCst);
2290    Ok(())
2291}
2292
2293async fn lsp_query(
2294    request: proto::LspQuery,
2295    response: Response<proto::LspQuery>,
2296    session: MessageContext,
2297) -> Result<()> {
2298    let (name, should_write) = request.query_name_and_write_permissions();
2299    tracing::Span::current().record("lsp_query_request", name);
2300    tracing::info!("lsp_query message received");
2301    if should_write {
2302        forward_mutating_project_request(request, response, session).await
2303    } else {
2304        forward_read_only_project_request(request, response, session).await
2305    }
2306}
2307
2308/// Notify other participants that a new buffer has been created
2309async fn create_buffer_for_peer(
2310    request: proto::CreateBufferForPeer,
2311    session: MessageContext,
2312) -> Result<()> {
2313    session
2314        .db()
2315        .await
2316        .check_user_is_project_host(
2317            ProjectId::from_proto(request.project_id),
2318            session.connection_id,
2319        )
2320        .await?;
2321    let peer_id = request.peer_id.context("invalid peer id")?;
2322    session
2323        .peer
2324        .forward_send(session.connection_id, peer_id.into(), request)?;
2325    Ok(())
2326}
2327
2328/// Notify other participants that a new image has been created
2329async fn create_image_for_peer(
2330    request: proto::CreateImageForPeer,
2331    session: MessageContext,
2332) -> Result<()> {
2333    session
2334        .db()
2335        .await
2336        .check_user_is_project_host(
2337            ProjectId::from_proto(request.project_id),
2338            session.connection_id,
2339        )
2340        .await?;
2341    let peer_id = request.peer_id.context("invalid peer id")?;
2342    session
2343        .peer
2344        .forward_send(session.connection_id, peer_id.into(), request)?;
2345    Ok(())
2346}
2347
2348/// Notify other participants that a buffer has been updated. This is
2349/// allowed for guests as long as the update is limited to selections.
2350async fn update_buffer(
2351    request: proto::UpdateBuffer,
2352    response: Response<proto::UpdateBuffer>,
2353    session: MessageContext,
2354) -> Result<()> {
2355    let project_id = ProjectId::from_proto(request.project_id);
2356    let mut capability = Capability::ReadOnly;
2357
2358    for op in request.operations.iter() {
2359        match op.variant {
2360            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2361            Some(_) => capability = Capability::ReadWrite,
2362        }
2363    }
2364
2365    let host = {
2366        let guard = session
2367            .db()
2368            .await
2369            .connections_for_buffer_update(project_id, session.connection_id, capability)
2370            .await?;
2371
2372        let (host, guests) = &*guard;
2373
2374        broadcast(
2375            Some(session.connection_id),
2376            guests.clone(),
2377            |connection_id| {
2378                session
2379                    .peer
2380                    .forward_send(session.connection_id, connection_id, request.clone())
2381            },
2382        );
2383
2384        *host
2385    };
2386
2387    if host != session.connection_id {
2388        session.forward_request(host, request.clone()).await?;
2389    }
2390
2391    response.send(proto::Ack {})?;
2392    Ok(())
2393}
2394
2395async fn forward_project_search_chunk(
2396    message: proto::FindSearchCandidatesChunk,
2397    response: Response<proto::FindSearchCandidatesChunk>,
2398    session: MessageContext,
2399) -> Result<()> {
2400    let peer_id = message.peer_id.context("missing peer_id")?;
2401    let payload = session
2402        .peer
2403        .forward_request(session.connection_id, peer_id.into(), message)
2404        .await?;
2405    response.send(payload)?;
2406    Ok(())
2407}
2408
2409/// Notify other participants that a project has been updated.
2410async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2411    request: T,
2412    session: MessageContext,
2413) -> Result<()> {
2414    let project_id = ProjectId::from_proto(request.remote_entity_id());
2415    let project_connection_ids = session
2416        .db()
2417        .await
2418        .project_connection_ids(project_id, session.connection_id, false)
2419        .await?;
2420
2421    broadcast(
2422        Some(session.connection_id),
2423        project_connection_ids.iter().copied(),
2424        |connection_id| {
2425            session
2426                .peer
2427                .forward_send(session.connection_id, connection_id, request.clone())
2428        },
2429    );
2430    Ok(())
2431}
2432
2433/// Start following another user in a call.
2434async fn follow(
2435    request: proto::Follow,
2436    response: Response<proto::Follow>,
2437    session: MessageContext,
2438) -> Result<()> {
2439    let room_id = RoomId::from_proto(request.room_id);
2440    let project_id = request.project_id.map(ProjectId::from_proto);
2441    let leader_id = request.leader_id.context("invalid leader id")?.into();
2442    let follower_id = session.connection_id;
2443
2444    session
2445        .db()
2446        .await
2447        .check_room_participants(room_id, leader_id, session.connection_id)
2448        .await?;
2449
2450    let response_payload = session.forward_request(leader_id, request).await?;
2451    response.send(response_payload)?;
2452
2453    if let Some(project_id) = project_id {
2454        let room = session
2455            .db()
2456            .await
2457            .follow(room_id, project_id, leader_id, follower_id)
2458            .await?;
2459        room_updated(&room, &session.peer);
2460    }
2461
2462    Ok(())
2463}
2464
2465/// Stop following another user in a call.
2466async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2467    let room_id = RoomId::from_proto(request.room_id);
2468    let project_id = request.project_id.map(ProjectId::from_proto);
2469    let leader_id = request.leader_id.context("invalid leader id")?.into();
2470    let follower_id = session.connection_id;
2471
2472    session
2473        .db()
2474        .await
2475        .check_room_participants(room_id, leader_id, session.connection_id)
2476        .await?;
2477
2478    session
2479        .peer
2480        .forward_send(session.connection_id, leader_id, request)?;
2481
2482    if let Some(project_id) = project_id {
2483        let room = session
2484            .db()
2485            .await
2486            .unfollow(room_id, project_id, leader_id, follower_id)
2487            .await?;
2488        room_updated(&room, &session.peer);
2489    }
2490
2491    Ok(())
2492}
2493
2494/// Notify everyone following you of your current location.
2495async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2496    let room_id = RoomId::from_proto(request.room_id);
2497    let database = session.db.lock().await;
2498
2499    let connection_ids = if let Some(project_id) = request.project_id {
2500        let project_id = ProjectId::from_proto(project_id);
2501        database
2502            .project_connection_ids(project_id, session.connection_id, true)
2503            .await?
2504    } else {
2505        database
2506            .room_connection_ids(room_id, session.connection_id)
2507            .await?
2508    };
2509
2510    // For now, don't send view update messages back to that view's current leader.
2511    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2512        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2513        _ => None,
2514    });
2515
2516    for connection_id in connection_ids.iter().cloned() {
2517        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2518            session
2519                .peer
2520                .forward_send(session.connection_id, connection_id, request.clone())?;
2521        }
2522    }
2523    Ok(())
2524}
2525
2526/// Get public data about users.
2527async fn get_users(
2528    request: proto::GetUsers,
2529    response: Response<proto::GetUsers>,
2530    session: MessageContext,
2531) -> Result<()> {
2532    let user_ids = request
2533        .user_ids
2534        .into_iter()
2535        .map(UserId::from_proto)
2536        .collect();
2537    let users = session
2538        .db()
2539        .await
2540        .get_users_by_ids(user_ids)
2541        .await?
2542        .into_iter()
2543        .map(|user| proto::User {
2544            id: user.id.to_proto(),
2545            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2546            github_login: user.github_login,
2547            name: user.name,
2548        })
2549        .collect();
2550    response.send(proto::UsersResponse { users })?;
2551    Ok(())
2552}
2553
2554/// Search for users (to invite) buy Github login
2555async fn fuzzy_search_users(
2556    request: proto::FuzzySearchUsers,
2557    response: Response<proto::FuzzySearchUsers>,
2558    session: MessageContext,
2559) -> Result<()> {
2560    let query = request.query;
2561    let users = match query.len() {
2562        0 => vec![],
2563        1 | 2 => session
2564            .db()
2565            .await
2566            .get_user_by_github_login(&query)
2567            .await?
2568            .into_iter()
2569            .collect(),
2570        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2571    };
2572    let users = users
2573        .into_iter()
2574        .filter(|user| user.id != session.user_id())
2575        .map(|user| proto::User {
2576            id: user.id.to_proto(),
2577            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2578            github_login: user.github_login,
2579            name: user.name,
2580        })
2581        .collect();
2582    response.send(proto::UsersResponse { users })?;
2583    Ok(())
2584}
2585
2586/// Send a contact request to another user.
2587async fn request_contact(
2588    request: proto::RequestContact,
2589    response: Response<proto::RequestContact>,
2590    session: MessageContext,
2591) -> Result<()> {
2592    let requester_id = session.user_id();
2593    let responder_id = UserId::from_proto(request.responder_id);
2594    if requester_id == responder_id {
2595        return Err(anyhow!("cannot add yourself as a contact"))?;
2596    }
2597
2598    let notifications = session
2599        .db()
2600        .await
2601        .send_contact_request(requester_id, responder_id)
2602        .await?;
2603
2604    // Update outgoing contact requests of requester
2605    let mut update = proto::UpdateContacts::default();
2606    update.outgoing_requests.push(responder_id.to_proto());
2607    for connection_id in session
2608        .connection_pool()
2609        .await
2610        .user_connection_ids(requester_id)
2611    {
2612        session.peer.send(connection_id, update.clone())?;
2613    }
2614
2615    // Update incoming contact requests of responder
2616    let mut update = proto::UpdateContacts::default();
2617    update
2618        .incoming_requests
2619        .push(proto::IncomingContactRequest {
2620            requester_id: requester_id.to_proto(),
2621        });
2622    let connection_pool = session.connection_pool().await;
2623    for connection_id in connection_pool.user_connection_ids(responder_id) {
2624        session.peer.send(connection_id, update.clone())?;
2625    }
2626
2627    send_notifications(&connection_pool, &session.peer, notifications);
2628
2629    response.send(proto::Ack {})?;
2630    Ok(())
2631}
2632
2633/// Accept or decline a contact request
2634async fn respond_to_contact_request(
2635    request: proto::RespondToContactRequest,
2636    response: Response<proto::RespondToContactRequest>,
2637    session: MessageContext,
2638) -> Result<()> {
2639    let responder_id = session.user_id();
2640    let requester_id = UserId::from_proto(request.requester_id);
2641    let db = session.db().await;
2642    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2643        db.dismiss_contact_notification(responder_id, requester_id)
2644            .await?;
2645    } else {
2646        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2647
2648        let notifications = db
2649            .respond_to_contact_request(responder_id, requester_id, accept)
2650            .await?;
2651        let requester_busy = db.is_user_busy(requester_id).await?;
2652        let responder_busy = db.is_user_busy(responder_id).await?;
2653
2654        let pool = session.connection_pool().await;
2655        // Update responder with new contact
2656        let mut update = proto::UpdateContacts::default();
2657        if accept {
2658            update
2659                .contacts
2660                .push(contact_for_user(requester_id, requester_busy, &pool));
2661        }
2662        update
2663            .remove_incoming_requests
2664            .push(requester_id.to_proto());
2665        for connection_id in pool.user_connection_ids(responder_id) {
2666            session.peer.send(connection_id, update.clone())?;
2667        }
2668
2669        // Update requester with new contact
2670        let mut update = proto::UpdateContacts::default();
2671        if accept {
2672            update
2673                .contacts
2674                .push(contact_for_user(responder_id, responder_busy, &pool));
2675        }
2676        update
2677            .remove_outgoing_requests
2678            .push(responder_id.to_proto());
2679
2680        for connection_id in pool.user_connection_ids(requester_id) {
2681            session.peer.send(connection_id, update.clone())?;
2682        }
2683
2684        send_notifications(&pool, &session.peer, notifications);
2685    }
2686
2687    response.send(proto::Ack {})?;
2688    Ok(())
2689}
2690
2691/// Remove a contact.
2692async fn remove_contact(
2693    request: proto::RemoveContact,
2694    response: Response<proto::RemoveContact>,
2695    session: MessageContext,
2696) -> Result<()> {
2697    let requester_id = session.user_id();
2698    let responder_id = UserId::from_proto(request.user_id);
2699    let db = session.db().await;
2700    let (contact_accepted, deleted_notification_id) =
2701        db.remove_contact(requester_id, responder_id).await?;
2702
2703    let pool = session.connection_pool().await;
2704    // Update outgoing contact requests of requester
2705    let mut update = proto::UpdateContacts::default();
2706    if contact_accepted {
2707        update.remove_contacts.push(responder_id.to_proto());
2708    } else {
2709        update
2710            .remove_outgoing_requests
2711            .push(responder_id.to_proto());
2712    }
2713    for connection_id in pool.user_connection_ids(requester_id) {
2714        session.peer.send(connection_id, update.clone())?;
2715    }
2716
2717    // Update incoming contact requests of responder
2718    let mut update = proto::UpdateContacts::default();
2719    if contact_accepted {
2720        update.remove_contacts.push(requester_id.to_proto());
2721    } else {
2722        update
2723            .remove_incoming_requests
2724            .push(requester_id.to_proto());
2725    }
2726    for connection_id in pool.user_connection_ids(responder_id) {
2727        session.peer.send(connection_id, update.clone())?;
2728        if let Some(notification_id) = deleted_notification_id {
2729            session.peer.send(
2730                connection_id,
2731                proto::DeleteNotification {
2732                    notification_id: notification_id.to_proto(),
2733                },
2734            )?;
2735        }
2736    }
2737
2738    response.send(proto::Ack {})?;
2739    Ok(())
2740}
2741
2742fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2743    version.0.minor < 139
2744}
2745
2746async fn subscribe_to_channels(
2747    _: proto::SubscribeToChannels,
2748    session: MessageContext,
2749) -> Result<()> {
2750    subscribe_user_to_channels(session.user_id(), &session).await?;
2751    Ok(())
2752}
2753
2754async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2755    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2756    let mut pool = session.connection_pool().await;
2757    for membership in &channels_for_user.channel_memberships {
2758        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2759    }
2760    session.peer.send(
2761        session.connection_id,
2762        build_update_user_channels(&channels_for_user),
2763    )?;
2764    session.peer.send(
2765        session.connection_id,
2766        build_channels_update(channels_for_user),
2767    )?;
2768    Ok(())
2769}
2770
2771/// Creates a new channel.
2772async fn create_channel(
2773    request: proto::CreateChannel,
2774    response: Response<proto::CreateChannel>,
2775    session: MessageContext,
2776) -> Result<()> {
2777    let db = session.db().await;
2778
2779    let parent_id = request.parent_id.map(ChannelId::from_proto);
2780    let (channel, membership) = db
2781        .create_channel(&request.name, parent_id, session.user_id())
2782        .await?;
2783
2784    let root_id = channel.root_id();
2785    let channel = Channel::from_model(channel);
2786
2787    response.send(proto::CreateChannelResponse {
2788        channel: Some(channel.to_proto()),
2789        parent_id: request.parent_id,
2790    })?;
2791
2792    let mut connection_pool = session.connection_pool().await;
2793    if let Some(membership) = membership {
2794        connection_pool.subscribe_to_channel(
2795            membership.user_id,
2796            membership.channel_id,
2797            membership.role,
2798        );
2799        let update = proto::UpdateUserChannels {
2800            channel_memberships: vec![proto::ChannelMembership {
2801                channel_id: membership.channel_id.to_proto(),
2802                role: membership.role.into(),
2803            }],
2804            ..Default::default()
2805        };
2806        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2807            session.peer.send(connection_id, update.clone())?;
2808        }
2809    }
2810
2811    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2812        if !role.can_see_channel(channel.visibility) {
2813            continue;
2814        }
2815
2816        let update = proto::UpdateChannels {
2817            channels: vec![channel.to_proto()],
2818            ..Default::default()
2819        };
2820        session.peer.send(connection_id, update.clone())?;
2821    }
2822
2823    Ok(())
2824}
2825
2826/// Delete a channel
2827async fn delete_channel(
2828    request: proto::DeleteChannel,
2829    response: Response<proto::DeleteChannel>,
2830    session: MessageContext,
2831) -> Result<()> {
2832    let db = session.db().await;
2833
2834    let channel_id = request.channel_id;
2835    let (root_channel, removed_channels) = db
2836        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2837        .await?;
2838    response.send(proto::Ack {})?;
2839
2840    // Notify members of removed channels
2841    let mut update = proto::UpdateChannels::default();
2842    update
2843        .delete_channels
2844        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2845
2846    let connection_pool = session.connection_pool().await;
2847    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2848        session.peer.send(connection_id, update.clone())?;
2849    }
2850
2851    Ok(())
2852}
2853
2854/// Invite someone to join a channel.
2855async fn invite_channel_member(
2856    request: proto::InviteChannelMember,
2857    response: Response<proto::InviteChannelMember>,
2858    session: MessageContext,
2859) -> Result<()> {
2860    let db = session.db().await;
2861    let channel_id = ChannelId::from_proto(request.channel_id);
2862    let invitee_id = UserId::from_proto(request.user_id);
2863    let InviteMemberResult {
2864        channel,
2865        notifications,
2866    } = db
2867        .invite_channel_member(
2868            channel_id,
2869            invitee_id,
2870            session.user_id(),
2871            request.role().into(),
2872        )
2873        .await?;
2874
2875    let update = proto::UpdateChannels {
2876        channel_invitations: vec![channel.to_proto()],
2877        ..Default::default()
2878    };
2879
2880    let connection_pool = session.connection_pool().await;
2881    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2882        session.peer.send(connection_id, update.clone())?;
2883    }
2884
2885    send_notifications(&connection_pool, &session.peer, notifications);
2886
2887    response.send(proto::Ack {})?;
2888    Ok(())
2889}
2890
2891/// remove someone from a channel
2892async fn remove_channel_member(
2893    request: proto::RemoveChannelMember,
2894    response: Response<proto::RemoveChannelMember>,
2895    session: MessageContext,
2896) -> Result<()> {
2897    let db = session.db().await;
2898    let channel_id = ChannelId::from_proto(request.channel_id);
2899    let member_id = UserId::from_proto(request.user_id);
2900
2901    let RemoveChannelMemberResult {
2902        membership_update,
2903        notification_id,
2904    } = db
2905        .remove_channel_member(channel_id, member_id, session.user_id())
2906        .await?;
2907
2908    let mut connection_pool = session.connection_pool().await;
2909    notify_membership_updated(
2910        &mut connection_pool,
2911        membership_update,
2912        member_id,
2913        &session.peer,
2914    );
2915    for connection_id in connection_pool.user_connection_ids(member_id) {
2916        if let Some(notification_id) = notification_id {
2917            session
2918                .peer
2919                .send(
2920                    connection_id,
2921                    proto::DeleteNotification {
2922                        notification_id: notification_id.to_proto(),
2923                    },
2924                )
2925                .trace_err();
2926        }
2927    }
2928
2929    response.send(proto::Ack {})?;
2930    Ok(())
2931}
2932
2933/// Toggle the channel between public and private.
2934/// Care is taken to maintain the invariant that public channels only descend from public channels,
2935/// (though members-only channels can appear at any point in the hierarchy).
2936async fn set_channel_visibility(
2937    request: proto::SetChannelVisibility,
2938    response: Response<proto::SetChannelVisibility>,
2939    session: MessageContext,
2940) -> Result<()> {
2941    let db = session.db().await;
2942    let channel_id = ChannelId::from_proto(request.channel_id);
2943    let visibility = request.visibility().into();
2944
2945    let channel_model = db
2946        .set_channel_visibility(channel_id, visibility, session.user_id())
2947        .await?;
2948    let root_id = channel_model.root_id();
2949    let channel = Channel::from_model(channel_model);
2950
2951    let mut connection_pool = session.connection_pool().await;
2952    for (user_id, role) in connection_pool
2953        .channel_user_ids(root_id)
2954        .collect::<Vec<_>>()
2955        .into_iter()
2956    {
2957        let update = if role.can_see_channel(channel.visibility) {
2958            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2959            proto::UpdateChannels {
2960                channels: vec![channel.to_proto()],
2961                ..Default::default()
2962            }
2963        } else {
2964            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2965            proto::UpdateChannels {
2966                delete_channels: vec![channel.id.to_proto()],
2967                ..Default::default()
2968            }
2969        };
2970
2971        for connection_id in connection_pool.user_connection_ids(user_id) {
2972            session.peer.send(connection_id, update.clone())?;
2973        }
2974    }
2975
2976    response.send(proto::Ack {})?;
2977    Ok(())
2978}
2979
2980/// Alter the role for a user in the channel.
2981async fn set_channel_member_role(
2982    request: proto::SetChannelMemberRole,
2983    response: Response<proto::SetChannelMemberRole>,
2984    session: MessageContext,
2985) -> Result<()> {
2986    let db = session.db().await;
2987    let channel_id = ChannelId::from_proto(request.channel_id);
2988    let member_id = UserId::from_proto(request.user_id);
2989    let result = db
2990        .set_channel_member_role(
2991            channel_id,
2992            session.user_id(),
2993            member_id,
2994            request.role().into(),
2995        )
2996        .await?;
2997
2998    match result {
2999        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3000            let mut connection_pool = session.connection_pool().await;
3001            notify_membership_updated(
3002                &mut connection_pool,
3003                membership_update,
3004                member_id,
3005                &session.peer,
3006            )
3007        }
3008        db::SetMemberRoleResult::InviteUpdated(channel) => {
3009            let update = proto::UpdateChannels {
3010                channel_invitations: vec![channel.to_proto()],
3011                ..Default::default()
3012            };
3013
3014            for connection_id in session
3015                .connection_pool()
3016                .await
3017                .user_connection_ids(member_id)
3018            {
3019                session.peer.send(connection_id, update.clone())?;
3020            }
3021        }
3022    }
3023
3024    response.send(proto::Ack {})?;
3025    Ok(())
3026}
3027
3028/// Change the name of a channel
3029async fn rename_channel(
3030    request: proto::RenameChannel,
3031    response: Response<proto::RenameChannel>,
3032    session: MessageContext,
3033) -> Result<()> {
3034    let db = session.db().await;
3035    let channel_id = ChannelId::from_proto(request.channel_id);
3036    let channel_model = db
3037        .rename_channel(channel_id, session.user_id(), &request.name)
3038        .await?;
3039    let root_id = channel_model.root_id();
3040    let channel = Channel::from_model(channel_model);
3041
3042    response.send(proto::RenameChannelResponse {
3043        channel: Some(channel.to_proto()),
3044    })?;
3045
3046    let connection_pool = session.connection_pool().await;
3047    let update = proto::UpdateChannels {
3048        channels: vec![channel.to_proto()],
3049        ..Default::default()
3050    };
3051    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3052        if role.can_see_channel(channel.visibility) {
3053            session.peer.send(connection_id, update.clone())?;
3054        }
3055    }
3056
3057    Ok(())
3058}
3059
3060/// Move a channel to a new parent.
3061async fn move_channel(
3062    request: proto::MoveChannel,
3063    response: Response<proto::MoveChannel>,
3064    session: MessageContext,
3065) -> Result<()> {
3066    let channel_id = ChannelId::from_proto(request.channel_id);
3067    let to = ChannelId::from_proto(request.to);
3068
3069    let (root_id, channels) = session
3070        .db()
3071        .await
3072        .move_channel(channel_id, to, session.user_id())
3073        .await?;
3074
3075    let connection_pool = session.connection_pool().await;
3076    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3077        let channels = channels
3078            .iter()
3079            .filter_map(|channel| {
3080                if role.can_see_channel(channel.visibility) {
3081                    Some(channel.to_proto())
3082                } else {
3083                    None
3084                }
3085            })
3086            .collect::<Vec<_>>();
3087        if channels.is_empty() {
3088            continue;
3089        }
3090
3091        let update = proto::UpdateChannels {
3092            channels,
3093            ..Default::default()
3094        };
3095
3096        session.peer.send(connection_id, update.clone())?;
3097    }
3098
3099    response.send(Ack {})?;
3100    Ok(())
3101}
3102
3103async fn reorder_channel(
3104    request: proto::ReorderChannel,
3105    response: Response<proto::ReorderChannel>,
3106    session: MessageContext,
3107) -> Result<()> {
3108    let channel_id = ChannelId::from_proto(request.channel_id);
3109    let direction = request.direction();
3110
3111    let updated_channels = session
3112        .db()
3113        .await
3114        .reorder_channel(channel_id, direction, session.user_id())
3115        .await?;
3116
3117    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3118        let connection_pool = session.connection_pool().await;
3119        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3120            let channels = updated_channels
3121                .iter()
3122                .filter_map(|channel| {
3123                    if role.can_see_channel(channel.visibility) {
3124                        Some(channel.to_proto())
3125                    } else {
3126                        None
3127                    }
3128                })
3129                .collect::<Vec<_>>();
3130
3131            if channels.is_empty() {
3132                continue;
3133            }
3134
3135            let update = proto::UpdateChannels {
3136                channels,
3137                ..Default::default()
3138            };
3139
3140            session.peer.send(connection_id, update.clone())?;
3141        }
3142    }
3143
3144    response.send(Ack {})?;
3145    Ok(())
3146}
3147
3148/// Get the list of channel members
3149async fn get_channel_members(
3150    request: proto::GetChannelMembers,
3151    response: Response<proto::GetChannelMembers>,
3152    session: MessageContext,
3153) -> Result<()> {
3154    let db = session.db().await;
3155    let channel_id = ChannelId::from_proto(request.channel_id);
3156    let limit = if request.limit == 0 {
3157        u16::MAX as u64
3158    } else {
3159        request.limit
3160    };
3161    let (members, users) = db
3162        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3163        .await?;
3164    response.send(proto::GetChannelMembersResponse { members, users })?;
3165    Ok(())
3166}
3167
3168/// Accept or decline a channel invitation.
3169async fn respond_to_channel_invite(
3170    request: proto::RespondToChannelInvite,
3171    response: Response<proto::RespondToChannelInvite>,
3172    session: MessageContext,
3173) -> Result<()> {
3174    let db = session.db().await;
3175    let channel_id = ChannelId::from_proto(request.channel_id);
3176    let RespondToChannelInvite {
3177        membership_update,
3178        notifications,
3179    } = db
3180        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3181        .await?;
3182
3183    let mut connection_pool = session.connection_pool().await;
3184    if let Some(membership_update) = membership_update {
3185        notify_membership_updated(
3186            &mut connection_pool,
3187            membership_update,
3188            session.user_id(),
3189            &session.peer,
3190        );
3191    } else {
3192        let update = proto::UpdateChannels {
3193            remove_channel_invitations: vec![channel_id.to_proto()],
3194            ..Default::default()
3195        };
3196
3197        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3198            session.peer.send(connection_id, update.clone())?;
3199        }
3200    };
3201
3202    send_notifications(&connection_pool, &session.peer, notifications);
3203
3204    response.send(proto::Ack {})?;
3205
3206    Ok(())
3207}
3208
3209/// Join the channels' room
3210async fn join_channel(
3211    request: proto::JoinChannel,
3212    response: Response<proto::JoinChannel>,
3213    session: MessageContext,
3214) -> Result<()> {
3215    let channel_id = ChannelId::from_proto(request.channel_id);
3216    join_channel_internal(channel_id, Box::new(response), session).await
3217}
3218
3219trait JoinChannelInternalResponse {
3220    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3221}
3222impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3223    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3224        Response::<proto::JoinChannel>::send(self, result)
3225    }
3226}
3227impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3228    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3229        Response::<proto::JoinRoom>::send(self, result)
3230    }
3231}
3232
3233async fn join_channel_internal(
3234    channel_id: ChannelId,
3235    response: Box<impl JoinChannelInternalResponse>,
3236    session: MessageContext,
3237) -> Result<()> {
3238    let joined_room = {
3239        let mut db = session.db().await;
3240        // If zed quits without leaving the room, and the user re-opens zed before the
3241        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3242        // room they were in.
3243        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3244            tracing::info!(
3245                stale_connection_id = %connection,
3246                "cleaning up stale connection",
3247            );
3248            drop(db);
3249            leave_room_for_session(&session, connection).await?;
3250            db = session.db().await;
3251        }
3252
3253        let (joined_room, membership_updated, role) = db
3254            .join_channel(channel_id, session.user_id(), session.connection_id)
3255            .await?;
3256
3257        let live_kit_connection_info =
3258            session
3259                .app_state
3260                .livekit_client
3261                .as_ref()
3262                .and_then(|live_kit| {
3263                    let (can_publish, token) = if role == ChannelRole::Guest {
3264                        (
3265                            false,
3266                            live_kit
3267                                .guest_token(
3268                                    &joined_room.room.livekit_room,
3269                                    &session.user_id().to_string(),
3270                                )
3271                                .trace_err()?,
3272                        )
3273                    } else {
3274                        (
3275                            true,
3276                            live_kit
3277                                .room_token(
3278                                    &joined_room.room.livekit_room,
3279                                    &session.user_id().to_string(),
3280                                )
3281                                .trace_err()?,
3282                        )
3283                    };
3284
3285                    Some(LiveKitConnectionInfo {
3286                        server_url: live_kit.url().into(),
3287                        token,
3288                        can_publish,
3289                    })
3290                });
3291
3292        response.send(proto::JoinRoomResponse {
3293            room: Some(joined_room.room.clone()),
3294            channel_id: joined_room
3295                .channel
3296                .as_ref()
3297                .map(|channel| channel.id.to_proto()),
3298            live_kit_connection_info,
3299        })?;
3300
3301        let mut connection_pool = session.connection_pool().await;
3302        if let Some(membership_updated) = membership_updated {
3303            notify_membership_updated(
3304                &mut connection_pool,
3305                membership_updated,
3306                session.user_id(),
3307                &session.peer,
3308            );
3309        }
3310
3311        room_updated(&joined_room.room, &session.peer);
3312
3313        joined_room
3314    };
3315
3316    channel_updated(
3317        &joined_room.channel.context("channel not returned")?,
3318        &joined_room.room,
3319        &session.peer,
3320        &*session.connection_pool().await,
3321    );
3322
3323    update_user_contacts(session.user_id(), &session).await?;
3324    Ok(())
3325}
3326
3327/// Start editing the channel notes
3328async fn join_channel_buffer(
3329    request: proto::JoinChannelBuffer,
3330    response: Response<proto::JoinChannelBuffer>,
3331    session: MessageContext,
3332) -> Result<()> {
3333    let db = session.db().await;
3334    let channel_id = ChannelId::from_proto(request.channel_id);
3335
3336    let open_response = db
3337        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3338        .await?;
3339
3340    let collaborators = open_response.collaborators.clone();
3341    response.send(open_response)?;
3342
3343    let update = UpdateChannelBufferCollaborators {
3344        channel_id: channel_id.to_proto(),
3345        collaborators: collaborators.clone(),
3346    };
3347    channel_buffer_updated(
3348        session.connection_id,
3349        collaborators
3350            .iter()
3351            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3352        &update,
3353        &session.peer,
3354    );
3355
3356    Ok(())
3357}
3358
3359/// Edit the channel notes
3360async fn update_channel_buffer(
3361    request: proto::UpdateChannelBuffer,
3362    session: MessageContext,
3363) -> Result<()> {
3364    let db = session.db().await;
3365    let channel_id = ChannelId::from_proto(request.channel_id);
3366
3367    let (collaborators, epoch, version) = db
3368        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3369        .await?;
3370
3371    channel_buffer_updated(
3372        session.connection_id,
3373        collaborators.clone(),
3374        &proto::UpdateChannelBuffer {
3375            channel_id: channel_id.to_proto(),
3376            operations: request.operations,
3377        },
3378        &session.peer,
3379    );
3380
3381    let pool = &*session.connection_pool().await;
3382
3383    let non_collaborators =
3384        pool.channel_connection_ids(channel_id)
3385            .filter_map(|(connection_id, _)| {
3386                if collaborators.contains(&connection_id) {
3387                    None
3388                } else {
3389                    Some(connection_id)
3390                }
3391            });
3392
3393    broadcast(None, non_collaborators, |peer_id| {
3394        session.peer.send(
3395            peer_id,
3396            proto::UpdateChannels {
3397                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3398                    channel_id: channel_id.to_proto(),
3399                    epoch: epoch as u64,
3400                    version: version.clone(),
3401                }],
3402                ..Default::default()
3403            },
3404        )
3405    });
3406
3407    Ok(())
3408}
3409
3410/// Rejoin the channel notes after a connection blip
3411async fn rejoin_channel_buffers(
3412    request: proto::RejoinChannelBuffers,
3413    response: Response<proto::RejoinChannelBuffers>,
3414    session: MessageContext,
3415) -> Result<()> {
3416    let db = session.db().await;
3417    let buffers = db
3418        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3419        .await?;
3420
3421    for rejoined_buffer in &buffers {
3422        let collaborators_to_notify = rejoined_buffer
3423            .buffer
3424            .collaborators
3425            .iter()
3426            .filter_map(|c| Some(c.peer_id?.into()));
3427        channel_buffer_updated(
3428            session.connection_id,
3429            collaborators_to_notify,
3430            &proto::UpdateChannelBufferCollaborators {
3431                channel_id: rejoined_buffer.buffer.channel_id,
3432                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3433            },
3434            &session.peer,
3435        );
3436    }
3437
3438    response.send(proto::RejoinChannelBuffersResponse {
3439        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3440    })?;
3441
3442    Ok(())
3443}
3444
3445/// Stop editing the channel notes
3446async fn leave_channel_buffer(
3447    request: proto::LeaveChannelBuffer,
3448    response: Response<proto::LeaveChannelBuffer>,
3449    session: MessageContext,
3450) -> Result<()> {
3451    let db = session.db().await;
3452    let channel_id = ChannelId::from_proto(request.channel_id);
3453
3454    let left_buffer = db
3455        .leave_channel_buffer(channel_id, session.connection_id)
3456        .await?;
3457
3458    response.send(Ack {})?;
3459
3460    channel_buffer_updated(
3461        session.connection_id,
3462        left_buffer.connections,
3463        &proto::UpdateChannelBufferCollaborators {
3464            channel_id: channel_id.to_proto(),
3465            collaborators: left_buffer.collaborators,
3466        },
3467        &session.peer,
3468    );
3469
3470    Ok(())
3471}
3472
3473fn channel_buffer_updated<T: EnvelopedMessage>(
3474    sender_id: ConnectionId,
3475    collaborators: impl IntoIterator<Item = ConnectionId>,
3476    message: &T,
3477    peer: &Peer,
3478) {
3479    broadcast(Some(sender_id), collaborators, |peer_id| {
3480        peer.send(peer_id, message.clone())
3481    });
3482}
3483
3484fn send_notifications(
3485    connection_pool: &ConnectionPool,
3486    peer: &Peer,
3487    notifications: db::NotificationBatch,
3488) {
3489    for (user_id, notification) in notifications {
3490        for connection_id in connection_pool.user_connection_ids(user_id) {
3491            if let Err(error) = peer.send(
3492                connection_id,
3493                proto::AddNotification {
3494                    notification: Some(notification.clone()),
3495                },
3496            ) {
3497                tracing::error!(
3498                    "failed to send notification to {:?} {}",
3499                    connection_id,
3500                    error
3501                );
3502            }
3503        }
3504    }
3505}
3506
3507/// Send a message to the channel
3508async fn send_channel_message(
3509    _request: proto::SendChannelMessage,
3510    _response: Response<proto::SendChannelMessage>,
3511    _session: MessageContext,
3512) -> Result<()> {
3513    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3514}
3515
3516/// Delete a channel message
3517async fn remove_channel_message(
3518    _request: proto::RemoveChannelMessage,
3519    _response: Response<proto::RemoveChannelMessage>,
3520    _session: MessageContext,
3521) -> Result<()> {
3522    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3523}
3524
3525async fn update_channel_message(
3526    _request: proto::UpdateChannelMessage,
3527    _response: Response<proto::UpdateChannelMessage>,
3528    _session: MessageContext,
3529) -> Result<()> {
3530    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3531}
3532
3533/// Mark a channel message as read
3534async fn acknowledge_channel_message(
3535    _request: proto::AckChannelMessage,
3536    _session: MessageContext,
3537) -> Result<()> {
3538    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3539}
3540
3541/// Mark a buffer version as synced
3542async fn acknowledge_buffer_version(
3543    request: proto::AckBufferOperation,
3544    session: MessageContext,
3545) -> Result<()> {
3546    let buffer_id = BufferId::from_proto(request.buffer_id);
3547    session
3548        .db()
3549        .await
3550        .observe_buffer_version(
3551            buffer_id,
3552            session.user_id(),
3553            request.epoch as i32,
3554            &request.version,
3555        )
3556        .await?;
3557    Ok(())
3558}
3559
3560/// Start receiving chat updates for a channel
3561async fn join_channel_chat(
3562    _request: proto::JoinChannelChat,
3563    _response: Response<proto::JoinChannelChat>,
3564    _session: MessageContext,
3565) -> Result<()> {
3566    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3567}
3568
3569/// Stop receiving chat updates for a channel
3570async fn leave_channel_chat(
3571    _request: proto::LeaveChannelChat,
3572    _session: MessageContext,
3573) -> Result<()> {
3574    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3575}
3576
3577/// Retrieve the chat history for a channel
3578async fn get_channel_messages(
3579    _request: proto::GetChannelMessages,
3580    _response: Response<proto::GetChannelMessages>,
3581    _session: MessageContext,
3582) -> Result<()> {
3583    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Retrieve specific chat messages
3587async fn get_channel_messages_by_id(
3588    _request: proto::GetChannelMessagesById,
3589    _response: Response<proto::GetChannelMessagesById>,
3590    _session: MessageContext,
3591) -> Result<()> {
3592    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3593}
3594
3595/// Retrieve the current users notifications
3596async fn get_notifications(
3597    request: proto::GetNotifications,
3598    response: Response<proto::GetNotifications>,
3599    session: MessageContext,
3600) -> Result<()> {
3601    let notifications = session
3602        .db()
3603        .await
3604        .get_notifications(
3605            session.user_id(),
3606            NOTIFICATION_COUNT_PER_PAGE,
3607            request.before_id.map(db::NotificationId::from_proto),
3608        )
3609        .await?;
3610    response.send(proto::GetNotificationsResponse {
3611        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3612        notifications,
3613    })?;
3614    Ok(())
3615}
3616
3617/// Mark notifications as read
3618async fn mark_notification_as_read(
3619    request: proto::MarkNotificationRead,
3620    response: Response<proto::MarkNotificationRead>,
3621    session: MessageContext,
3622) -> Result<()> {
3623    let database = &session.db().await;
3624    let notifications = database
3625        .mark_notification_as_read_by_id(
3626            session.user_id(),
3627            NotificationId::from_proto(request.notification_id),
3628        )
3629        .await?;
3630    send_notifications(
3631        &*session.connection_pool().await,
3632        &session.peer,
3633        notifications,
3634    );
3635    response.send(proto::Ack {})?;
3636    Ok(())
3637}
3638
3639fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3640    let message = match message {
3641        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3642        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3643        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3644        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3645        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3646            code: frame.code.into(),
3647            reason: frame.reason.as_str().to_owned().into(),
3648        })),
3649        // We should never receive a frame while reading the message, according
3650        // to the `tungstenite` maintainers:
3651        //
3652        // > It cannot occur when you read messages from the WebSocket, but it
3653        // > can be used when you want to send the raw frames (e.g. you want to
3654        // > send the frames to the WebSocket without composing the full message first).
3655        // >
3656        // > — https://github.com/snapview/tungstenite-rs/issues/268
3657        TungsteniteMessage::Frame(_) => {
3658            bail!("received an unexpected frame while reading the message")
3659        }
3660    };
3661
3662    Ok(message)
3663}
3664
3665fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3666    match message {
3667        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3668        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3669        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3670        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3671        AxumMessage::Close(frame) => {
3672            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3673                code: frame.code.into(),
3674                reason: frame.reason.as_ref().into(),
3675            }))
3676        }
3677    }
3678}
3679
3680fn notify_membership_updated(
3681    connection_pool: &mut ConnectionPool,
3682    result: MembershipUpdated,
3683    user_id: UserId,
3684    peer: &Peer,
3685) {
3686    for membership in &result.new_channels.channel_memberships {
3687        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3688    }
3689    for channel_id in &result.removed_channels {
3690        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3691    }
3692
3693    let user_channels_update = proto::UpdateUserChannels {
3694        channel_memberships: result
3695            .new_channels
3696            .channel_memberships
3697            .iter()
3698            .map(|cm| proto::ChannelMembership {
3699                channel_id: cm.channel_id.to_proto(),
3700                role: cm.role.into(),
3701            })
3702            .collect(),
3703        ..Default::default()
3704    };
3705
3706    let mut update = build_channels_update(result.new_channels);
3707    update.delete_channels = result
3708        .removed_channels
3709        .into_iter()
3710        .map(|id| id.to_proto())
3711        .collect();
3712    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3713
3714    for connection_id in connection_pool.user_connection_ids(user_id) {
3715        peer.send(connection_id, user_channels_update.clone())
3716            .trace_err();
3717        peer.send(connection_id, update.clone()).trace_err();
3718    }
3719}
3720
3721fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3722    proto::UpdateUserChannels {
3723        channel_memberships: channels
3724            .channel_memberships
3725            .iter()
3726            .map(|m| proto::ChannelMembership {
3727                channel_id: m.channel_id.to_proto(),
3728                role: m.role.into(),
3729            })
3730            .collect(),
3731        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3732    }
3733}
3734
3735fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3736    let mut update = proto::UpdateChannels::default();
3737
3738    for channel in channels.channels {
3739        update.channels.push(channel.to_proto());
3740    }
3741
3742    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3743
3744    for (channel_id, participants) in channels.channel_participants {
3745        update
3746            .channel_participants
3747            .push(proto::ChannelParticipants {
3748                channel_id: channel_id.to_proto(),
3749                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3750            });
3751    }
3752
3753    for channel in channels.invited_channels {
3754        update.channel_invitations.push(channel.to_proto());
3755    }
3756
3757    update
3758}
3759
3760fn build_initial_contacts_update(
3761    contacts: Vec<db::Contact>,
3762    pool: &ConnectionPool,
3763) -> proto::UpdateContacts {
3764    let mut update = proto::UpdateContacts::default();
3765
3766    for contact in contacts {
3767        match contact {
3768            db::Contact::Accepted { user_id, busy } => {
3769                update.contacts.push(contact_for_user(user_id, busy, pool));
3770            }
3771            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3772            db::Contact::Incoming { user_id } => {
3773                update
3774                    .incoming_requests
3775                    .push(proto::IncomingContactRequest {
3776                        requester_id: user_id.to_proto(),
3777                    })
3778            }
3779        }
3780    }
3781
3782    update
3783}
3784
3785fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3786    proto::Contact {
3787        user_id: user_id.to_proto(),
3788        online: pool.is_user_online(user_id),
3789        busy,
3790    }
3791}
3792
3793fn room_updated(room: &proto::Room, peer: &Peer) {
3794    broadcast(
3795        None,
3796        room.participants
3797            .iter()
3798            .filter_map(|participant| Some(participant.peer_id?.into())),
3799        |peer_id| {
3800            peer.send(
3801                peer_id,
3802                proto::RoomUpdated {
3803                    room: Some(room.clone()),
3804                },
3805            )
3806        },
3807    );
3808}
3809
3810fn channel_updated(
3811    channel: &db::channel::Model,
3812    room: &proto::Room,
3813    peer: &Peer,
3814    pool: &ConnectionPool,
3815) {
3816    let participants = room
3817        .participants
3818        .iter()
3819        .map(|p| p.user_id)
3820        .collect::<Vec<_>>();
3821
3822    broadcast(
3823        None,
3824        pool.channel_connection_ids(channel.root_id())
3825            .filter_map(|(channel_id, role)| {
3826                role.can_see_channel(channel.visibility)
3827                    .then_some(channel_id)
3828            }),
3829        |peer_id| {
3830            peer.send(
3831                peer_id,
3832                proto::UpdateChannels {
3833                    channel_participants: vec![proto::ChannelParticipants {
3834                        channel_id: channel.id.to_proto(),
3835                        participant_user_ids: participants.clone(),
3836                    }],
3837                    ..Default::default()
3838                },
3839            )
3840        },
3841    );
3842}
3843
3844async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3845    let db = session.db().await;
3846
3847    let contacts = db.get_contacts(user_id).await?;
3848    let busy = db.is_user_busy(user_id).await?;
3849
3850    let pool = session.connection_pool().await;
3851    let updated_contact = contact_for_user(user_id, busy, &pool);
3852    for contact in contacts {
3853        if let db::Contact::Accepted {
3854            user_id: contact_user_id,
3855            ..
3856        } = contact
3857        {
3858            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3859                session
3860                    .peer
3861                    .send(
3862                        contact_conn_id,
3863                        proto::UpdateContacts {
3864                            contacts: vec![updated_contact.clone()],
3865                            remove_contacts: Default::default(),
3866                            incoming_requests: Default::default(),
3867                            remove_incoming_requests: Default::default(),
3868                            outgoing_requests: Default::default(),
3869                            remove_outgoing_requests: Default::default(),
3870                        },
3871                    )
3872                    .trace_err();
3873            }
3874        }
3875    }
3876    Ok(())
3877}
3878
3879async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3880    let mut contacts_to_update = HashSet::default();
3881
3882    let room_id;
3883    let canceled_calls_to_user_ids;
3884    let livekit_room;
3885    let delete_livekit_room;
3886    let room;
3887    let channel;
3888
3889    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3890        contacts_to_update.insert(session.user_id());
3891
3892        for project in left_room.left_projects.values() {
3893            project_left(project, session);
3894        }
3895
3896        room_id = RoomId::from_proto(left_room.room.id);
3897        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3898        livekit_room = mem::take(&mut left_room.room.livekit_room);
3899        delete_livekit_room = left_room.deleted;
3900        room = mem::take(&mut left_room.room);
3901        channel = mem::take(&mut left_room.channel);
3902
3903        room_updated(&room, &session.peer);
3904    } else {
3905        return Ok(());
3906    }
3907
3908    if let Some(channel) = channel {
3909        channel_updated(
3910            &channel,
3911            &room,
3912            &session.peer,
3913            &*session.connection_pool().await,
3914        );
3915    }
3916
3917    {
3918        let pool = session.connection_pool().await;
3919        for canceled_user_id in canceled_calls_to_user_ids {
3920            for connection_id in pool.user_connection_ids(canceled_user_id) {
3921                session
3922                    .peer
3923                    .send(
3924                        connection_id,
3925                        proto::CallCanceled {
3926                            room_id: room_id.to_proto(),
3927                        },
3928                    )
3929                    .trace_err();
3930            }
3931            contacts_to_update.insert(canceled_user_id);
3932        }
3933    }
3934
3935    for contact_user_id in contacts_to_update {
3936        update_user_contacts(contact_user_id, session).await?;
3937    }
3938
3939    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3940        live_kit
3941            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3942            .await
3943            .trace_err();
3944
3945        if delete_livekit_room {
3946            live_kit.delete_room(livekit_room).await.trace_err();
3947        }
3948    }
3949
3950    Ok(())
3951}
3952
3953async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3954    let left_channel_buffers = session
3955        .db()
3956        .await
3957        .leave_channel_buffers(session.connection_id)
3958        .await?;
3959
3960    for left_buffer in left_channel_buffers {
3961        channel_buffer_updated(
3962            session.connection_id,
3963            left_buffer.connections,
3964            &proto::UpdateChannelBufferCollaborators {
3965                channel_id: left_buffer.channel_id.to_proto(),
3966                collaborators: left_buffer.collaborators,
3967            },
3968            &session.peer,
3969        );
3970    }
3971
3972    Ok(())
3973}
3974
3975fn project_left(project: &db::LeftProject, session: &Session) {
3976    for connection_id in &project.connection_ids {
3977        if project.should_unshare {
3978            session
3979                .peer
3980                .send(
3981                    *connection_id,
3982                    proto::UnshareProject {
3983                        project_id: project.id.to_proto(),
3984                    },
3985                )
3986                .trace_err();
3987        } else {
3988            session
3989                .peer
3990                .send(
3991                    *connection_id,
3992                    proto::RemoveProjectCollaborator {
3993                        project_id: project.id.to_proto(),
3994                        peer_id: Some(session.connection_id.into()),
3995                    },
3996                )
3997                .trace_err();
3998        }
3999    }
4000}
4001
4002async fn share_agent_thread(
4003    request: proto::ShareAgentThread,
4004    response: Response<proto::ShareAgentThread>,
4005    session: MessageContext,
4006) -> Result<()> {
4007    let user_id = session.user_id();
4008
4009    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4010        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4011
4012    session
4013        .db()
4014        .await
4015        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4016        .await?;
4017
4018    response.send(proto::Ack {})?;
4019
4020    Ok(())
4021}
4022
4023async fn get_shared_agent_thread(
4024    request: proto::GetSharedAgentThread,
4025    response: Response<proto::GetSharedAgentThread>,
4026    session: MessageContext,
4027) -> Result<()> {
4028    let share_id = SharedThreadId::from_proto(request.session_id)
4029        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4030
4031    let result = session.db().await.get_shared_thread(share_id).await?;
4032
4033    match result {
4034        Some((thread, username)) => {
4035            response.send(proto::GetSharedAgentThreadResponse {
4036                title: thread.title,
4037                thread_data: thread.data,
4038                sharer_username: username,
4039                created_at: thread.created_at.and_utc().to_rfc3339(),
4040            })?;
4041        }
4042        None => {
4043            return Err(anyhow!("Shared thread not found").into());
4044        }
4045    }
4046
4047    Ok(())
4048}
4049
4050pub trait ResultExt {
4051    type Ok;
4052
4053    fn trace_err(self) -> Option<Self::Ok>;
4054}
4055
4056impl<T, E> ResultExt for Result<T, E>
4057where
4058    E: std::fmt::Debug,
4059{
4060    type Ok = T;
4061
4062    #[track_caller]
4063    fn trace_err(self) -> Option<T> {
4064        match self {
4065            Ok(value) => Some(value),
4066            Err(error) => {
4067                tracing::error!("{:?}", error);
4068                None
4069            }
4070        }
4071    }
4072}