rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 415            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 416            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 417            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 419            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 421            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 422            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 423            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 426            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 427            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 430            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 431            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 437            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 438            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 439            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 440            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 441            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 442            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 443            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 444            .add_request_handler(share_agent_thread)
 445            .add_request_handler(get_shared_agent_thread)
 446            .add_request_handler(forward_project_search_chunk);
 447
 448        Arc::new(server)
 449    }
 450
 451    pub async fn start(&self) -> Result<()> {
 452        let server_id = *self.id.lock();
 453        let app_state = self.app_state.clone();
 454        let peer = self.peer.clone();
 455        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 456        let pool = self.connection_pool.clone();
 457        let livekit_client = self.app_state.livekit_client.clone();
 458
 459        let span = info_span!("start server");
 460        self.app_state.executor.spawn_detached(
 461            async move {
 462                tracing::info!("waiting for cleanup timeout");
 463                timeout.await;
 464                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 465
 466                app_state
 467                    .db
 468                    .delete_stale_channel_chat_participants(
 469                        &app_state.config.zed_environment,
 470                        server_id,
 471                    )
 472                    .await
 473                    .trace_err();
 474
 475                if let Some((room_ids, channel_ids)) = app_state
 476                    .db
 477                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 478                    .await
 479                    .trace_err()
 480                {
 481                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 482                    tracing::info!(
 483                        stale_channel_buffer_count = channel_ids.len(),
 484                        "retrieved stale channel buffers"
 485                    );
 486
 487                    for channel_id in channel_ids {
 488                        if let Some(refreshed_channel_buffer) = app_state
 489                            .db
 490                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 491                            .await
 492                            .trace_err()
 493                        {
 494                            for connection_id in refreshed_channel_buffer.connection_ids {
 495                                peer.send(
 496                                    connection_id,
 497                                    proto::UpdateChannelBufferCollaborators {
 498                                        channel_id: channel_id.to_proto(),
 499                                        collaborators: refreshed_channel_buffer
 500                                            .collaborators
 501                                            .clone(),
 502                                    },
 503                                )
 504                                .trace_err();
 505                            }
 506                        }
 507                    }
 508
 509                    for room_id in room_ids {
 510                        let mut contacts_to_update = HashSet::default();
 511                        let mut canceled_calls_to_user_ids = Vec::new();
 512                        let mut livekit_room = String::new();
 513                        let mut delete_livekit_room = false;
 514
 515                        if let Some(mut refreshed_room) = app_state
 516                            .db
 517                            .clear_stale_room_participants(room_id, server_id)
 518                            .await
 519                            .trace_err()
 520                        {
 521                            tracing::info!(
 522                                room_id = room_id.0,
 523                                new_participant_count = refreshed_room.room.participants.len(),
 524                                "refreshed room"
 525                            );
 526                            room_updated(&refreshed_room.room, &peer);
 527                            if let Some(channel) = refreshed_room.channel.as_ref() {
 528                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 529                            }
 530                            contacts_to_update
 531                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 532                            contacts_to_update
 533                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 534                            canceled_calls_to_user_ids =
 535                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 536                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 537                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 538                        }
 539
 540                        {
 541                            let pool = pool.lock();
 542                            for canceled_user_id in canceled_calls_to_user_ids {
 543                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 544                                    peer.send(
 545                                        connection_id,
 546                                        proto::CallCanceled {
 547                                            room_id: room_id.to_proto(),
 548                                        },
 549                                    )
 550                                    .trace_err();
 551                                }
 552                            }
 553                        }
 554
 555                        for user_id in contacts_to_update {
 556                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 557                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 558                            if let Some((busy, contacts)) = busy.zip(contacts) {
 559                                let pool = pool.lock();
 560                                let updated_contact = contact_for_user(user_id, busy, &pool);
 561                                for contact in contacts {
 562                                    if let db::Contact::Accepted {
 563                                        user_id: contact_user_id,
 564                                        ..
 565                                    } = contact
 566                                    {
 567                                        for contact_conn_id in
 568                                            pool.user_connection_ids(contact_user_id)
 569                                        {
 570                                            peer.send(
 571                                                contact_conn_id,
 572                                                proto::UpdateContacts {
 573                                                    contacts: vec![updated_contact.clone()],
 574                                                    remove_contacts: Default::default(),
 575                                                    incoming_requests: Default::default(),
 576                                                    remove_incoming_requests: Default::default(),
 577                                                    outgoing_requests: Default::default(),
 578                                                    remove_outgoing_requests: Default::default(),
 579                                                },
 580                                            )
 581                                            .trace_err();
 582                                        }
 583                                    }
 584                                }
 585                            }
 586                        }
 587
 588                        if let Some(live_kit) = livekit_client.as_ref()
 589                            && delete_livekit_room
 590                        {
 591                            live_kit.delete_room(livekit_room).await.trace_err();
 592                        }
 593                    }
 594                }
 595
 596                app_state
 597                    .db
 598                    .delete_stale_channel_chat_participants(
 599                        &app_state.config.zed_environment,
 600                        server_id,
 601                    )
 602                    .await
 603                    .trace_err();
 604
 605                app_state
 606                    .db
 607                    .clear_old_worktree_entries(server_id)
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 614                    .await
 615                    .trace_err();
 616            }
 617            .instrument(span),
 618        );
 619        Ok(())
 620    }
 621
 622    pub fn teardown(&self) {
 623        self.peer.teardown();
 624        self.connection_pool.lock().reset();
 625        let _ = self.teardown.send(true);
 626    }
 627
 628    #[cfg(feature = "test-support")]
 629    pub fn reset(&self, id: ServerId) {
 630        self.teardown();
 631        *self.id.lock() = id;
 632        self.peer.reset(id.0 as u32);
 633        let _ = self.teardown.send(false);
 634    }
 635
 636    #[cfg(feature = "test-support")]
 637    pub fn id(&self) -> ServerId {
 638        *self.id.lock()
 639    }
 640
 641    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 642    where
 643        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 644        Fut: 'static + Send + Future<Output = Result<()>>,
 645        M: EnvelopedMessage,
 646    {
 647        let prev_handler = self.handlers.insert(
 648            TypeId::of::<M>(),
 649            Box::new(move |envelope, session, span| {
 650                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 651                let received_at = envelope.received_at;
 652                tracing::info!("message received");
 653                let start_time = Instant::now();
 654                let future = (handler)(
 655                    *envelope,
 656                    MessageContext {
 657                        session,
 658                        span: span.clone(),
 659                    },
 660                );
 661                async move {
 662                    let result = future.await;
 663                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 664                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 665                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 666                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 667                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 668                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 669                    match result {
 670                        Err(error) => {
 671                            tracing::error!(?error, "error handling message")
 672                        }
 673                        Ok(()) => tracing::info!("finished handling message"),
 674                    }
 675                }
 676                .boxed()
 677            }),
 678        );
 679        if prev_handler.is_some() {
 680            panic!("registered a handler for the same message twice");
 681        }
 682        self
 683    }
 684
 685    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 686    where
 687        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 688        Fut: 'static + Send + Future<Output = Result<()>>,
 689        M: EnvelopedMessage,
 690    {
 691        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 692        self
 693    }
 694
 695    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 696    where
 697        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 698        Fut: Send + Future<Output = Result<()>>,
 699        M: RequestMessage,
 700    {
 701        let handler = Arc::new(handler);
 702        self.add_handler(move |envelope, session| {
 703            let receipt = envelope.receipt();
 704            let handler = handler.clone();
 705            async move {
 706                let peer = session.peer.clone();
 707                let responded = Arc::new(AtomicBool::default());
 708                let response = Response {
 709                    peer: peer.clone(),
 710                    responded: responded.clone(),
 711                    receipt,
 712                };
 713                match (handler)(envelope.payload, response, session).await {
 714                    Ok(()) => {
 715                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 716                            Ok(())
 717                        } else {
 718                            Err(anyhow!("handler did not send a response"))?
 719                        }
 720                    }
 721                    Err(error) => {
 722                        let proto_err = match &error {
 723                            Error::Internal(err) => err.to_proto(),
 724                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 725                        };
 726                        peer.respond_with_error(receipt, proto_err)?;
 727                        Err(error)
 728                    }
 729                }
 730            }
 731        })
 732    }
 733
 734    pub fn handle_connection(
 735        self: &Arc<Self>,
 736        connection: Connection,
 737        address: String,
 738        principal: Principal,
 739        zed_version: ZedVersion,
 740        release_channel: Option<String>,
 741        user_agent: Option<String>,
 742        geoip_country_code: Option<String>,
 743        system_id: Option<String>,
 744        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 745        executor: Executor,
 746        connection_guard: Option<ConnectionGuard>,
 747    ) -> impl Future<Output = ()> + use<> {
 748        let this = self.clone();
 749        let span = info_span!("handle connection", %address,
 750            connection_id=field::Empty,
 751            user_id=field::Empty,
 752            login=field::Empty,
 753            user_agent=field::Empty,
 754            geoip_country_code=field::Empty,
 755            release_channel=field::Empty,
 756        );
 757        principal.update_span(&span);
 758        if let Some(user_agent) = user_agent {
 759            span.record("user_agent", user_agent);
 760        }
 761        if let Some(release_channel) = release_channel {
 762            span.record("release_channel", release_channel);
 763        }
 764
 765        if let Some(country_code) = geoip_country_code.as_ref() {
 766            span.record("geoip_country_code", country_code);
 767        }
 768
 769        let mut teardown = self.teardown.subscribe();
 770        async move {
 771            if *teardown.borrow() {
 772                tracing::error!("server is tearing down");
 773                return;
 774            }
 775
 776            let (connection_id, handle_io, mut incoming_rx) =
 777                this.peer.add_connection(connection, {
 778                    let executor = executor.clone();
 779                    move |duration| executor.sleep(duration)
 780                });
 781            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 782
 783            tracing::info!("connection opened");
 784
 785            let session = Session {
 786                principal: principal.clone(),
 787                connection_id,
 788                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 789                peer: this.peer.clone(),
 790                connection_pool: this.connection_pool.clone(),
 791                app_state: this.app_state.clone(),
 792                geoip_country_code,
 793                system_id,
 794                _executor: executor.clone(),
 795            };
 796
 797            if let Err(error) = this
 798                .send_initial_client_update(
 799                    connection_id,
 800                    zed_version,
 801                    send_connection_id,
 802                    &session,
 803                )
 804                .await
 805            {
 806                tracing::error!(?error, "failed to send initial client update");
 807                return;
 808            }
 809            drop(connection_guard);
 810
 811            let handle_io = handle_io.fuse();
 812            futures::pin_mut!(handle_io);
 813
 814            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 815            // This prevents deadlocks when e.g., client A performs a request to client B and
 816            // client B performs a request to client A. If both clients stop processing further
 817            // messages until their respective request completes, they won't have a chance to
 818            // respond to the other client's request and cause a deadlock.
 819            //
 820            // This arrangement ensures we will attempt to process earlier messages first, but fall
 821            // back to processing messages arrived later in the spirit of making progress.
 822            const MAX_CONCURRENT_HANDLERS: usize = 256;
 823            let mut foreground_message_handlers = FuturesUnordered::new();
 824            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 825            let get_concurrent_handlers = {
 826                let concurrent_handlers = concurrent_handlers.clone();
 827                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 828            };
 829            loop {
 830                let next_message = async {
 831                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 832                    let message = incoming_rx.next().await;
 833                    // Cache the concurrent_handlers here, so that we know what the
 834                    // queue looks like as each handler starts
 835                    (permit, message, get_concurrent_handlers())
 836                }
 837                .fuse();
 838                futures::pin_mut!(next_message);
 839                futures::select_biased! {
 840                    _ = teardown.changed().fuse() => return,
 841                    result = handle_io => {
 842                        if let Err(error) = result {
 843                            tracing::error!(?error, "error handling I/O");
 844                        }
 845                        break;
 846                    }
 847                    _ = foreground_message_handlers.next() => {}
 848                    next_message = next_message => {
 849                        let (permit, message, concurrent_handlers) = next_message;
 850                        if let Some(message) = message {
 851                            let type_name = message.payload_type_name();
 852                            // note: we copy all the fields from the parent span so we can query them in the logs.
 853                            // (https://github.com/tokio-rs/tracing/issues/2670).
 854                            let span = tracing::info_span!("receive message",
 855                                %connection_id,
 856                                %address,
 857                                type_name,
 858                                concurrent_handlers,
 859                                user_id=field::Empty,
 860                                login=field::Empty,
 861                                lsp_query_request=field::Empty,
 862                                release_channel=field::Empty,
 863                                { TOTAL_DURATION_MS }=field::Empty,
 864                                { PROCESSING_DURATION_MS }=field::Empty,
 865                                { QUEUE_DURATION_MS }=field::Empty,
 866                                { HOST_WAITING_MS }=field::Empty
 867                            );
 868                            principal.update_span(&span);
 869                            let span_enter = span.enter();
 870                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 871                                let is_background = message.is_background();
 872                                let handle_message = (handler)(message, session.clone(), span.clone());
 873                                drop(span_enter);
 874
 875                                let handle_message = async move {
 876                                    handle_message.await;
 877                                    drop(permit);
 878                                }.instrument(span);
 879                                if is_background {
 880                                    executor.spawn_detached(handle_message);
 881                                } else {
 882                                    foreground_message_handlers.push(handle_message);
 883                                }
 884                            } else {
 885                                tracing::error!("no message handler");
 886                            }
 887                        } else {
 888                            tracing::info!("connection closed");
 889                            break;
 890                        }
 891                    }
 892                }
 893            }
 894
 895            drop(foreground_message_handlers);
 896            let concurrent_handlers = get_concurrent_handlers();
 897            tracing::info!(concurrent_handlers, "signing out");
 898            if let Err(error) = connection_lost(session, teardown, executor).await {
 899                tracing::error!(?error, "error signing out");
 900            }
 901        }
 902        .instrument(span)
 903    }
 904
 905    async fn send_initial_client_update(
 906        &self,
 907        connection_id: ConnectionId,
 908        zed_version: ZedVersion,
 909        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 910        session: &Session,
 911    ) -> Result<()> {
 912        self.peer.send(
 913            connection_id,
 914            proto::Hello {
 915                peer_id: Some(connection_id.into()),
 916            },
 917        )?;
 918        tracing::info!("sent hello message");
 919        if let Some(send_connection_id) = send_connection_id.take() {
 920            let _ = send_connection_id.send(connection_id);
 921        }
 922
 923        match &session.principal {
 924            Principal::User(user) => {
 925                if !user.connected_once {
 926                    self.peer.send(connection_id, proto::ShowContacts {})?;
 927                    self.app_state
 928                        .db
 929                        .set_user_connected_once(user.id, true)
 930                        .await?;
 931                }
 932
 933                let contacts = self.app_state.db.get_contacts(user.id).await?;
 934
 935                {
 936                    let mut pool = self.connection_pool.lock();
 937                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 938                    self.peer.send(
 939                        connection_id,
 940                        build_initial_contacts_update(contacts, &pool),
 941                    )?;
 942                }
 943
 944                if should_auto_subscribe_to_channels(&zed_version) {
 945                    subscribe_user_to_channels(user.id, session).await?;
 946                }
 947
 948                if let Some(incoming_call) =
 949                    self.app_state.db.incoming_call_for_user(user.id).await?
 950                {
 951                    self.peer.send(connection_id, incoming_call)?;
 952                }
 953
 954                update_user_contacts(user.id, session).await?;
 955            }
 956        }
 957
 958        Ok(())
 959    }
 960}
 961
 962impl Deref for ConnectionPoolGuard<'_> {
 963    type Target = ConnectionPool;
 964
 965    fn deref(&self) -> &Self::Target {
 966        &self.guard
 967    }
 968}
 969
 970impl DerefMut for ConnectionPoolGuard<'_> {
 971    fn deref_mut(&mut self) -> &mut Self::Target {
 972        &mut self.guard
 973    }
 974}
 975
 976impl Drop for ConnectionPoolGuard<'_> {
 977    fn drop(&mut self) {
 978        #[cfg(feature = "test-support")]
 979        self.check_invariants();
 980    }
 981}
 982
 983fn broadcast<F>(
 984    sender_id: Option<ConnectionId>,
 985    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 986    mut f: F,
 987) where
 988    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 989{
 990    for receiver_id in receiver_ids {
 991        if Some(receiver_id) != sender_id
 992            && let Err(error) = f(receiver_id)
 993        {
 994            tracing::error!("failed to send to {:?} {}", receiver_id, error);
 995        }
 996    }
 997}
 998
 999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002    fn name() -> &'static HeaderName {
1003        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005    }
1006
1007    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008    where
1009        Self: Sized,
1010        I: Iterator<Item = &'i axum::http::HeaderValue>,
1011    {
1012        let version = values
1013            .next()
1014            .ok_or_else(axum::headers::Error::invalid)?
1015            .to_str()
1016            .map_err(|_| axum::headers::Error::invalid())?
1017            .parse()
1018            .map_err(|_| axum::headers::Error::invalid())?;
1019        Ok(Self(version))
1020    }
1021
1022    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023        values.extend([self.0.to_string().parse().unwrap()]);
1024    }
1025}
1026
1027pub struct AppVersionHeader(Version);
1028impl Header for AppVersionHeader {
1029    fn name() -> &'static HeaderName {
1030        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032    }
1033
1034    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035    where
1036        Self: Sized,
1037        I: Iterator<Item = &'i axum::http::HeaderValue>,
1038    {
1039        let version = values
1040            .next()
1041            .ok_or_else(axum::headers::Error::invalid)?
1042            .to_str()
1043            .map_err(|_| axum::headers::Error::invalid())?
1044            .parse()
1045            .map_err(|_| axum::headers::Error::invalid())?;
1046        Ok(Self(version))
1047    }
1048
1049    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050        values.extend([self.0.to_string().parse().unwrap()]);
1051    }
1052}
1053
1054#[derive(Debug)]
1055pub struct ReleaseChannelHeader(String);
1056
1057impl Header for ReleaseChannelHeader {
1058    fn name() -> &'static HeaderName {
1059        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1060        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1061    }
1062
1063    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1064    where
1065        Self: Sized,
1066        I: Iterator<Item = &'i axum::http::HeaderValue>,
1067    {
1068        Ok(Self(
1069            values
1070                .next()
1071                .ok_or_else(axum::headers::Error::invalid)?
1072                .to_str()
1073                .map_err(|_| axum::headers::Error::invalid())?
1074                .to_owned(),
1075        ))
1076    }
1077
1078    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1079        values.extend([self.0.parse().unwrap()]);
1080    }
1081}
1082
1083pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1084    Router::new()
1085        .route("/rpc", get(handle_websocket_request))
1086        .layer(
1087            ServiceBuilder::new()
1088                .layer(Extension(server.app_state.clone()))
1089                .layer(middleware::from_fn(auth::validate_header)),
1090        )
1091        .route("/metrics", get(handle_metrics))
1092        .layer(Extension(server))
1093}
1094
1095pub async fn handle_websocket_request(
1096    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1097    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1098    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1099    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1100    Extension(server): Extension<Arc<Server>>,
1101    Extension(principal): Extension<Principal>,
1102    user_agent: Option<TypedHeader<UserAgent>>,
1103    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1104    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1105    ws: WebSocketUpgrade,
1106) -> axum::response::Response {
1107    if protocol_version != rpc::PROTOCOL_VERSION {
1108        return (
1109            StatusCode::UPGRADE_REQUIRED,
1110            "client must be upgraded".to_string(),
1111        )
1112            .into_response();
1113    }
1114
1115    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1116        return (
1117            StatusCode::UPGRADE_REQUIRED,
1118            "no version header found".to_string(),
1119        )
1120            .into_response();
1121    };
1122
1123    let release_channel = release_channel_header.map(|header| header.0.0);
1124
1125    if !version.can_collaborate() {
1126        return (
1127            StatusCode::UPGRADE_REQUIRED,
1128            "client must be upgraded".to_string(),
1129        )
1130            .into_response();
1131    }
1132
1133    let socket_address = socket_address.to_string();
1134
1135    // Acquire connection guard before WebSocket upgrade
1136    let connection_guard = match ConnectionGuard::try_acquire() {
1137        Ok(guard) => guard,
1138        Err(()) => {
1139            return (
1140                StatusCode::SERVICE_UNAVAILABLE,
1141                "Too many concurrent connections",
1142            )
1143                .into_response();
1144        }
1145    };
1146
1147    ws.on_upgrade(move |socket| {
1148        let socket = socket
1149            .map_ok(to_tungstenite_message)
1150            .err_into()
1151            .with(|message| async move { to_axum_message(message) });
1152        let connection = Connection::new(Box::pin(socket));
1153        async move {
1154            server
1155                .handle_connection(
1156                    connection,
1157                    socket_address,
1158                    principal,
1159                    version,
1160                    release_channel,
1161                    user_agent.map(|header| header.to_string()),
1162                    country_code_header.map(|header| header.to_string()),
1163                    system_id_header.map(|header| header.to_string()),
1164                    None,
1165                    Executor::Production,
1166                    Some(connection_guard),
1167                )
1168                .await;
1169        }
1170    })
1171}
1172
1173pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1174    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1175    let connections_metric = CONNECTIONS_METRIC
1176        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1177
1178    let connections = server
1179        .connection_pool
1180        .lock()
1181        .connections()
1182        .filter(|connection| !connection.admin)
1183        .count();
1184    connections_metric.set(connections as _);
1185
1186    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1187    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1188        register_int_gauge!(
1189            "shared_projects",
1190            "number of open projects with one or more guests"
1191        )
1192        .unwrap()
1193    });
1194
1195    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1196    shared_projects_metric.set(shared_projects as _);
1197
1198    let encoder = prometheus::TextEncoder::new();
1199    let metric_families = prometheus::gather();
1200    let encoded_metrics = encoder
1201        .encode_to_string(&metric_families)
1202        .map_err(|err| anyhow!("{err}"))?;
1203    Ok(encoded_metrics)
1204}
1205
1206#[instrument(err, skip(executor))]
1207async fn connection_lost(
1208    session: Session,
1209    mut teardown: watch::Receiver<bool>,
1210    executor: Executor,
1211) -> Result<()> {
1212    session.peer.disconnect(session.connection_id);
1213    session
1214        .connection_pool()
1215        .await
1216        .remove_connection(session.connection_id)?;
1217
1218    session
1219        .db()
1220        .await
1221        .connection_lost(session.connection_id)
1222        .await
1223        .trace_err();
1224
1225    futures::select_biased! {
1226        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1227
1228            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1229            leave_room_for_session(&session, session.connection_id).await.trace_err();
1230            leave_channel_buffers_for_session(&session)
1231                .await
1232                .trace_err();
1233
1234            if !session
1235                .connection_pool()
1236                .await
1237                .is_user_online(session.user_id())
1238            {
1239                let db = session.db().await;
1240                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1241                    room_updated(&room, &session.peer);
1242                }
1243            }
1244
1245            update_user_contacts(session.user_id(), &session).await?;
1246        },
1247        _ = teardown.changed().fuse() => {}
1248    }
1249
1250    Ok(())
1251}
1252
1253/// Acknowledges a ping from a client, used to keep the connection alive.
1254async fn ping(
1255    _: proto::Ping,
1256    response: Response<proto::Ping>,
1257    _session: MessageContext,
1258) -> Result<()> {
1259    response.send(proto::Ack {})?;
1260    Ok(())
1261}
1262
1263/// Creates a new room for calling (outside of channels)
1264async fn create_room(
1265    _request: proto::CreateRoom,
1266    response: Response<proto::CreateRoom>,
1267    session: MessageContext,
1268) -> Result<()> {
1269    let livekit_room = nanoid::nanoid!(30);
1270
1271    let live_kit_connection_info = util::maybe!(async {
1272        let live_kit = session.app_state.livekit_client.as_ref();
1273        let live_kit = live_kit?;
1274        let user_id = session.user_id().to_string();
1275
1276        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1277
1278        Some(proto::LiveKitConnectionInfo {
1279            server_url: live_kit.url().into(),
1280            token,
1281            can_publish: true,
1282        })
1283    })
1284    .await;
1285
1286    let room = session
1287        .db()
1288        .await
1289        .create_room(session.user_id(), session.connection_id, &livekit_room)
1290        .await?;
1291
1292    response.send(proto::CreateRoomResponse {
1293        room: Some(room.clone()),
1294        live_kit_connection_info,
1295    })?;
1296
1297    update_user_contacts(session.user_id(), &session).await?;
1298    Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303    request: proto::JoinRoom,
1304    response: Response<proto::JoinRoom>,
1305    session: MessageContext,
1306) -> Result<()> {
1307    let room_id = RoomId::from_proto(request.id);
1308
1309    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311    if let Some(channel_id) = channel_id {
1312        return join_channel_internal(channel_id, Box::new(response), session).await;
1313    }
1314
1315    let joined_room = {
1316        let room = session
1317            .db()
1318            .await
1319            .join_room(room_id, session.user_id(), session.connection_id)
1320            .await?;
1321        room_updated(&room.room, &session.peer);
1322        room.into_inner()
1323    };
1324
1325    for connection_id in session
1326        .connection_pool()
1327        .await
1328        .user_connection_ids(session.user_id())
1329    {
1330        session
1331            .peer
1332            .send(
1333                connection_id,
1334                proto::CallCanceled {
1335                    room_id: room_id.to_proto(),
1336                },
1337            )
1338            .trace_err();
1339    }
1340
1341    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342    {
1343        live_kit
1344            .room_token(
1345                &joined_room.room.livekit_room,
1346                &session.user_id().to_string(),
1347            )
1348            .trace_err()
1349            .map(|token| proto::LiveKitConnectionInfo {
1350                server_url: live_kit.url().into(),
1351                token,
1352                can_publish: true,
1353            })
1354    } else {
1355        None
1356    };
1357
1358    response.send(proto::JoinRoomResponse {
1359        room: Some(joined_room.room),
1360        channel_id: None,
1361        live_kit_connection_info,
1362    })?;
1363
1364    update_user_contacts(session.user_id(), &session).await?;
1365    Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370    request: proto::RejoinRoom,
1371    response: Response<proto::RejoinRoom>,
1372    session: MessageContext,
1373) -> Result<()> {
1374    let room;
1375    let channel;
1376    {
1377        let mut rejoined_room = session
1378            .db()
1379            .await
1380            .rejoin_room(request, session.user_id(), session.connection_id)
1381            .await?;
1382
1383        response.send(proto::RejoinRoomResponse {
1384            room: Some(rejoined_room.room.clone()),
1385            reshared_projects: rejoined_room
1386                .reshared_projects
1387                .iter()
1388                .map(|project| proto::ResharedProject {
1389                    id: project.id.to_proto(),
1390                    collaborators: project
1391                        .collaborators
1392                        .iter()
1393                        .map(|collaborator| collaborator.to_proto())
1394                        .collect(),
1395                })
1396                .collect(),
1397            rejoined_projects: rejoined_room
1398                .rejoined_projects
1399                .iter()
1400                .map(|rejoined_project| rejoined_project.to_proto())
1401                .collect(),
1402        })?;
1403        room_updated(&rejoined_room.room, &session.peer);
1404
1405        for project in &rejoined_room.reshared_projects {
1406            for collaborator in &project.collaborators {
1407                session
1408                    .peer
1409                    .send(
1410                        collaborator.connection_id,
1411                        proto::UpdateProjectCollaborator {
1412                            project_id: project.id.to_proto(),
1413                            old_peer_id: Some(project.old_connection_id.into()),
1414                            new_peer_id: Some(session.connection_id.into()),
1415                        },
1416                    )
1417                    .trace_err();
1418            }
1419
1420            broadcast(
1421                Some(session.connection_id),
1422                project
1423                    .collaborators
1424                    .iter()
1425                    .map(|collaborator| collaborator.connection_id),
1426                |connection_id| {
1427                    session.peer.forward_send(
1428                        session.connection_id,
1429                        connection_id,
1430                        proto::UpdateProject {
1431                            project_id: project.id.to_proto(),
1432                            worktrees: project.worktrees.clone(),
1433                        },
1434                    )
1435                },
1436            );
1437        }
1438
1439        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441        let rejoined_room = rejoined_room.into_inner();
1442
1443        room = rejoined_room.room;
1444        channel = rejoined_room.channel;
1445    }
1446
1447    if let Some(channel) = channel {
1448        channel_updated(
1449            &channel,
1450            &room,
1451            &session.peer,
1452            &*session.connection_pool().await,
1453        );
1454    }
1455
1456    update_user_contacts(session.user_id(), &session).await?;
1457    Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461    rejoined_projects: &mut Vec<RejoinedProject>,
1462    session: &Session,
1463) -> Result<()> {
1464    for project in rejoined_projects.iter() {
1465        for collaborator in &project.collaborators {
1466            session
1467                .peer
1468                .send(
1469                    collaborator.connection_id,
1470                    proto::UpdateProjectCollaborator {
1471                        project_id: project.id.to_proto(),
1472                        old_peer_id: Some(project.old_connection_id.into()),
1473                        new_peer_id: Some(session.connection_id.into()),
1474                    },
1475                )
1476                .trace_err();
1477        }
1478    }
1479
1480    for project in rejoined_projects {
1481        for worktree in mem::take(&mut project.worktrees) {
1482            // Stream this worktree's entries.
1483            let message = proto::UpdateWorktree {
1484                project_id: project.id.to_proto(),
1485                worktree_id: worktree.id,
1486                abs_path: worktree.abs_path.clone(),
1487                root_name: worktree.root_name,
1488                updated_entries: worktree.updated_entries,
1489                removed_entries: worktree.removed_entries,
1490                scan_id: worktree.scan_id,
1491                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1492                updated_repositories: worktree.updated_repositories,
1493                removed_repositories: worktree.removed_repositories,
1494            };
1495            for update in proto::split_worktree_update(message) {
1496                session.peer.send(session.connection_id, update)?;
1497            }
1498
1499            // Stream this worktree's diagnostics.
1500            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1501            if let Some(summary) = worktree_diagnostics.next() {
1502                let message = proto::UpdateDiagnosticSummary {
1503                    project_id: project.id.to_proto(),
1504                    worktree_id: worktree.id,
1505                    summary: Some(summary),
1506                    more_summaries: worktree_diagnostics.collect(),
1507                };
1508                session.peer.send(session.connection_id, message)?;
1509            }
1510
1511            for settings_file in worktree.settings_files {
1512                session.peer.send(
1513                    session.connection_id,
1514                    proto::UpdateWorktreeSettings {
1515                        project_id: project.id.to_proto(),
1516                        worktree_id: worktree.id,
1517                        path: settings_file.path,
1518                        content: Some(settings_file.content),
1519                        kind: Some(settings_file.kind.to_proto().into()),
1520                        outside_worktree: Some(settings_file.outside_worktree),
1521                    },
1522                )?;
1523            }
1524        }
1525
1526        for repository in mem::take(&mut project.updated_repositories) {
1527            for update in split_repository_update(repository) {
1528                session.peer.send(session.connection_id, update)?;
1529            }
1530        }
1531
1532        for id in mem::take(&mut project.removed_repositories) {
1533            session.peer.send(
1534                session.connection_id,
1535                proto::RemoveRepository {
1536                    project_id: project.id.to_proto(),
1537                    id,
1538                },
1539            )?;
1540        }
1541    }
1542
1543    Ok(())
1544}
1545
1546/// leave room disconnects from the room.
1547async fn leave_room(
1548    _: proto::LeaveRoom,
1549    response: Response<proto::LeaveRoom>,
1550    session: MessageContext,
1551) -> Result<()> {
1552    leave_room_for_session(&session, session.connection_id).await?;
1553    response.send(proto::Ack {})?;
1554    Ok(())
1555}
1556
1557/// Updates the permissions of someone else in the room.
1558async fn set_room_participant_role(
1559    request: proto::SetRoomParticipantRole,
1560    response: Response<proto::SetRoomParticipantRole>,
1561    session: MessageContext,
1562) -> Result<()> {
1563    let user_id = UserId::from_proto(request.user_id);
1564    let role = ChannelRole::from(request.role());
1565
1566    let (livekit_room, can_publish) = {
1567        let room = session
1568            .db()
1569            .await
1570            .set_room_participant_role(
1571                session.user_id(),
1572                RoomId::from_proto(request.room_id),
1573                user_id,
1574                role,
1575            )
1576            .await?;
1577
1578        let livekit_room = room.livekit_room.clone();
1579        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1580        room_updated(&room, &session.peer);
1581        (livekit_room, can_publish)
1582    };
1583
1584    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1585        live_kit
1586            .update_participant(
1587                livekit_room.clone(),
1588                request.user_id.to_string(),
1589                livekit_api::proto::ParticipantPermission {
1590                    can_subscribe: true,
1591                    can_publish,
1592                    can_publish_data: can_publish,
1593                    hidden: false,
1594                    recorder: false,
1595                },
1596            )
1597            .await
1598            .trace_err();
1599    }
1600
1601    response.send(proto::Ack {})?;
1602    Ok(())
1603}
1604
1605/// Call someone else into the current room
1606async fn call(
1607    request: proto::Call,
1608    response: Response<proto::Call>,
1609    session: MessageContext,
1610) -> Result<()> {
1611    let room_id = RoomId::from_proto(request.room_id);
1612    let calling_user_id = session.user_id();
1613    let calling_connection_id = session.connection_id;
1614    let called_user_id = UserId::from_proto(request.called_user_id);
1615    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1616    if !session
1617        .db()
1618        .await
1619        .has_contact(calling_user_id, called_user_id)
1620        .await?
1621    {
1622        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1623    }
1624
1625    let incoming_call = {
1626        let (room, incoming_call) = &mut *session
1627            .db()
1628            .await
1629            .call(
1630                room_id,
1631                calling_user_id,
1632                calling_connection_id,
1633                called_user_id,
1634                initial_project_id,
1635            )
1636            .await?;
1637        room_updated(room, &session.peer);
1638        mem::take(incoming_call)
1639    };
1640    update_user_contacts(called_user_id, &session).await?;
1641
1642    let mut calls = session
1643        .connection_pool()
1644        .await
1645        .user_connection_ids(called_user_id)
1646        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1647        .collect::<FuturesUnordered<_>>();
1648
1649    while let Some(call_response) = calls.next().await {
1650        match call_response.as_ref() {
1651            Ok(_) => {
1652                response.send(proto::Ack {})?;
1653                return Ok(());
1654            }
1655            Err(_) => {
1656                call_response.trace_err();
1657            }
1658        }
1659    }
1660
1661    {
1662        let room = session
1663            .db()
1664            .await
1665            .call_failed(room_id, called_user_id)
1666            .await?;
1667        room_updated(&room, &session.peer);
1668    }
1669    update_user_contacts(called_user_id, &session).await?;
1670
1671    Err(anyhow!("failed to ring user"))?
1672}
1673
1674/// Cancel an outgoing call.
1675async fn cancel_call(
1676    request: proto::CancelCall,
1677    response: Response<proto::CancelCall>,
1678    session: MessageContext,
1679) -> Result<()> {
1680    let called_user_id = UserId::from_proto(request.called_user_id);
1681    let room_id = RoomId::from_proto(request.room_id);
1682    {
1683        let room = session
1684            .db()
1685            .await
1686            .cancel_call(room_id, session.connection_id, called_user_id)
1687            .await?;
1688        room_updated(&room, &session.peer);
1689    }
1690
1691    for connection_id in session
1692        .connection_pool()
1693        .await
1694        .user_connection_ids(called_user_id)
1695    {
1696        session
1697            .peer
1698            .send(
1699                connection_id,
1700                proto::CallCanceled {
1701                    room_id: room_id.to_proto(),
1702                },
1703            )
1704            .trace_err();
1705    }
1706    response.send(proto::Ack {})?;
1707
1708    update_user_contacts(called_user_id, &session).await?;
1709    Ok(())
1710}
1711
1712/// Decline an incoming call.
1713async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1714    let room_id = RoomId::from_proto(message.room_id);
1715    {
1716        let room = session
1717            .db()
1718            .await
1719            .decline_call(Some(room_id), session.user_id())
1720            .await?
1721            .context("declining call")?;
1722        room_updated(&room, &session.peer);
1723    }
1724
1725    for connection_id in session
1726        .connection_pool()
1727        .await
1728        .user_connection_ids(session.user_id())
1729    {
1730        session
1731            .peer
1732            .send(
1733                connection_id,
1734                proto::CallCanceled {
1735                    room_id: room_id.to_proto(),
1736                },
1737            )
1738            .trace_err();
1739    }
1740    update_user_contacts(session.user_id(), &session).await?;
1741    Ok(())
1742}
1743
1744/// Updates other participants in the room with your current location.
1745async fn update_participant_location(
1746    request: proto::UpdateParticipantLocation,
1747    response: Response<proto::UpdateParticipantLocation>,
1748    session: MessageContext,
1749) -> Result<()> {
1750    let room_id = RoomId::from_proto(request.room_id);
1751    let location = request.location.context("invalid location")?;
1752
1753    let db = session.db().await;
1754    let room = db
1755        .update_room_participant_location(room_id, session.connection_id, location)
1756        .await?;
1757
1758    room_updated(&room, &session.peer);
1759    response.send(proto::Ack {})?;
1760    Ok(())
1761}
1762
1763/// Share a project into the room.
1764async fn share_project(
1765    request: proto::ShareProject,
1766    response: Response<proto::ShareProject>,
1767    session: MessageContext,
1768) -> Result<()> {
1769    let (project_id, room) = &*session
1770        .db()
1771        .await
1772        .share_project(
1773            RoomId::from_proto(request.room_id),
1774            session.connection_id,
1775            &request.worktrees,
1776            request.is_ssh_project,
1777            request.windows_paths.unwrap_or(false),
1778        )
1779        .await?;
1780    response.send(proto::ShareProjectResponse {
1781        project_id: project_id.to_proto(),
1782    })?;
1783    room_updated(room, &session.peer);
1784
1785    Ok(())
1786}
1787
1788/// Unshare a project from the room.
1789async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1790    let project_id = ProjectId::from_proto(message.project_id);
1791    unshare_project_internal(project_id, session.connection_id, &session).await
1792}
1793
1794async fn unshare_project_internal(
1795    project_id: ProjectId,
1796    connection_id: ConnectionId,
1797    session: &Session,
1798) -> Result<()> {
1799    let delete = {
1800        let room_guard = session
1801            .db()
1802            .await
1803            .unshare_project(project_id, connection_id)
1804            .await?;
1805
1806        let (delete, room, guest_connection_ids) = &*room_guard;
1807
1808        let message = proto::UnshareProject {
1809            project_id: project_id.to_proto(),
1810        };
1811
1812        broadcast(
1813            Some(connection_id),
1814            guest_connection_ids.iter().copied(),
1815            |conn_id| session.peer.send(conn_id, message.clone()),
1816        );
1817        if let Some(room) = room {
1818            room_updated(room, &session.peer);
1819        }
1820
1821        *delete
1822    };
1823
1824    if delete {
1825        let db = session.db().await;
1826        db.delete_project(project_id).await?;
1827    }
1828
1829    Ok(())
1830}
1831
1832/// Join someone elses shared project.
1833async fn join_project(
1834    request: proto::JoinProject,
1835    response: Response<proto::JoinProject>,
1836    session: MessageContext,
1837) -> Result<()> {
1838    let project_id = ProjectId::from_proto(request.project_id);
1839
1840    tracing::info!(%project_id, "join project");
1841
1842    let db = session.db().await;
1843    let (project, replica_id) = &mut *db
1844        .join_project(
1845            project_id,
1846            session.connection_id,
1847            session.user_id(),
1848            request.committer_name.clone(),
1849            request.committer_email.clone(),
1850        )
1851        .await?;
1852    drop(db);
1853    tracing::info!(%project_id, "join remote project");
1854    let collaborators = project
1855        .collaborators
1856        .iter()
1857        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1858        .map(|collaborator| collaborator.to_proto())
1859        .collect::<Vec<_>>();
1860    let project_id = project.id;
1861    let guest_user_id = session.user_id();
1862
1863    let worktrees = project
1864        .worktrees
1865        .iter()
1866        .map(|(id, worktree)| proto::WorktreeMetadata {
1867            id: *id,
1868            root_name: worktree.root_name.clone(),
1869            visible: worktree.visible,
1870            abs_path: worktree.abs_path.clone(),
1871        })
1872        .collect::<Vec<_>>();
1873
1874    let add_project_collaborator = proto::AddProjectCollaborator {
1875        project_id: project_id.to_proto(),
1876        collaborator: Some(proto::Collaborator {
1877            peer_id: Some(session.connection_id.into()),
1878            replica_id: replica_id.0 as u32,
1879            user_id: guest_user_id.to_proto(),
1880            is_host: false,
1881            committer_name: request.committer_name.clone(),
1882            committer_email: request.committer_email.clone(),
1883        }),
1884    };
1885
1886    for collaborator in &collaborators {
1887        session
1888            .peer
1889            .send(
1890                collaborator.peer_id.unwrap().into(),
1891                add_project_collaborator.clone(),
1892            )
1893            .trace_err();
1894    }
1895
1896    // First, we send the metadata associated with each worktree.
1897    let (language_servers, language_server_capabilities) = project
1898        .language_servers
1899        .clone()
1900        .into_iter()
1901        .map(|server| (server.server, server.capabilities))
1902        .unzip();
1903    response.send(proto::JoinProjectResponse {
1904        project_id: project.id.0 as u64,
1905        worktrees,
1906        replica_id: replica_id.0 as u32,
1907        collaborators,
1908        language_servers,
1909        language_server_capabilities,
1910        role: project.role.into(),
1911        windows_paths: project.path_style == PathStyle::Windows,
1912    })?;
1913
1914    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1915        // Stream this worktree's entries.
1916        let message = proto::UpdateWorktree {
1917            project_id: project_id.to_proto(),
1918            worktree_id,
1919            abs_path: worktree.abs_path.clone(),
1920            root_name: worktree.root_name,
1921            updated_entries: worktree.entries,
1922            removed_entries: Default::default(),
1923            scan_id: worktree.scan_id,
1924            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1925            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1926            removed_repositories: Default::default(),
1927        };
1928        for update in proto::split_worktree_update(message) {
1929            session.peer.send(session.connection_id, update.clone())?;
1930        }
1931
1932        // Stream this worktree's diagnostics.
1933        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1934        if let Some(summary) = worktree_diagnostics.next() {
1935            let message = proto::UpdateDiagnosticSummary {
1936                project_id: project.id.to_proto(),
1937                worktree_id: worktree.id,
1938                summary: Some(summary),
1939                more_summaries: worktree_diagnostics.collect(),
1940            };
1941            session.peer.send(session.connection_id, message)?;
1942        }
1943
1944        for settings_file in worktree.settings_files {
1945            session.peer.send(
1946                session.connection_id,
1947                proto::UpdateWorktreeSettings {
1948                    project_id: project_id.to_proto(),
1949                    worktree_id: worktree.id,
1950                    path: settings_file.path,
1951                    content: Some(settings_file.content),
1952                    kind: Some(settings_file.kind.to_proto() as i32),
1953                    outside_worktree: Some(settings_file.outside_worktree),
1954                },
1955            )?;
1956        }
1957    }
1958
1959    for repository in mem::take(&mut project.repositories) {
1960        for update in split_repository_update(repository) {
1961            session.peer.send(session.connection_id, update)?;
1962        }
1963    }
1964
1965    for language_server in &project.language_servers {
1966        session.peer.send(
1967            session.connection_id,
1968            proto::UpdateLanguageServer {
1969                project_id: project_id.to_proto(),
1970                server_name: Some(language_server.server.name.clone()),
1971                language_server_id: language_server.server.id,
1972                variant: Some(
1973                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1974                        proto::LspDiskBasedDiagnosticsUpdated {},
1975                    ),
1976                ),
1977            },
1978        )?;
1979    }
1980
1981    Ok(())
1982}
1983
1984/// Leave someone elses shared project.
1985async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
1986    let sender_id = session.connection_id;
1987    let project_id = ProjectId::from_proto(request.project_id);
1988    let db = session.db().await;
1989
1990    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1991    tracing::info!(
1992        %project_id,
1993        "leave project"
1994    );
1995
1996    project_left(project, &session);
1997    if let Some(room) = room {
1998        room_updated(room, &session.peer);
1999    }
2000
2001    Ok(())
2002}
2003
2004/// Updates other participants with changes to the project
2005async fn update_project(
2006    request: proto::UpdateProject,
2007    response: Response<proto::UpdateProject>,
2008    session: MessageContext,
2009) -> Result<()> {
2010    let project_id = ProjectId::from_proto(request.project_id);
2011    let (room, guest_connection_ids) = &*session
2012        .db()
2013        .await
2014        .update_project(project_id, session.connection_id, &request.worktrees)
2015        .await?;
2016    broadcast(
2017        Some(session.connection_id),
2018        guest_connection_ids.iter().copied(),
2019        |connection_id| {
2020            session
2021                .peer
2022                .forward_send(session.connection_id, connection_id, request.clone())
2023        },
2024    );
2025    if let Some(room) = room {
2026        room_updated(room, &session.peer);
2027    }
2028    response.send(proto::Ack {})?;
2029
2030    Ok(())
2031}
2032
2033/// Updates other participants with changes to the worktree
2034async fn update_worktree(
2035    request: proto::UpdateWorktree,
2036    response: Response<proto::UpdateWorktree>,
2037    session: MessageContext,
2038) -> Result<()> {
2039    let guest_connection_ids = session
2040        .db()
2041        .await
2042        .update_worktree(&request, session.connection_id)
2043        .await?;
2044
2045    broadcast(
2046        Some(session.connection_id),
2047        guest_connection_ids.iter().copied(),
2048        |connection_id| {
2049            session
2050                .peer
2051                .forward_send(session.connection_id, connection_id, request.clone())
2052        },
2053    );
2054    response.send(proto::Ack {})?;
2055    Ok(())
2056}
2057
2058async fn update_repository(
2059    request: proto::UpdateRepository,
2060    response: Response<proto::UpdateRepository>,
2061    session: MessageContext,
2062) -> Result<()> {
2063    let guest_connection_ids = session
2064        .db()
2065        .await
2066        .update_repository(&request, session.connection_id)
2067        .await?;
2068
2069    broadcast(
2070        Some(session.connection_id),
2071        guest_connection_ids.iter().copied(),
2072        |connection_id| {
2073            session
2074                .peer
2075                .forward_send(session.connection_id, connection_id, request.clone())
2076        },
2077    );
2078    response.send(proto::Ack {})?;
2079    Ok(())
2080}
2081
2082async fn remove_repository(
2083    request: proto::RemoveRepository,
2084    response: Response<proto::RemoveRepository>,
2085    session: MessageContext,
2086) -> Result<()> {
2087    let guest_connection_ids = session
2088        .db()
2089        .await
2090        .remove_repository(&request, session.connection_id)
2091        .await?;
2092
2093    broadcast(
2094        Some(session.connection_id),
2095        guest_connection_ids.iter().copied(),
2096        |connection_id| {
2097            session
2098                .peer
2099                .forward_send(session.connection_id, connection_id, request.clone())
2100        },
2101    );
2102    response.send(proto::Ack {})?;
2103    Ok(())
2104}
2105
2106/// Updates other participants with changes to the diagnostics
2107async fn update_diagnostic_summary(
2108    message: proto::UpdateDiagnosticSummary,
2109    session: MessageContext,
2110) -> Result<()> {
2111    let guest_connection_ids = session
2112        .db()
2113        .await
2114        .update_diagnostic_summary(&message, session.connection_id)
2115        .await?;
2116
2117    broadcast(
2118        Some(session.connection_id),
2119        guest_connection_ids.iter().copied(),
2120        |connection_id| {
2121            session
2122                .peer
2123                .forward_send(session.connection_id, connection_id, message.clone())
2124        },
2125    );
2126
2127    Ok(())
2128}
2129
2130/// Updates other participants with changes to the worktree settings
2131async fn update_worktree_settings(
2132    message: proto::UpdateWorktreeSettings,
2133    session: MessageContext,
2134) -> Result<()> {
2135    let guest_connection_ids = session
2136        .db()
2137        .await
2138        .update_worktree_settings(&message, session.connection_id)
2139        .await?;
2140
2141    broadcast(
2142        Some(session.connection_id),
2143        guest_connection_ids.iter().copied(),
2144        |connection_id| {
2145            session
2146                .peer
2147                .forward_send(session.connection_id, connection_id, message.clone())
2148        },
2149    );
2150
2151    Ok(())
2152}
2153
2154/// Notify other participants that a language server has started.
2155async fn start_language_server(
2156    request: proto::StartLanguageServer,
2157    session: MessageContext,
2158) -> Result<()> {
2159    let guest_connection_ids = session
2160        .db()
2161        .await
2162        .start_language_server(&request, session.connection_id)
2163        .await?;
2164
2165    broadcast(
2166        Some(session.connection_id),
2167        guest_connection_ids.iter().copied(),
2168        |connection_id| {
2169            session
2170                .peer
2171                .forward_send(session.connection_id, connection_id, request.clone())
2172        },
2173    );
2174    Ok(())
2175}
2176
2177/// Notify other participants that a language server has changed.
2178async fn update_language_server(
2179    request: proto::UpdateLanguageServer,
2180    session: MessageContext,
2181) -> Result<()> {
2182    let project_id = ProjectId::from_proto(request.project_id);
2183    let db = session.db().await;
2184
2185    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2186        && let Some(capabilities) = update.capabilities.clone()
2187    {
2188        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2189            .await?;
2190    }
2191
2192    let project_connection_ids = db
2193        .project_connection_ids(project_id, session.connection_id, true)
2194        .await?;
2195    broadcast(
2196        Some(session.connection_id),
2197        project_connection_ids.iter().copied(),
2198        |connection_id| {
2199            session
2200                .peer
2201                .forward_send(session.connection_id, connection_id, request.clone())
2202        },
2203    );
2204    Ok(())
2205}
2206
2207/// forward a project request to the host. These requests should be read only
2208/// as guests are allowed to send them.
2209async fn forward_read_only_project_request<T>(
2210    request: T,
2211    response: Response<T>,
2212    session: MessageContext,
2213) -> Result<()>
2214where
2215    T: EntityMessage + RequestMessage,
2216{
2217    let project_id = ProjectId::from_proto(request.remote_entity_id());
2218    let host_connection_id = session
2219        .db()
2220        .await
2221        .host_for_read_only_project_request(project_id, session.connection_id)
2222        .await?;
2223    let payload = session.forward_request(host_connection_id, request).await?;
2224    response.send(payload)?;
2225    Ok(())
2226}
2227
2228/// forward a project request to the host. These requests are disallowed
2229/// for guests.
2230async fn forward_mutating_project_request<T>(
2231    request: T,
2232    response: Response<T>,
2233    session: MessageContext,
2234) -> Result<()>
2235where
2236    T: EntityMessage + RequestMessage,
2237{
2238    let project_id = ProjectId::from_proto(request.remote_entity_id());
2239
2240    let host_connection_id = session
2241        .db()
2242        .await
2243        .host_for_mutating_project_request(project_id, session.connection_id)
2244        .await?;
2245    let payload = session.forward_request(host_connection_id, request).await?;
2246    response.send(payload)?;
2247    Ok(())
2248}
2249
2250async fn disallow_guest_request<T>(
2251    _request: T,
2252    response: Response<T>,
2253    _session: MessageContext,
2254) -> Result<()>
2255where
2256    T: RequestMessage,
2257{
2258    response.peer.respond_with_error(
2259        response.receipt,
2260        ErrorCode::Forbidden
2261            .message("request is not allowed for guests".to_string())
2262            .to_proto(),
2263    )?;
2264    response.responded.store(true, SeqCst);
2265    Ok(())
2266}
2267
2268async fn lsp_query(
2269    request: proto::LspQuery,
2270    response: Response<proto::LspQuery>,
2271    session: MessageContext,
2272) -> Result<()> {
2273    let (name, should_write) = request.query_name_and_write_permissions();
2274    tracing::Span::current().record("lsp_query_request", name);
2275    tracing::info!("lsp_query message received");
2276    if should_write {
2277        forward_mutating_project_request(request, response, session).await
2278    } else {
2279        forward_read_only_project_request(request, response, session).await
2280    }
2281}
2282
2283/// Notify other participants that a new buffer has been created
2284async fn create_buffer_for_peer(
2285    request: proto::CreateBufferForPeer,
2286    session: MessageContext,
2287) -> Result<()> {
2288    session
2289        .db()
2290        .await
2291        .check_user_is_project_host(
2292            ProjectId::from_proto(request.project_id),
2293            session.connection_id,
2294        )
2295        .await?;
2296    let peer_id = request.peer_id.context("invalid peer id")?;
2297    session
2298        .peer
2299        .forward_send(session.connection_id, peer_id.into(), request)?;
2300    Ok(())
2301}
2302
2303/// Notify other participants that a new image has been created
2304async fn create_image_for_peer(
2305    request: proto::CreateImageForPeer,
2306    session: MessageContext,
2307) -> Result<()> {
2308    session
2309        .db()
2310        .await
2311        .check_user_is_project_host(
2312            ProjectId::from_proto(request.project_id),
2313            session.connection_id,
2314        )
2315        .await?;
2316    let peer_id = request.peer_id.context("invalid peer id")?;
2317    session
2318        .peer
2319        .forward_send(session.connection_id, peer_id.into(), request)?;
2320    Ok(())
2321}
2322
2323/// Notify other participants that a buffer has been updated. This is
2324/// allowed for guests as long as the update is limited to selections.
2325async fn update_buffer(
2326    request: proto::UpdateBuffer,
2327    response: Response<proto::UpdateBuffer>,
2328    session: MessageContext,
2329) -> Result<()> {
2330    let project_id = ProjectId::from_proto(request.project_id);
2331    let mut capability = Capability::ReadOnly;
2332
2333    for op in request.operations.iter() {
2334        match op.variant {
2335            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2336            Some(_) => capability = Capability::ReadWrite,
2337        }
2338    }
2339
2340    let host = {
2341        let guard = session
2342            .db()
2343            .await
2344            .connections_for_buffer_update(project_id, session.connection_id, capability)
2345            .await?;
2346
2347        let (host, guests) = &*guard;
2348
2349        broadcast(
2350            Some(session.connection_id),
2351            guests.clone(),
2352            |connection_id| {
2353                session
2354                    .peer
2355                    .forward_send(session.connection_id, connection_id, request.clone())
2356            },
2357        );
2358
2359        *host
2360    };
2361
2362    if host != session.connection_id {
2363        session.forward_request(host, request.clone()).await?;
2364    }
2365
2366    response.send(proto::Ack {})?;
2367    Ok(())
2368}
2369
2370async fn forward_project_search_chunk(
2371    message: proto::FindSearchCandidatesChunk,
2372    response: Response<proto::FindSearchCandidatesChunk>,
2373    session: MessageContext,
2374) -> Result<()> {
2375    let peer_id = message.peer_id.context("missing peer_id")?;
2376    let payload = session
2377        .peer
2378        .forward_request(session.connection_id, peer_id.into(), message)
2379        .await?;
2380    response.send(payload)?;
2381    Ok(())
2382}
2383
2384/// Notify other participants that a project has been updated.
2385async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2386    request: T,
2387    session: MessageContext,
2388) -> Result<()> {
2389    let project_id = ProjectId::from_proto(request.remote_entity_id());
2390    let project_connection_ids = session
2391        .db()
2392        .await
2393        .project_connection_ids(project_id, session.connection_id, false)
2394        .await?;
2395
2396    broadcast(
2397        Some(session.connection_id),
2398        project_connection_ids.iter().copied(),
2399        |connection_id| {
2400            session
2401                .peer
2402                .forward_send(session.connection_id, connection_id, request.clone())
2403        },
2404    );
2405    Ok(())
2406}
2407
2408/// Start following another user in a call.
2409async fn follow(
2410    request: proto::Follow,
2411    response: Response<proto::Follow>,
2412    session: MessageContext,
2413) -> Result<()> {
2414    let room_id = RoomId::from_proto(request.room_id);
2415    let project_id = request.project_id.map(ProjectId::from_proto);
2416    let leader_id = request.leader_id.context("invalid leader id")?.into();
2417    let follower_id = session.connection_id;
2418
2419    session
2420        .db()
2421        .await
2422        .check_room_participants(room_id, leader_id, session.connection_id)
2423        .await?;
2424
2425    let response_payload = session.forward_request(leader_id, request).await?;
2426    response.send(response_payload)?;
2427
2428    if let Some(project_id) = project_id {
2429        let room = session
2430            .db()
2431            .await
2432            .follow(room_id, project_id, leader_id, follower_id)
2433            .await?;
2434        room_updated(&room, &session.peer);
2435    }
2436
2437    Ok(())
2438}
2439
2440/// Stop following another user in a call.
2441async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2442    let room_id = RoomId::from_proto(request.room_id);
2443    let project_id = request.project_id.map(ProjectId::from_proto);
2444    let leader_id = request.leader_id.context("invalid leader id")?.into();
2445    let follower_id = session.connection_id;
2446
2447    session
2448        .db()
2449        .await
2450        .check_room_participants(room_id, leader_id, session.connection_id)
2451        .await?;
2452
2453    session
2454        .peer
2455        .forward_send(session.connection_id, leader_id, request)?;
2456
2457    if let Some(project_id) = project_id {
2458        let room = session
2459            .db()
2460            .await
2461            .unfollow(room_id, project_id, leader_id, follower_id)
2462            .await?;
2463        room_updated(&room, &session.peer);
2464    }
2465
2466    Ok(())
2467}
2468
2469/// Notify everyone following you of your current location.
2470async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2471    let room_id = RoomId::from_proto(request.room_id);
2472    let database = session.db.lock().await;
2473
2474    let connection_ids = if let Some(project_id) = request.project_id {
2475        let project_id = ProjectId::from_proto(project_id);
2476        database
2477            .project_connection_ids(project_id, session.connection_id, true)
2478            .await?
2479    } else {
2480        database
2481            .room_connection_ids(room_id, session.connection_id)
2482            .await?
2483    };
2484
2485    // For now, don't send view update messages back to that view's current leader.
2486    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2487        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2488        _ => None,
2489    });
2490
2491    for connection_id in connection_ids.iter().cloned() {
2492        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2493            session
2494                .peer
2495                .forward_send(session.connection_id, connection_id, request.clone())?;
2496        }
2497    }
2498    Ok(())
2499}
2500
2501/// Get public data about users.
2502async fn get_users(
2503    request: proto::GetUsers,
2504    response: Response<proto::GetUsers>,
2505    session: MessageContext,
2506) -> Result<()> {
2507    let user_ids = request
2508        .user_ids
2509        .into_iter()
2510        .map(UserId::from_proto)
2511        .collect();
2512    let users = session
2513        .db()
2514        .await
2515        .get_users_by_ids(user_ids)
2516        .await?
2517        .into_iter()
2518        .map(|user| proto::User {
2519            id: user.id.to_proto(),
2520            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2521            github_login: user.github_login,
2522            name: user.name,
2523        })
2524        .collect();
2525    response.send(proto::UsersResponse { users })?;
2526    Ok(())
2527}
2528
2529/// Search for users (to invite) buy Github login
2530async fn fuzzy_search_users(
2531    request: proto::FuzzySearchUsers,
2532    response: Response<proto::FuzzySearchUsers>,
2533    session: MessageContext,
2534) -> Result<()> {
2535    let query = request.query;
2536    let users = match query.len() {
2537        0 => vec![],
2538        1 | 2 => session
2539            .db()
2540            .await
2541            .get_user_by_github_login(&query)
2542            .await?
2543            .into_iter()
2544            .collect(),
2545        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2546    };
2547    let users = users
2548        .into_iter()
2549        .filter(|user| user.id != session.user_id())
2550        .map(|user| proto::User {
2551            id: user.id.to_proto(),
2552            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2553            github_login: user.github_login,
2554            name: user.name,
2555        })
2556        .collect();
2557    response.send(proto::UsersResponse { users })?;
2558    Ok(())
2559}
2560
2561/// Send a contact request to another user.
2562async fn request_contact(
2563    request: proto::RequestContact,
2564    response: Response<proto::RequestContact>,
2565    session: MessageContext,
2566) -> Result<()> {
2567    let requester_id = session.user_id();
2568    let responder_id = UserId::from_proto(request.responder_id);
2569    if requester_id == responder_id {
2570        return Err(anyhow!("cannot add yourself as a contact"))?;
2571    }
2572
2573    let notifications = session
2574        .db()
2575        .await
2576        .send_contact_request(requester_id, responder_id)
2577        .await?;
2578
2579    // Update outgoing contact requests of requester
2580    let mut update = proto::UpdateContacts::default();
2581    update.outgoing_requests.push(responder_id.to_proto());
2582    for connection_id in session
2583        .connection_pool()
2584        .await
2585        .user_connection_ids(requester_id)
2586    {
2587        session.peer.send(connection_id, update.clone())?;
2588    }
2589
2590    // Update incoming contact requests of responder
2591    let mut update = proto::UpdateContacts::default();
2592    update
2593        .incoming_requests
2594        .push(proto::IncomingContactRequest {
2595            requester_id: requester_id.to_proto(),
2596        });
2597    let connection_pool = session.connection_pool().await;
2598    for connection_id in connection_pool.user_connection_ids(responder_id) {
2599        session.peer.send(connection_id, update.clone())?;
2600    }
2601
2602    send_notifications(&connection_pool, &session.peer, notifications);
2603
2604    response.send(proto::Ack {})?;
2605    Ok(())
2606}
2607
2608/// Accept or decline a contact request
2609async fn respond_to_contact_request(
2610    request: proto::RespondToContactRequest,
2611    response: Response<proto::RespondToContactRequest>,
2612    session: MessageContext,
2613) -> Result<()> {
2614    let responder_id = session.user_id();
2615    let requester_id = UserId::from_proto(request.requester_id);
2616    let db = session.db().await;
2617    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2618        db.dismiss_contact_notification(responder_id, requester_id)
2619            .await?;
2620    } else {
2621        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2622
2623        let notifications = db
2624            .respond_to_contact_request(responder_id, requester_id, accept)
2625            .await?;
2626        let requester_busy = db.is_user_busy(requester_id).await?;
2627        let responder_busy = db.is_user_busy(responder_id).await?;
2628
2629        let pool = session.connection_pool().await;
2630        // Update responder with new contact
2631        let mut update = proto::UpdateContacts::default();
2632        if accept {
2633            update
2634                .contacts
2635                .push(contact_for_user(requester_id, requester_busy, &pool));
2636        }
2637        update
2638            .remove_incoming_requests
2639            .push(requester_id.to_proto());
2640        for connection_id in pool.user_connection_ids(responder_id) {
2641            session.peer.send(connection_id, update.clone())?;
2642        }
2643
2644        // Update requester with new contact
2645        let mut update = proto::UpdateContacts::default();
2646        if accept {
2647            update
2648                .contacts
2649                .push(contact_for_user(responder_id, responder_busy, &pool));
2650        }
2651        update
2652            .remove_outgoing_requests
2653            .push(responder_id.to_proto());
2654
2655        for connection_id in pool.user_connection_ids(requester_id) {
2656            session.peer.send(connection_id, update.clone())?;
2657        }
2658
2659        send_notifications(&pool, &session.peer, notifications);
2660    }
2661
2662    response.send(proto::Ack {})?;
2663    Ok(())
2664}
2665
2666/// Remove a contact.
2667async fn remove_contact(
2668    request: proto::RemoveContact,
2669    response: Response<proto::RemoveContact>,
2670    session: MessageContext,
2671) -> Result<()> {
2672    let requester_id = session.user_id();
2673    let responder_id = UserId::from_proto(request.user_id);
2674    let db = session.db().await;
2675    let (contact_accepted, deleted_notification_id) =
2676        db.remove_contact(requester_id, responder_id).await?;
2677
2678    let pool = session.connection_pool().await;
2679    // Update outgoing contact requests of requester
2680    let mut update = proto::UpdateContacts::default();
2681    if contact_accepted {
2682        update.remove_contacts.push(responder_id.to_proto());
2683    } else {
2684        update
2685            .remove_outgoing_requests
2686            .push(responder_id.to_proto());
2687    }
2688    for connection_id in pool.user_connection_ids(requester_id) {
2689        session.peer.send(connection_id, update.clone())?;
2690    }
2691
2692    // Update incoming contact requests of responder
2693    let mut update = proto::UpdateContacts::default();
2694    if contact_accepted {
2695        update.remove_contacts.push(requester_id.to_proto());
2696    } else {
2697        update
2698            .remove_incoming_requests
2699            .push(requester_id.to_proto());
2700    }
2701    for connection_id in pool.user_connection_ids(responder_id) {
2702        session.peer.send(connection_id, update.clone())?;
2703        if let Some(notification_id) = deleted_notification_id {
2704            session.peer.send(
2705                connection_id,
2706                proto::DeleteNotification {
2707                    notification_id: notification_id.to_proto(),
2708                },
2709            )?;
2710        }
2711    }
2712
2713    response.send(proto::Ack {})?;
2714    Ok(())
2715}
2716
2717fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2718    version.0.minor < 139
2719}
2720
2721async fn subscribe_to_channels(
2722    _: proto::SubscribeToChannels,
2723    session: MessageContext,
2724) -> Result<()> {
2725    subscribe_user_to_channels(session.user_id(), &session).await?;
2726    Ok(())
2727}
2728
2729async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2730    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2731    let mut pool = session.connection_pool().await;
2732    for membership in &channels_for_user.channel_memberships {
2733        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2734    }
2735    session.peer.send(
2736        session.connection_id,
2737        build_update_user_channels(&channels_for_user),
2738    )?;
2739    session.peer.send(
2740        session.connection_id,
2741        build_channels_update(channels_for_user),
2742    )?;
2743    Ok(())
2744}
2745
2746/// Creates a new channel.
2747async fn create_channel(
2748    request: proto::CreateChannel,
2749    response: Response<proto::CreateChannel>,
2750    session: MessageContext,
2751) -> Result<()> {
2752    let db = session.db().await;
2753
2754    let parent_id = request.parent_id.map(ChannelId::from_proto);
2755    let (channel, membership) = db
2756        .create_channel(&request.name, parent_id, session.user_id())
2757        .await?;
2758
2759    let root_id = channel.root_id();
2760    let channel = Channel::from_model(channel);
2761
2762    response.send(proto::CreateChannelResponse {
2763        channel: Some(channel.to_proto()),
2764        parent_id: request.parent_id,
2765    })?;
2766
2767    let mut connection_pool = session.connection_pool().await;
2768    if let Some(membership) = membership {
2769        connection_pool.subscribe_to_channel(
2770            membership.user_id,
2771            membership.channel_id,
2772            membership.role,
2773        );
2774        let update = proto::UpdateUserChannels {
2775            channel_memberships: vec![proto::ChannelMembership {
2776                channel_id: membership.channel_id.to_proto(),
2777                role: membership.role.into(),
2778            }],
2779            ..Default::default()
2780        };
2781        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2782            session.peer.send(connection_id, update.clone())?;
2783        }
2784    }
2785
2786    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2787        if !role.can_see_channel(channel.visibility) {
2788            continue;
2789        }
2790
2791        let update = proto::UpdateChannels {
2792            channels: vec![channel.to_proto()],
2793            ..Default::default()
2794        };
2795        session.peer.send(connection_id, update.clone())?;
2796    }
2797
2798    Ok(())
2799}
2800
2801/// Delete a channel
2802async fn delete_channel(
2803    request: proto::DeleteChannel,
2804    response: Response<proto::DeleteChannel>,
2805    session: MessageContext,
2806) -> Result<()> {
2807    let db = session.db().await;
2808
2809    let channel_id = request.channel_id;
2810    let (root_channel, removed_channels) = db
2811        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2812        .await?;
2813    response.send(proto::Ack {})?;
2814
2815    // Notify members of removed channels
2816    let mut update = proto::UpdateChannels::default();
2817    update
2818        .delete_channels
2819        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2820
2821    let connection_pool = session.connection_pool().await;
2822    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2823        session.peer.send(connection_id, update.clone())?;
2824    }
2825
2826    Ok(())
2827}
2828
2829/// Invite someone to join a channel.
2830async fn invite_channel_member(
2831    request: proto::InviteChannelMember,
2832    response: Response<proto::InviteChannelMember>,
2833    session: MessageContext,
2834) -> Result<()> {
2835    let db = session.db().await;
2836    let channel_id = ChannelId::from_proto(request.channel_id);
2837    let invitee_id = UserId::from_proto(request.user_id);
2838    let InviteMemberResult {
2839        channel,
2840        notifications,
2841    } = db
2842        .invite_channel_member(
2843            channel_id,
2844            invitee_id,
2845            session.user_id(),
2846            request.role().into(),
2847        )
2848        .await?;
2849
2850    let update = proto::UpdateChannels {
2851        channel_invitations: vec![channel.to_proto()],
2852        ..Default::default()
2853    };
2854
2855    let connection_pool = session.connection_pool().await;
2856    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2857        session.peer.send(connection_id, update.clone())?;
2858    }
2859
2860    send_notifications(&connection_pool, &session.peer, notifications);
2861
2862    response.send(proto::Ack {})?;
2863    Ok(())
2864}
2865
2866/// remove someone from a channel
2867async fn remove_channel_member(
2868    request: proto::RemoveChannelMember,
2869    response: Response<proto::RemoveChannelMember>,
2870    session: MessageContext,
2871) -> Result<()> {
2872    let db = session.db().await;
2873    let channel_id = ChannelId::from_proto(request.channel_id);
2874    let member_id = UserId::from_proto(request.user_id);
2875
2876    let RemoveChannelMemberResult {
2877        membership_update,
2878        notification_id,
2879    } = db
2880        .remove_channel_member(channel_id, member_id, session.user_id())
2881        .await?;
2882
2883    let mut connection_pool = session.connection_pool().await;
2884    notify_membership_updated(
2885        &mut connection_pool,
2886        membership_update,
2887        member_id,
2888        &session.peer,
2889    );
2890    for connection_id in connection_pool.user_connection_ids(member_id) {
2891        if let Some(notification_id) = notification_id {
2892            session
2893                .peer
2894                .send(
2895                    connection_id,
2896                    proto::DeleteNotification {
2897                        notification_id: notification_id.to_proto(),
2898                    },
2899                )
2900                .trace_err();
2901        }
2902    }
2903
2904    response.send(proto::Ack {})?;
2905    Ok(())
2906}
2907
2908/// Toggle the channel between public and private.
2909/// Care is taken to maintain the invariant that public channels only descend from public channels,
2910/// (though members-only channels can appear at any point in the hierarchy).
2911async fn set_channel_visibility(
2912    request: proto::SetChannelVisibility,
2913    response: Response<proto::SetChannelVisibility>,
2914    session: MessageContext,
2915) -> Result<()> {
2916    let db = session.db().await;
2917    let channel_id = ChannelId::from_proto(request.channel_id);
2918    let visibility = request.visibility().into();
2919
2920    let channel_model = db
2921        .set_channel_visibility(channel_id, visibility, session.user_id())
2922        .await?;
2923    let root_id = channel_model.root_id();
2924    let channel = Channel::from_model(channel_model);
2925
2926    let mut connection_pool = session.connection_pool().await;
2927    for (user_id, role) in connection_pool
2928        .channel_user_ids(root_id)
2929        .collect::<Vec<_>>()
2930        .into_iter()
2931    {
2932        let update = if role.can_see_channel(channel.visibility) {
2933            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2934            proto::UpdateChannels {
2935                channels: vec![channel.to_proto()],
2936                ..Default::default()
2937            }
2938        } else {
2939            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2940            proto::UpdateChannels {
2941                delete_channels: vec![channel.id.to_proto()],
2942                ..Default::default()
2943            }
2944        };
2945
2946        for connection_id in connection_pool.user_connection_ids(user_id) {
2947            session.peer.send(connection_id, update.clone())?;
2948        }
2949    }
2950
2951    response.send(proto::Ack {})?;
2952    Ok(())
2953}
2954
2955/// Alter the role for a user in the channel.
2956async fn set_channel_member_role(
2957    request: proto::SetChannelMemberRole,
2958    response: Response<proto::SetChannelMemberRole>,
2959    session: MessageContext,
2960) -> Result<()> {
2961    let db = session.db().await;
2962    let channel_id = ChannelId::from_proto(request.channel_id);
2963    let member_id = UserId::from_proto(request.user_id);
2964    let result = db
2965        .set_channel_member_role(
2966            channel_id,
2967            session.user_id(),
2968            member_id,
2969            request.role().into(),
2970        )
2971        .await?;
2972
2973    match result {
2974        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2975            let mut connection_pool = session.connection_pool().await;
2976            notify_membership_updated(
2977                &mut connection_pool,
2978                membership_update,
2979                member_id,
2980                &session.peer,
2981            )
2982        }
2983        db::SetMemberRoleResult::InviteUpdated(channel) => {
2984            let update = proto::UpdateChannels {
2985                channel_invitations: vec![channel.to_proto()],
2986                ..Default::default()
2987            };
2988
2989            for connection_id in session
2990                .connection_pool()
2991                .await
2992                .user_connection_ids(member_id)
2993            {
2994                session.peer.send(connection_id, update.clone())?;
2995            }
2996        }
2997    }
2998
2999    response.send(proto::Ack {})?;
3000    Ok(())
3001}
3002
3003/// Change the name of a channel
3004async fn rename_channel(
3005    request: proto::RenameChannel,
3006    response: Response<proto::RenameChannel>,
3007    session: MessageContext,
3008) -> Result<()> {
3009    let db = session.db().await;
3010    let channel_id = ChannelId::from_proto(request.channel_id);
3011    let channel_model = db
3012        .rename_channel(channel_id, session.user_id(), &request.name)
3013        .await?;
3014    let root_id = channel_model.root_id();
3015    let channel = Channel::from_model(channel_model);
3016
3017    response.send(proto::RenameChannelResponse {
3018        channel: Some(channel.to_proto()),
3019    })?;
3020
3021    let connection_pool = session.connection_pool().await;
3022    let update = proto::UpdateChannels {
3023        channels: vec![channel.to_proto()],
3024        ..Default::default()
3025    };
3026    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3027        if role.can_see_channel(channel.visibility) {
3028            session.peer.send(connection_id, update.clone())?;
3029        }
3030    }
3031
3032    Ok(())
3033}
3034
3035/// Move a channel to a new parent.
3036async fn move_channel(
3037    request: proto::MoveChannel,
3038    response: Response<proto::MoveChannel>,
3039    session: MessageContext,
3040) -> Result<()> {
3041    let channel_id = ChannelId::from_proto(request.channel_id);
3042    let to = ChannelId::from_proto(request.to);
3043
3044    let (root_id, channels) = session
3045        .db()
3046        .await
3047        .move_channel(channel_id, to, session.user_id())
3048        .await?;
3049
3050    let connection_pool = session.connection_pool().await;
3051    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3052        let channels = channels
3053            .iter()
3054            .filter_map(|channel| {
3055                if role.can_see_channel(channel.visibility) {
3056                    Some(channel.to_proto())
3057                } else {
3058                    None
3059                }
3060            })
3061            .collect::<Vec<_>>();
3062        if channels.is_empty() {
3063            continue;
3064        }
3065
3066        let update = proto::UpdateChannels {
3067            channels,
3068            ..Default::default()
3069        };
3070
3071        session.peer.send(connection_id, update.clone())?;
3072    }
3073
3074    response.send(Ack {})?;
3075    Ok(())
3076}
3077
3078async fn reorder_channel(
3079    request: proto::ReorderChannel,
3080    response: Response<proto::ReorderChannel>,
3081    session: MessageContext,
3082) -> Result<()> {
3083    let channel_id = ChannelId::from_proto(request.channel_id);
3084    let direction = request.direction();
3085
3086    let updated_channels = session
3087        .db()
3088        .await
3089        .reorder_channel(channel_id, direction, session.user_id())
3090        .await?;
3091
3092    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3093        let connection_pool = session.connection_pool().await;
3094        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3095            let channels = updated_channels
3096                .iter()
3097                .filter_map(|channel| {
3098                    if role.can_see_channel(channel.visibility) {
3099                        Some(channel.to_proto())
3100                    } else {
3101                        None
3102                    }
3103                })
3104                .collect::<Vec<_>>();
3105
3106            if channels.is_empty() {
3107                continue;
3108            }
3109
3110            let update = proto::UpdateChannels {
3111                channels,
3112                ..Default::default()
3113            };
3114
3115            session.peer.send(connection_id, update.clone())?;
3116        }
3117    }
3118
3119    response.send(Ack {})?;
3120    Ok(())
3121}
3122
3123/// Get the list of channel members
3124async fn get_channel_members(
3125    request: proto::GetChannelMembers,
3126    response: Response<proto::GetChannelMembers>,
3127    session: MessageContext,
3128) -> Result<()> {
3129    let db = session.db().await;
3130    let channel_id = ChannelId::from_proto(request.channel_id);
3131    let limit = if request.limit == 0 {
3132        u16::MAX as u64
3133    } else {
3134        request.limit
3135    };
3136    let (members, users) = db
3137        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3138        .await?;
3139    response.send(proto::GetChannelMembersResponse { members, users })?;
3140    Ok(())
3141}
3142
3143/// Accept or decline a channel invitation.
3144async fn respond_to_channel_invite(
3145    request: proto::RespondToChannelInvite,
3146    response: Response<proto::RespondToChannelInvite>,
3147    session: MessageContext,
3148) -> Result<()> {
3149    let db = session.db().await;
3150    let channel_id = ChannelId::from_proto(request.channel_id);
3151    let RespondToChannelInvite {
3152        membership_update,
3153        notifications,
3154    } = db
3155        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3156        .await?;
3157
3158    let mut connection_pool = session.connection_pool().await;
3159    if let Some(membership_update) = membership_update {
3160        notify_membership_updated(
3161            &mut connection_pool,
3162            membership_update,
3163            session.user_id(),
3164            &session.peer,
3165        );
3166    } else {
3167        let update = proto::UpdateChannels {
3168            remove_channel_invitations: vec![channel_id.to_proto()],
3169            ..Default::default()
3170        };
3171
3172        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3173            session.peer.send(connection_id, update.clone())?;
3174        }
3175    };
3176
3177    send_notifications(&connection_pool, &session.peer, notifications);
3178
3179    response.send(proto::Ack {})?;
3180
3181    Ok(())
3182}
3183
3184/// Join the channels' room
3185async fn join_channel(
3186    request: proto::JoinChannel,
3187    response: Response<proto::JoinChannel>,
3188    session: MessageContext,
3189) -> Result<()> {
3190    let channel_id = ChannelId::from_proto(request.channel_id);
3191    join_channel_internal(channel_id, Box::new(response), session).await
3192}
3193
3194trait JoinChannelInternalResponse {
3195    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3196}
3197impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3198    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3199        Response::<proto::JoinChannel>::send(self, result)
3200    }
3201}
3202impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3203    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3204        Response::<proto::JoinRoom>::send(self, result)
3205    }
3206}
3207
3208async fn join_channel_internal(
3209    channel_id: ChannelId,
3210    response: Box<impl JoinChannelInternalResponse>,
3211    session: MessageContext,
3212) -> Result<()> {
3213    let joined_room = {
3214        let mut db = session.db().await;
3215        // If zed quits without leaving the room, and the user re-opens zed before the
3216        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3217        // room they were in.
3218        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3219            tracing::info!(
3220                stale_connection_id = %connection,
3221                "cleaning up stale connection",
3222            );
3223            drop(db);
3224            leave_room_for_session(&session, connection).await?;
3225            db = session.db().await;
3226        }
3227
3228        let (joined_room, membership_updated, role) = db
3229            .join_channel(channel_id, session.user_id(), session.connection_id)
3230            .await?;
3231
3232        let live_kit_connection_info =
3233            session
3234                .app_state
3235                .livekit_client
3236                .as_ref()
3237                .and_then(|live_kit| {
3238                    let (can_publish, token) = if role == ChannelRole::Guest {
3239                        (
3240                            false,
3241                            live_kit
3242                                .guest_token(
3243                                    &joined_room.room.livekit_room,
3244                                    &session.user_id().to_string(),
3245                                )
3246                                .trace_err()?,
3247                        )
3248                    } else {
3249                        (
3250                            true,
3251                            live_kit
3252                                .room_token(
3253                                    &joined_room.room.livekit_room,
3254                                    &session.user_id().to_string(),
3255                                )
3256                                .trace_err()?,
3257                        )
3258                    };
3259
3260                    Some(LiveKitConnectionInfo {
3261                        server_url: live_kit.url().into(),
3262                        token,
3263                        can_publish,
3264                    })
3265                });
3266
3267        response.send(proto::JoinRoomResponse {
3268            room: Some(joined_room.room.clone()),
3269            channel_id: joined_room
3270                .channel
3271                .as_ref()
3272                .map(|channel| channel.id.to_proto()),
3273            live_kit_connection_info,
3274        })?;
3275
3276        let mut connection_pool = session.connection_pool().await;
3277        if let Some(membership_updated) = membership_updated {
3278            notify_membership_updated(
3279                &mut connection_pool,
3280                membership_updated,
3281                session.user_id(),
3282                &session.peer,
3283            );
3284        }
3285
3286        room_updated(&joined_room.room, &session.peer);
3287
3288        joined_room
3289    };
3290
3291    channel_updated(
3292        &joined_room.channel.context("channel not returned")?,
3293        &joined_room.room,
3294        &session.peer,
3295        &*session.connection_pool().await,
3296    );
3297
3298    update_user_contacts(session.user_id(), &session).await?;
3299    Ok(())
3300}
3301
3302/// Start editing the channel notes
3303async fn join_channel_buffer(
3304    request: proto::JoinChannelBuffer,
3305    response: Response<proto::JoinChannelBuffer>,
3306    session: MessageContext,
3307) -> Result<()> {
3308    let db = session.db().await;
3309    let channel_id = ChannelId::from_proto(request.channel_id);
3310
3311    let open_response = db
3312        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3313        .await?;
3314
3315    let collaborators = open_response.collaborators.clone();
3316    response.send(open_response)?;
3317
3318    let update = UpdateChannelBufferCollaborators {
3319        channel_id: channel_id.to_proto(),
3320        collaborators: collaborators.clone(),
3321    };
3322    channel_buffer_updated(
3323        session.connection_id,
3324        collaborators
3325            .iter()
3326            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3327        &update,
3328        &session.peer,
3329    );
3330
3331    Ok(())
3332}
3333
3334/// Edit the channel notes
3335async fn update_channel_buffer(
3336    request: proto::UpdateChannelBuffer,
3337    session: MessageContext,
3338) -> Result<()> {
3339    let db = session.db().await;
3340    let channel_id = ChannelId::from_proto(request.channel_id);
3341
3342    let (collaborators, epoch, version) = db
3343        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3344        .await?;
3345
3346    channel_buffer_updated(
3347        session.connection_id,
3348        collaborators.clone(),
3349        &proto::UpdateChannelBuffer {
3350            channel_id: channel_id.to_proto(),
3351            operations: request.operations,
3352        },
3353        &session.peer,
3354    );
3355
3356    let pool = &*session.connection_pool().await;
3357
3358    let non_collaborators =
3359        pool.channel_connection_ids(channel_id)
3360            .filter_map(|(connection_id, _)| {
3361                if collaborators.contains(&connection_id) {
3362                    None
3363                } else {
3364                    Some(connection_id)
3365                }
3366            });
3367
3368    broadcast(None, non_collaborators, |peer_id| {
3369        session.peer.send(
3370            peer_id,
3371            proto::UpdateChannels {
3372                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3373                    channel_id: channel_id.to_proto(),
3374                    epoch: epoch as u64,
3375                    version: version.clone(),
3376                }],
3377                ..Default::default()
3378            },
3379        )
3380    });
3381
3382    Ok(())
3383}
3384
3385/// Rejoin the channel notes after a connection blip
3386async fn rejoin_channel_buffers(
3387    request: proto::RejoinChannelBuffers,
3388    response: Response<proto::RejoinChannelBuffers>,
3389    session: MessageContext,
3390) -> Result<()> {
3391    let db = session.db().await;
3392    let buffers = db
3393        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3394        .await?;
3395
3396    for rejoined_buffer in &buffers {
3397        let collaborators_to_notify = rejoined_buffer
3398            .buffer
3399            .collaborators
3400            .iter()
3401            .filter_map(|c| Some(c.peer_id?.into()));
3402        channel_buffer_updated(
3403            session.connection_id,
3404            collaborators_to_notify,
3405            &proto::UpdateChannelBufferCollaborators {
3406                channel_id: rejoined_buffer.buffer.channel_id,
3407                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3408            },
3409            &session.peer,
3410        );
3411    }
3412
3413    response.send(proto::RejoinChannelBuffersResponse {
3414        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3415    })?;
3416
3417    Ok(())
3418}
3419
3420/// Stop editing the channel notes
3421async fn leave_channel_buffer(
3422    request: proto::LeaveChannelBuffer,
3423    response: Response<proto::LeaveChannelBuffer>,
3424    session: MessageContext,
3425) -> Result<()> {
3426    let db = session.db().await;
3427    let channel_id = ChannelId::from_proto(request.channel_id);
3428
3429    let left_buffer = db
3430        .leave_channel_buffer(channel_id, session.connection_id)
3431        .await?;
3432
3433    response.send(Ack {})?;
3434
3435    channel_buffer_updated(
3436        session.connection_id,
3437        left_buffer.connections,
3438        &proto::UpdateChannelBufferCollaborators {
3439            channel_id: channel_id.to_proto(),
3440            collaborators: left_buffer.collaborators,
3441        },
3442        &session.peer,
3443    );
3444
3445    Ok(())
3446}
3447
3448fn channel_buffer_updated<T: EnvelopedMessage>(
3449    sender_id: ConnectionId,
3450    collaborators: impl IntoIterator<Item = ConnectionId>,
3451    message: &T,
3452    peer: &Peer,
3453) {
3454    broadcast(Some(sender_id), collaborators, |peer_id| {
3455        peer.send(peer_id, message.clone())
3456    });
3457}
3458
3459fn send_notifications(
3460    connection_pool: &ConnectionPool,
3461    peer: &Peer,
3462    notifications: db::NotificationBatch,
3463) {
3464    for (user_id, notification) in notifications {
3465        for connection_id in connection_pool.user_connection_ids(user_id) {
3466            if let Err(error) = peer.send(
3467                connection_id,
3468                proto::AddNotification {
3469                    notification: Some(notification.clone()),
3470                },
3471            ) {
3472                tracing::error!(
3473                    "failed to send notification to {:?} {}",
3474                    connection_id,
3475                    error
3476                );
3477            }
3478        }
3479    }
3480}
3481
3482/// Send a message to the channel
3483async fn send_channel_message(
3484    _request: proto::SendChannelMessage,
3485    _response: Response<proto::SendChannelMessage>,
3486    _session: MessageContext,
3487) -> Result<()> {
3488    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3489}
3490
3491/// Delete a channel message
3492async fn remove_channel_message(
3493    _request: proto::RemoveChannelMessage,
3494    _response: Response<proto::RemoveChannelMessage>,
3495    _session: MessageContext,
3496) -> Result<()> {
3497    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3498}
3499
3500async fn update_channel_message(
3501    _request: proto::UpdateChannelMessage,
3502    _response: Response<proto::UpdateChannelMessage>,
3503    _session: MessageContext,
3504) -> Result<()> {
3505    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3506}
3507
3508/// Mark a channel message as read
3509async fn acknowledge_channel_message(
3510    _request: proto::AckChannelMessage,
3511    _session: MessageContext,
3512) -> Result<()> {
3513    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3514}
3515
3516/// Mark a buffer version as synced
3517async fn acknowledge_buffer_version(
3518    request: proto::AckBufferOperation,
3519    session: MessageContext,
3520) -> Result<()> {
3521    let buffer_id = BufferId::from_proto(request.buffer_id);
3522    session
3523        .db()
3524        .await
3525        .observe_buffer_version(
3526            buffer_id,
3527            session.user_id(),
3528            request.epoch as i32,
3529            &request.version,
3530        )
3531        .await?;
3532    Ok(())
3533}
3534
3535/// Start receiving chat updates for a channel
3536async fn join_channel_chat(
3537    _request: proto::JoinChannelChat,
3538    _response: Response<proto::JoinChannelChat>,
3539    _session: MessageContext,
3540) -> Result<()> {
3541    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3542}
3543
3544/// Stop receiving chat updates for a channel
3545async fn leave_channel_chat(
3546    _request: proto::LeaveChannelChat,
3547    _session: MessageContext,
3548) -> Result<()> {
3549    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3550}
3551
3552/// Retrieve the chat history for a channel
3553async fn get_channel_messages(
3554    _request: proto::GetChannelMessages,
3555    _response: Response<proto::GetChannelMessages>,
3556    _session: MessageContext,
3557) -> Result<()> {
3558    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3559}
3560
3561/// Retrieve specific chat messages
3562async fn get_channel_messages_by_id(
3563    _request: proto::GetChannelMessagesById,
3564    _response: Response<proto::GetChannelMessagesById>,
3565    _session: MessageContext,
3566) -> Result<()> {
3567    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3568}
3569
3570/// Retrieve the current users notifications
3571async fn get_notifications(
3572    request: proto::GetNotifications,
3573    response: Response<proto::GetNotifications>,
3574    session: MessageContext,
3575) -> Result<()> {
3576    let notifications = session
3577        .db()
3578        .await
3579        .get_notifications(
3580            session.user_id(),
3581            NOTIFICATION_COUNT_PER_PAGE,
3582            request.before_id.map(db::NotificationId::from_proto),
3583        )
3584        .await?;
3585    response.send(proto::GetNotificationsResponse {
3586        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3587        notifications,
3588    })?;
3589    Ok(())
3590}
3591
3592/// Mark notifications as read
3593async fn mark_notification_as_read(
3594    request: proto::MarkNotificationRead,
3595    response: Response<proto::MarkNotificationRead>,
3596    session: MessageContext,
3597) -> Result<()> {
3598    let database = &session.db().await;
3599    let notifications = database
3600        .mark_notification_as_read_by_id(
3601            session.user_id(),
3602            NotificationId::from_proto(request.notification_id),
3603        )
3604        .await?;
3605    send_notifications(
3606        &*session.connection_pool().await,
3607        &session.peer,
3608        notifications,
3609    );
3610    response.send(proto::Ack {})?;
3611    Ok(())
3612}
3613
3614fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3615    let message = match message {
3616        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3617        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3618        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3619        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3620        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3621            code: frame.code.into(),
3622            reason: frame.reason.as_str().to_owned().into(),
3623        })),
3624        // We should never receive a frame while reading the message, according
3625        // to the `tungstenite` maintainers:
3626        //
3627        // > It cannot occur when you read messages from the WebSocket, but it
3628        // > can be used when you want to send the raw frames (e.g. you want to
3629        // > send the frames to the WebSocket without composing the full message first).
3630        // >
3631        // > — https://github.com/snapview/tungstenite-rs/issues/268
3632        TungsteniteMessage::Frame(_) => {
3633            bail!("received an unexpected frame while reading the message")
3634        }
3635    };
3636
3637    Ok(message)
3638}
3639
3640fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3641    match message {
3642        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3643        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3644        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3645        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3646        AxumMessage::Close(frame) => {
3647            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3648                code: frame.code.into(),
3649                reason: frame.reason.as_ref().into(),
3650            }))
3651        }
3652    }
3653}
3654
3655fn notify_membership_updated(
3656    connection_pool: &mut ConnectionPool,
3657    result: MembershipUpdated,
3658    user_id: UserId,
3659    peer: &Peer,
3660) {
3661    for membership in &result.new_channels.channel_memberships {
3662        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3663    }
3664    for channel_id in &result.removed_channels {
3665        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3666    }
3667
3668    let user_channels_update = proto::UpdateUserChannels {
3669        channel_memberships: result
3670            .new_channels
3671            .channel_memberships
3672            .iter()
3673            .map(|cm| proto::ChannelMembership {
3674                channel_id: cm.channel_id.to_proto(),
3675                role: cm.role.into(),
3676            })
3677            .collect(),
3678        ..Default::default()
3679    };
3680
3681    let mut update = build_channels_update(result.new_channels);
3682    update.delete_channels = result
3683        .removed_channels
3684        .into_iter()
3685        .map(|id| id.to_proto())
3686        .collect();
3687    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3688
3689    for connection_id in connection_pool.user_connection_ids(user_id) {
3690        peer.send(connection_id, user_channels_update.clone())
3691            .trace_err();
3692        peer.send(connection_id, update.clone()).trace_err();
3693    }
3694}
3695
3696fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3697    proto::UpdateUserChannels {
3698        channel_memberships: channels
3699            .channel_memberships
3700            .iter()
3701            .map(|m| proto::ChannelMembership {
3702                channel_id: m.channel_id.to_proto(),
3703                role: m.role.into(),
3704            })
3705            .collect(),
3706        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3707    }
3708}
3709
3710fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3711    let mut update = proto::UpdateChannels::default();
3712
3713    for channel in channels.channels {
3714        update.channels.push(channel.to_proto());
3715    }
3716
3717    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3718
3719    for (channel_id, participants) in channels.channel_participants {
3720        update
3721            .channel_participants
3722            .push(proto::ChannelParticipants {
3723                channel_id: channel_id.to_proto(),
3724                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3725            });
3726    }
3727
3728    for channel in channels.invited_channels {
3729        update.channel_invitations.push(channel.to_proto());
3730    }
3731
3732    update
3733}
3734
3735fn build_initial_contacts_update(
3736    contacts: Vec<db::Contact>,
3737    pool: &ConnectionPool,
3738) -> proto::UpdateContacts {
3739    let mut update = proto::UpdateContacts::default();
3740
3741    for contact in contacts {
3742        match contact {
3743            db::Contact::Accepted { user_id, busy } => {
3744                update.contacts.push(contact_for_user(user_id, busy, pool));
3745            }
3746            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3747            db::Contact::Incoming { user_id } => {
3748                update
3749                    .incoming_requests
3750                    .push(proto::IncomingContactRequest {
3751                        requester_id: user_id.to_proto(),
3752                    })
3753            }
3754        }
3755    }
3756
3757    update
3758}
3759
3760fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3761    proto::Contact {
3762        user_id: user_id.to_proto(),
3763        online: pool.is_user_online(user_id),
3764        busy,
3765    }
3766}
3767
3768fn room_updated(room: &proto::Room, peer: &Peer) {
3769    broadcast(
3770        None,
3771        room.participants
3772            .iter()
3773            .filter_map(|participant| Some(participant.peer_id?.into())),
3774        |peer_id| {
3775            peer.send(
3776                peer_id,
3777                proto::RoomUpdated {
3778                    room: Some(room.clone()),
3779                },
3780            )
3781        },
3782    );
3783}
3784
3785fn channel_updated(
3786    channel: &db::channel::Model,
3787    room: &proto::Room,
3788    peer: &Peer,
3789    pool: &ConnectionPool,
3790) {
3791    let participants = room
3792        .participants
3793        .iter()
3794        .map(|p| p.user_id)
3795        .collect::<Vec<_>>();
3796
3797    broadcast(
3798        None,
3799        pool.channel_connection_ids(channel.root_id())
3800            .filter_map(|(channel_id, role)| {
3801                role.can_see_channel(channel.visibility)
3802                    .then_some(channel_id)
3803            }),
3804        |peer_id| {
3805            peer.send(
3806                peer_id,
3807                proto::UpdateChannels {
3808                    channel_participants: vec![proto::ChannelParticipants {
3809                        channel_id: channel.id.to_proto(),
3810                        participant_user_ids: participants.clone(),
3811                    }],
3812                    ..Default::default()
3813                },
3814            )
3815        },
3816    );
3817}
3818
3819async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3820    let db = session.db().await;
3821
3822    let contacts = db.get_contacts(user_id).await?;
3823    let busy = db.is_user_busy(user_id).await?;
3824
3825    let pool = session.connection_pool().await;
3826    let updated_contact = contact_for_user(user_id, busy, &pool);
3827    for contact in contacts {
3828        if let db::Contact::Accepted {
3829            user_id: contact_user_id,
3830            ..
3831        } = contact
3832        {
3833            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3834                session
3835                    .peer
3836                    .send(
3837                        contact_conn_id,
3838                        proto::UpdateContacts {
3839                            contacts: vec![updated_contact.clone()],
3840                            remove_contacts: Default::default(),
3841                            incoming_requests: Default::default(),
3842                            remove_incoming_requests: Default::default(),
3843                            outgoing_requests: Default::default(),
3844                            remove_outgoing_requests: Default::default(),
3845                        },
3846                    )
3847                    .trace_err();
3848            }
3849        }
3850    }
3851    Ok(())
3852}
3853
3854async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3855    let mut contacts_to_update = HashSet::default();
3856
3857    let room_id;
3858    let canceled_calls_to_user_ids;
3859    let livekit_room;
3860    let delete_livekit_room;
3861    let room;
3862    let channel;
3863
3864    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3865        contacts_to_update.insert(session.user_id());
3866
3867        for project in left_room.left_projects.values() {
3868            project_left(project, session);
3869        }
3870
3871        room_id = RoomId::from_proto(left_room.room.id);
3872        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3873        livekit_room = mem::take(&mut left_room.room.livekit_room);
3874        delete_livekit_room = left_room.deleted;
3875        room = mem::take(&mut left_room.room);
3876        channel = mem::take(&mut left_room.channel);
3877
3878        room_updated(&room, &session.peer);
3879    } else {
3880        return Ok(());
3881    }
3882
3883    if let Some(channel) = channel {
3884        channel_updated(
3885            &channel,
3886            &room,
3887            &session.peer,
3888            &*session.connection_pool().await,
3889        );
3890    }
3891
3892    {
3893        let pool = session.connection_pool().await;
3894        for canceled_user_id in canceled_calls_to_user_ids {
3895            for connection_id in pool.user_connection_ids(canceled_user_id) {
3896                session
3897                    .peer
3898                    .send(
3899                        connection_id,
3900                        proto::CallCanceled {
3901                            room_id: room_id.to_proto(),
3902                        },
3903                    )
3904                    .trace_err();
3905            }
3906            contacts_to_update.insert(canceled_user_id);
3907        }
3908    }
3909
3910    for contact_user_id in contacts_to_update {
3911        update_user_contacts(contact_user_id, session).await?;
3912    }
3913
3914    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3915        live_kit
3916            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3917            .await
3918            .trace_err();
3919
3920        if delete_livekit_room {
3921            live_kit.delete_room(livekit_room).await.trace_err();
3922        }
3923    }
3924
3925    Ok(())
3926}
3927
3928async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3929    let left_channel_buffers = session
3930        .db()
3931        .await
3932        .leave_channel_buffers(session.connection_id)
3933        .await?;
3934
3935    for left_buffer in left_channel_buffers {
3936        channel_buffer_updated(
3937            session.connection_id,
3938            left_buffer.connections,
3939            &proto::UpdateChannelBufferCollaborators {
3940                channel_id: left_buffer.channel_id.to_proto(),
3941                collaborators: left_buffer.collaborators,
3942            },
3943            &session.peer,
3944        );
3945    }
3946
3947    Ok(())
3948}
3949
3950fn project_left(project: &db::LeftProject, session: &Session) {
3951    for connection_id in &project.connection_ids {
3952        if project.should_unshare {
3953            session
3954                .peer
3955                .send(
3956                    *connection_id,
3957                    proto::UnshareProject {
3958                        project_id: project.id.to_proto(),
3959                    },
3960                )
3961                .trace_err();
3962        } else {
3963            session
3964                .peer
3965                .send(
3966                    *connection_id,
3967                    proto::RemoveProjectCollaborator {
3968                        project_id: project.id.to_proto(),
3969                        peer_id: Some(session.connection_id.into()),
3970                    },
3971                )
3972                .trace_err();
3973        }
3974    }
3975}
3976
3977async fn share_agent_thread(
3978    request: proto::ShareAgentThread,
3979    response: Response<proto::ShareAgentThread>,
3980    session: MessageContext,
3981) -> Result<()> {
3982    let user_id = session.user_id();
3983
3984    let share_id = SharedThreadId::from_proto(request.session_id.clone())
3985        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
3986
3987    session
3988        .db()
3989        .await
3990        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
3991        .await?;
3992
3993    response.send(proto::Ack {})?;
3994
3995    Ok(())
3996}
3997
3998async fn get_shared_agent_thread(
3999    request: proto::GetSharedAgentThread,
4000    response: Response<proto::GetSharedAgentThread>,
4001    session: MessageContext,
4002) -> Result<()> {
4003    let share_id = SharedThreadId::from_proto(request.session_id)
4004        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4005
4006    let result = session.db().await.get_shared_thread(share_id).await?;
4007
4008    match result {
4009        Some((thread, username)) => {
4010            response.send(proto::GetSharedAgentThreadResponse {
4011                title: thread.title,
4012                thread_data: thread.data,
4013                sharer_username: username,
4014                created_at: thread.created_at.and_utc().to_rfc3339(),
4015            })?;
4016        }
4017        None => {
4018            return Err(anyhow!("Shared thread not found").into());
4019        }
4020    }
4021
4022    Ok(())
4023}
4024
4025pub trait ResultExt {
4026    type Ok;
4027
4028    fn trace_err(self) -> Option<Self::Ok>;
4029}
4030
4031impl<T, E> ResultExt for Result<T, E>
4032where
4033    E: std::fmt::Debug,
4034{
4035    type Ok = T;
4036
4037    #[track_caller]
4038    fn trace_err(self) -> Option<T> {
4039        match self {
4040            Ok(value) => Some(value),
4041            Err(error) => {
4042                tracing::error!("{:?}", error);
4043                None
4044            }
4045        }
4046    }
4047}