rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 415            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 416            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 417            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 419            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 421            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 422            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 423            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 426            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 427            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 430            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 431            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 437            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 438            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 439            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 440            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 441            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 442            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 443            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 444            .add_request_handler(share_agent_thread)
 445            .add_request_handler(get_shared_agent_thread)
 446            .add_request_handler(forward_project_search_chunk);
 447
 448        Arc::new(server)
 449    }
 450
 451    pub async fn start(&self) -> Result<()> {
 452        let server_id = *self.id.lock();
 453        let app_state = self.app_state.clone();
 454        let peer = self.peer.clone();
 455        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 456        let pool = self.connection_pool.clone();
 457        let livekit_client = self.app_state.livekit_client.clone();
 458
 459        let span = info_span!("start server");
 460        self.app_state.executor.spawn_detached(
 461            async move {
 462                tracing::info!("waiting for cleanup timeout");
 463                timeout.await;
 464                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 465
 466                app_state
 467                    .db
 468                    .delete_stale_channel_chat_participants(
 469                        &app_state.config.zed_environment,
 470                        server_id,
 471                    )
 472                    .await
 473                    .trace_err();
 474
 475                if let Some((room_ids, channel_ids)) = app_state
 476                    .db
 477                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 478                    .await
 479                    .trace_err()
 480                {
 481                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 482                    tracing::info!(
 483                        stale_channel_buffer_count = channel_ids.len(),
 484                        "retrieved stale channel buffers"
 485                    );
 486
 487                    for channel_id in channel_ids {
 488                        if let Some(refreshed_channel_buffer) = app_state
 489                            .db
 490                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 491                            .await
 492                            .trace_err()
 493                        {
 494                            for connection_id in refreshed_channel_buffer.connection_ids {
 495                                peer.send(
 496                                    connection_id,
 497                                    proto::UpdateChannelBufferCollaborators {
 498                                        channel_id: channel_id.to_proto(),
 499                                        collaborators: refreshed_channel_buffer
 500                                            .collaborators
 501                                            .clone(),
 502                                    },
 503                                )
 504                                .trace_err();
 505                            }
 506                        }
 507                    }
 508
 509                    for room_id in room_ids {
 510                        let mut contacts_to_update = HashSet::default();
 511                        let mut canceled_calls_to_user_ids = Vec::new();
 512                        let mut livekit_room = String::new();
 513                        let mut delete_livekit_room = false;
 514
 515                        if let Some(mut refreshed_room) = app_state
 516                            .db
 517                            .clear_stale_room_participants(room_id, server_id)
 518                            .await
 519                            .trace_err()
 520                        {
 521                            tracing::info!(
 522                                room_id = room_id.0,
 523                                new_participant_count = refreshed_room.room.participants.len(),
 524                                "refreshed room"
 525                            );
 526                            room_updated(&refreshed_room.room, &peer);
 527                            if let Some(channel) = refreshed_room.channel.as_ref() {
 528                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 529                            }
 530                            contacts_to_update
 531                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 532                            contacts_to_update
 533                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 534                            canceled_calls_to_user_ids =
 535                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 536                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 537                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 538                        }
 539
 540                        {
 541                            let pool = pool.lock();
 542                            for canceled_user_id in canceled_calls_to_user_ids {
 543                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 544                                    peer.send(
 545                                        connection_id,
 546                                        proto::CallCanceled {
 547                                            room_id: room_id.to_proto(),
 548                                        },
 549                                    )
 550                                    .trace_err();
 551                                }
 552                            }
 553                        }
 554
 555                        for user_id in contacts_to_update {
 556                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 557                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 558                            if let Some((busy, contacts)) = busy.zip(contacts) {
 559                                let pool = pool.lock();
 560                                let updated_contact = contact_for_user(user_id, busy, &pool);
 561                                for contact in contacts {
 562                                    if let db::Contact::Accepted {
 563                                        user_id: contact_user_id,
 564                                        ..
 565                                    } = contact
 566                                    {
 567                                        for contact_conn_id in
 568                                            pool.user_connection_ids(contact_user_id)
 569                                        {
 570                                            peer.send(
 571                                                contact_conn_id,
 572                                                proto::UpdateContacts {
 573                                                    contacts: vec![updated_contact.clone()],
 574                                                    remove_contacts: Default::default(),
 575                                                    incoming_requests: Default::default(),
 576                                                    remove_incoming_requests: Default::default(),
 577                                                    outgoing_requests: Default::default(),
 578                                                    remove_outgoing_requests: Default::default(),
 579                                                },
 580                                            )
 581                                            .trace_err();
 582                                        }
 583                                    }
 584                                }
 585                            }
 586                        }
 587
 588                        if let Some(live_kit) = livekit_client.as_ref()
 589                            && delete_livekit_room
 590                        {
 591                            live_kit.delete_room(livekit_room).await.trace_err();
 592                        }
 593                    }
 594                }
 595
 596                app_state
 597                    .db
 598                    .delete_stale_channel_chat_participants(
 599                        &app_state.config.zed_environment,
 600                        server_id,
 601                    )
 602                    .await
 603                    .trace_err();
 604
 605                app_state
 606                    .db
 607                    .clear_old_worktree_entries(server_id)
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 614                    .await
 615                    .trace_err();
 616            }
 617            .instrument(span),
 618        );
 619        Ok(())
 620    }
 621
 622    pub fn teardown(&self) {
 623        self.peer.teardown();
 624        self.connection_pool.lock().reset();
 625        let _ = self.teardown.send(true);
 626    }
 627
 628    #[cfg(feature = "test-support")]
 629    pub fn reset(&self, id: ServerId) {
 630        self.teardown();
 631        *self.id.lock() = id;
 632        self.peer.reset(id.0 as u32);
 633        let _ = self.teardown.send(false);
 634    }
 635
 636    #[cfg(feature = "test-support")]
 637    pub fn id(&self) -> ServerId {
 638        *self.id.lock()
 639    }
 640
 641    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 642    where
 643        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 644        Fut: 'static + Send + Future<Output = Result<()>>,
 645        M: EnvelopedMessage,
 646    {
 647        let prev_handler = self.handlers.insert(
 648            TypeId::of::<M>(),
 649            Box::new(move |envelope, session, span| {
 650                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 651                let received_at = envelope.received_at;
 652                tracing::info!("message received");
 653                let start_time = Instant::now();
 654                let future = (handler)(
 655                    *envelope,
 656                    MessageContext {
 657                        session,
 658                        span: span.clone(),
 659                    },
 660                );
 661                async move {
 662                    let result = future.await;
 663                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 664                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 665                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 666                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 667                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 668                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 669                    match result {
 670                        Err(error) => {
 671                            tracing::error!(?error, "error handling message")
 672                        }
 673                        Ok(()) => tracing::info!("finished handling message"),
 674                    }
 675                }
 676                .boxed()
 677            }),
 678        );
 679        if prev_handler.is_some() {
 680            panic!("registered a handler for the same message twice");
 681        }
 682        self
 683    }
 684
 685    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 686    where
 687        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 688        Fut: 'static + Send + Future<Output = Result<()>>,
 689        M: EnvelopedMessage,
 690    {
 691        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 692        self
 693    }
 694
 695    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 696    where
 697        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 698        Fut: Send + Future<Output = Result<()>>,
 699        M: RequestMessage,
 700    {
 701        let handler = Arc::new(handler);
 702        self.add_handler(move |envelope, session| {
 703            let receipt = envelope.receipt();
 704            let handler = handler.clone();
 705            async move {
 706                let peer = session.peer.clone();
 707                let responded = Arc::new(AtomicBool::default());
 708                let response = Response {
 709                    peer: peer.clone(),
 710                    responded: responded.clone(),
 711                    receipt,
 712                };
 713                match (handler)(envelope.payload, response, session).await {
 714                    Ok(()) => {
 715                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 716                            Ok(())
 717                        } else {
 718                            Err(anyhow!("handler did not send a response"))?
 719                        }
 720                    }
 721                    Err(error) => {
 722                        let proto_err = match &error {
 723                            Error::Internal(err) => err.to_proto(),
 724                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 725                        };
 726                        peer.respond_with_error(receipt, proto_err)?;
 727                        Err(error)
 728                    }
 729                }
 730            }
 731        })
 732    }
 733
 734    pub fn handle_connection(
 735        self: &Arc<Self>,
 736        connection: Connection,
 737        address: String,
 738        principal: Principal,
 739        zed_version: ZedVersion,
 740        release_channel: Option<String>,
 741        user_agent: Option<String>,
 742        geoip_country_code: Option<String>,
 743        system_id: Option<String>,
 744        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 745        executor: Executor,
 746        connection_guard: Option<ConnectionGuard>,
 747    ) -> impl Future<Output = ()> + use<> {
 748        let this = self.clone();
 749        let span = info_span!("handle connection", %address,
 750            connection_id=field::Empty,
 751            user_id=field::Empty,
 752            login=field::Empty,
 753            user_agent=field::Empty,
 754            geoip_country_code=field::Empty,
 755            release_channel=field::Empty,
 756        );
 757        principal.update_span(&span);
 758        if let Some(user_agent) = user_agent {
 759            span.record("user_agent", user_agent);
 760        }
 761        if let Some(release_channel) = release_channel {
 762            span.record("release_channel", release_channel);
 763        }
 764
 765        if let Some(country_code) = geoip_country_code.as_ref() {
 766            span.record("geoip_country_code", country_code);
 767        }
 768
 769        let mut teardown = self.teardown.subscribe();
 770        async move {
 771            if *teardown.borrow() {
 772                tracing::error!("server is tearing down");
 773                return;
 774            }
 775
 776            let (connection_id, handle_io, mut incoming_rx) =
 777                this.peer.add_connection(connection, {
 778                    let executor = executor.clone();
 779                    move |duration| executor.sleep(duration)
 780                });
 781            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 782
 783            tracing::info!("connection opened");
 784
 785            let session = Session {
 786                principal: principal.clone(),
 787                connection_id,
 788                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 789                peer: this.peer.clone(),
 790                connection_pool: this.connection_pool.clone(),
 791                app_state: this.app_state.clone(),
 792                geoip_country_code,
 793                system_id,
 794                _executor: executor.clone(),
 795            };
 796
 797            if let Err(error) = this
 798                .send_initial_client_update(
 799                    connection_id,
 800                    zed_version,
 801                    send_connection_id,
 802                    &session,
 803                )
 804                .await
 805            {
 806                tracing::error!(?error, "failed to send initial client update");
 807                return;
 808            }
 809            drop(connection_guard);
 810
 811            let handle_io = handle_io.fuse();
 812            futures::pin_mut!(handle_io);
 813
 814            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 815            // This prevents deadlocks when e.g., client A performs a request to client B and
 816            // client B performs a request to client A. If both clients stop processing further
 817            // messages until their respective request completes, they won't have a chance to
 818            // respond to the other client's request and cause a deadlock.
 819            //
 820            // This arrangement ensures we will attempt to process earlier messages first, but fall
 821            // back to processing messages arrived later in the spirit of making progress.
 822            const MAX_CONCURRENT_HANDLERS: usize = 256;
 823            let mut foreground_message_handlers = FuturesUnordered::new();
 824            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 825            let get_concurrent_handlers = {
 826                let concurrent_handlers = concurrent_handlers.clone();
 827                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 828            };
 829            loop {
 830                let next_message = async {
 831                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 832                    let message = incoming_rx.next().await;
 833                    // Cache the concurrent_handlers here, so that we know what the
 834                    // queue looks like as each handler starts
 835                    (permit, message, get_concurrent_handlers())
 836                }
 837                .fuse();
 838                futures::pin_mut!(next_message);
 839                futures::select_biased! {
 840                    _ = teardown.changed().fuse() => return,
 841                    result = handle_io => {
 842                        if let Err(error) = result {
 843                            tracing::error!(?error, "error handling I/O");
 844                        }
 845                        break;
 846                    }
 847                    _ = foreground_message_handlers.next() => {}
 848                    next_message = next_message => {
 849                        let (permit, message, concurrent_handlers) = next_message;
 850                        if let Some(message) = message {
 851                            let type_name = message.payload_type_name();
 852                            // note: we copy all the fields from the parent span so we can query them in the logs.
 853                            // (https://github.com/tokio-rs/tracing/issues/2670).
 854                            let span = tracing::info_span!("receive message",
 855                                %connection_id,
 856                                %address,
 857                                type_name,
 858                                concurrent_handlers,
 859                                user_id=field::Empty,
 860                                login=field::Empty,
 861                                lsp_query_request=field::Empty,
 862                                release_channel=field::Empty,
 863                                { TOTAL_DURATION_MS }=field::Empty,
 864                                { PROCESSING_DURATION_MS }=field::Empty,
 865                                { QUEUE_DURATION_MS }=field::Empty,
 866                                { HOST_WAITING_MS }=field::Empty
 867                            );
 868                            principal.update_span(&span);
 869                            let span_enter = span.enter();
 870                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 871                                let is_background = message.is_background();
 872                                let handle_message = (handler)(message, session.clone(), span.clone());
 873                                drop(span_enter);
 874
 875                                let handle_message = async move {
 876                                    handle_message.await;
 877                                    drop(permit);
 878                                }.instrument(span);
 879                                if is_background {
 880                                    executor.spawn_detached(handle_message);
 881                                } else {
 882                                    foreground_message_handlers.push(handle_message);
 883                                }
 884                            } else {
 885                                tracing::error!("no message handler");
 886                            }
 887                        } else {
 888                            tracing::info!("connection closed");
 889                            break;
 890                        }
 891                    }
 892                }
 893            }
 894
 895            drop(foreground_message_handlers);
 896            let concurrent_handlers = get_concurrent_handlers();
 897            tracing::info!(concurrent_handlers, "signing out");
 898            if let Err(error) = connection_lost(session, teardown, executor).await {
 899                tracing::error!(?error, "error signing out");
 900            }
 901        }
 902        .instrument(span)
 903    }
 904
 905    async fn send_initial_client_update(
 906        &self,
 907        connection_id: ConnectionId,
 908        zed_version: ZedVersion,
 909        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 910        session: &Session,
 911    ) -> Result<()> {
 912        self.peer.send(
 913            connection_id,
 914            proto::Hello {
 915                peer_id: Some(connection_id.into()),
 916            },
 917        )?;
 918        tracing::info!("sent hello message");
 919        if let Some(send_connection_id) = send_connection_id.take() {
 920            let _ = send_connection_id.send(connection_id);
 921        }
 922
 923        match &session.principal {
 924            Principal::User(user) => {
 925                if !user.connected_once {
 926                    self.peer.send(connection_id, proto::ShowContacts {})?;
 927                    self.app_state
 928                        .db
 929                        .set_user_connected_once(user.id, true)
 930                        .await?;
 931                }
 932
 933                let contacts = self.app_state.db.get_contacts(user.id).await?;
 934
 935                {
 936                    let mut pool = self.connection_pool.lock();
 937                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 938                    self.peer.send(
 939                        connection_id,
 940                        build_initial_contacts_update(contacts, &pool),
 941                    )?;
 942                }
 943
 944                if should_auto_subscribe_to_channels(&zed_version) {
 945                    subscribe_user_to_channels(user.id, session).await?;
 946                }
 947
 948                if let Some(incoming_call) =
 949                    self.app_state.db.incoming_call_for_user(user.id).await?
 950                {
 951                    self.peer.send(connection_id, incoming_call)?;
 952                }
 953
 954                update_user_contacts(user.id, session).await?;
 955            }
 956        }
 957
 958        Ok(())
 959    }
 960}
 961
 962impl Deref for ConnectionPoolGuard<'_> {
 963    type Target = ConnectionPool;
 964
 965    fn deref(&self) -> &Self::Target {
 966        &self.guard
 967    }
 968}
 969
 970impl DerefMut for ConnectionPoolGuard<'_> {
 971    fn deref_mut(&mut self) -> &mut Self::Target {
 972        &mut self.guard
 973    }
 974}
 975
 976impl Drop for ConnectionPoolGuard<'_> {
 977    fn drop(&mut self) {
 978        #[cfg(feature = "test-support")]
 979        self.check_invariants();
 980    }
 981}
 982
 983fn broadcast<F>(
 984    sender_id: Option<ConnectionId>,
 985    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 986    mut f: F,
 987) where
 988    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 989{
 990    for receiver_id in receiver_ids {
 991        if Some(receiver_id) != sender_id
 992            && let Err(error) = f(receiver_id)
 993        {
 994            tracing::error!("failed to send to {:?} {}", receiver_id, error);
 995        }
 996    }
 997}
 998
 999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002    fn name() -> &'static HeaderName {
1003        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005    }
1006
1007    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008    where
1009        Self: Sized,
1010        I: Iterator<Item = &'i axum::http::HeaderValue>,
1011    {
1012        let version = values
1013            .next()
1014            .ok_or_else(axum::headers::Error::invalid)?
1015            .to_str()
1016            .map_err(|_| axum::headers::Error::invalid())?
1017            .parse()
1018            .map_err(|_| axum::headers::Error::invalid())?;
1019        Ok(Self(version))
1020    }
1021
1022    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023        values.extend([self.0.to_string().parse().unwrap()]);
1024    }
1025}
1026
1027pub struct AppVersionHeader(Version);
1028impl Header for AppVersionHeader {
1029    fn name() -> &'static HeaderName {
1030        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032    }
1033
1034    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035    where
1036        Self: Sized,
1037        I: Iterator<Item = &'i axum::http::HeaderValue>,
1038    {
1039        let version = values
1040            .next()
1041            .ok_or_else(axum::headers::Error::invalid)?
1042            .to_str()
1043            .map_err(|_| axum::headers::Error::invalid())?
1044            .parse()
1045            .map_err(|_| axum::headers::Error::invalid())?;
1046        Ok(Self(version))
1047    }
1048
1049    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050        values.extend([self.0.to_string().parse().unwrap()]);
1051    }
1052}
1053
1054#[derive(Debug)]
1055pub struct ReleaseChannelHeader(String);
1056
1057impl Header for ReleaseChannelHeader {
1058    fn name() -> &'static HeaderName {
1059        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1060        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1061    }
1062
1063    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1064    where
1065        Self: Sized,
1066        I: Iterator<Item = &'i axum::http::HeaderValue>,
1067    {
1068        Ok(Self(
1069            values
1070                .next()
1071                .ok_or_else(axum::headers::Error::invalid)?
1072                .to_str()
1073                .map_err(|_| axum::headers::Error::invalid())?
1074                .to_owned(),
1075        ))
1076    }
1077
1078    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1079        values.extend([self.0.parse().unwrap()]);
1080    }
1081}
1082
1083pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1084    Router::new()
1085        .route("/rpc", get(handle_websocket_request))
1086        .layer(
1087            ServiceBuilder::new()
1088                .layer(Extension(server.app_state.clone()))
1089                .layer(middleware::from_fn(auth::validate_header)),
1090        )
1091        .route("/metrics", get(handle_metrics))
1092        .layer(Extension(server))
1093}
1094
1095pub async fn handle_websocket_request(
1096    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1097    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1098    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1099    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1100    Extension(server): Extension<Arc<Server>>,
1101    Extension(principal): Extension<Principal>,
1102    user_agent: Option<TypedHeader<UserAgent>>,
1103    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1104    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1105    ws: WebSocketUpgrade,
1106) -> axum::response::Response {
1107    if protocol_version != rpc::PROTOCOL_VERSION {
1108        return (
1109            StatusCode::UPGRADE_REQUIRED,
1110            "client must be upgraded".to_string(),
1111        )
1112            .into_response();
1113    }
1114
1115    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1116        return (
1117            StatusCode::UPGRADE_REQUIRED,
1118            "no version header found".to_string(),
1119        )
1120            .into_response();
1121    };
1122
1123    let release_channel = release_channel_header.map(|header| header.0.0);
1124
1125    if !version.can_collaborate() {
1126        return (
1127            StatusCode::UPGRADE_REQUIRED,
1128            "client must be upgraded".to_string(),
1129        )
1130            .into_response();
1131    }
1132
1133    let socket_address = socket_address.to_string();
1134
1135    // Acquire connection guard before WebSocket upgrade
1136    let connection_guard = match ConnectionGuard::try_acquire() {
1137        Ok(guard) => guard,
1138        Err(()) => {
1139            return (
1140                StatusCode::SERVICE_UNAVAILABLE,
1141                "Too many concurrent connections",
1142            )
1143                .into_response();
1144        }
1145    };
1146
1147    ws.on_upgrade(move |socket| {
1148        let socket = socket
1149            .map_ok(to_tungstenite_message)
1150            .err_into()
1151            .with(|message| async move { to_axum_message(message) });
1152        let connection = Connection::new(Box::pin(socket));
1153        async move {
1154            server
1155                .handle_connection(
1156                    connection,
1157                    socket_address,
1158                    principal,
1159                    version,
1160                    release_channel,
1161                    user_agent.map(|header| header.to_string()),
1162                    country_code_header.map(|header| header.to_string()),
1163                    system_id_header.map(|header| header.to_string()),
1164                    None,
1165                    Executor::Production,
1166                    Some(connection_guard),
1167                )
1168                .await;
1169        }
1170    })
1171}
1172
1173pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1174    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1175    let connections_metric = CONNECTIONS_METRIC
1176        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1177
1178    let connections = server
1179        .connection_pool
1180        .lock()
1181        .connections()
1182        .filter(|connection| !connection.admin)
1183        .count();
1184    connections_metric.set(connections as _);
1185
1186    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1187    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1188        register_int_gauge!(
1189            "shared_projects",
1190            "number of open projects with one or more guests"
1191        )
1192        .unwrap()
1193    });
1194
1195    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1196    shared_projects_metric.set(shared_projects as _);
1197
1198    let encoder = prometheus::TextEncoder::new();
1199    let metric_families = prometheus::gather();
1200    let encoded_metrics = encoder
1201        .encode_to_string(&metric_families)
1202        .map_err(|err| anyhow!("{err}"))?;
1203    Ok(encoded_metrics)
1204}
1205
1206#[instrument(err, skip(executor))]
1207async fn connection_lost(
1208    session: Session,
1209    mut teardown: watch::Receiver<bool>,
1210    executor: Executor,
1211) -> Result<()> {
1212    session.peer.disconnect(session.connection_id);
1213    session
1214        .connection_pool()
1215        .await
1216        .remove_connection(session.connection_id)?;
1217
1218    session
1219        .db()
1220        .await
1221        .connection_lost(session.connection_id)
1222        .await
1223        .trace_err();
1224
1225    futures::select_biased! {
1226        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1227
1228            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1229            leave_room_for_session(&session, session.connection_id).await.trace_err();
1230            leave_channel_buffers_for_session(&session)
1231                .await
1232                .trace_err();
1233
1234            if !session
1235                .connection_pool()
1236                .await
1237                .is_user_online(session.user_id())
1238            {
1239                let db = session.db().await;
1240                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1241                    room_updated(&room, &session.peer);
1242                }
1243            }
1244
1245            update_user_contacts(session.user_id(), &session).await?;
1246        },
1247        _ = teardown.changed().fuse() => {}
1248    }
1249
1250    Ok(())
1251}
1252
1253/// Acknowledges a ping from a client, used to keep the connection alive.
1254async fn ping(
1255    _: proto::Ping,
1256    response: Response<proto::Ping>,
1257    _session: MessageContext,
1258) -> Result<()> {
1259    response.send(proto::Ack {})?;
1260    Ok(())
1261}
1262
1263/// Creates a new room for calling (outside of channels)
1264async fn create_room(
1265    _request: proto::CreateRoom,
1266    response: Response<proto::CreateRoom>,
1267    session: MessageContext,
1268) -> Result<()> {
1269    let livekit_room = nanoid::nanoid!(30);
1270
1271    let live_kit_connection_info = util::maybe!(async {
1272        let live_kit = session.app_state.livekit_client.as_ref();
1273        let live_kit = live_kit?;
1274        let user_id = session.user_id().to_string();
1275
1276        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1277
1278        Some(proto::LiveKitConnectionInfo {
1279            server_url: live_kit.url().into(),
1280            token,
1281            can_publish: true,
1282        })
1283    })
1284    .await;
1285
1286    let room = session
1287        .db()
1288        .await
1289        .create_room(session.user_id(), session.connection_id, &livekit_room)
1290        .await?;
1291
1292    response.send(proto::CreateRoomResponse {
1293        room: Some(room.clone()),
1294        live_kit_connection_info,
1295    })?;
1296
1297    update_user_contacts(session.user_id(), &session).await?;
1298    Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303    request: proto::JoinRoom,
1304    response: Response<proto::JoinRoom>,
1305    session: MessageContext,
1306) -> Result<()> {
1307    let room_id = RoomId::from_proto(request.id);
1308
1309    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311    if let Some(channel_id) = channel_id {
1312        return join_channel_internal(channel_id, Box::new(response), session).await;
1313    }
1314
1315    let joined_room = {
1316        let room = session
1317            .db()
1318            .await
1319            .join_room(room_id, session.user_id(), session.connection_id)
1320            .await?;
1321        room_updated(&room.room, &session.peer);
1322        room.into_inner()
1323    };
1324
1325    for connection_id in session
1326        .connection_pool()
1327        .await
1328        .user_connection_ids(session.user_id())
1329    {
1330        session
1331            .peer
1332            .send(
1333                connection_id,
1334                proto::CallCanceled {
1335                    room_id: room_id.to_proto(),
1336                },
1337            )
1338            .trace_err();
1339    }
1340
1341    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342    {
1343        live_kit
1344            .room_token(
1345                &joined_room.room.livekit_room,
1346                &session.user_id().to_string(),
1347            )
1348            .trace_err()
1349            .map(|token| proto::LiveKitConnectionInfo {
1350                server_url: live_kit.url().into(),
1351                token,
1352                can_publish: true,
1353            })
1354    } else {
1355        None
1356    };
1357
1358    response.send(proto::JoinRoomResponse {
1359        room: Some(joined_room.room),
1360        channel_id: None,
1361        live_kit_connection_info,
1362    })?;
1363
1364    update_user_contacts(session.user_id(), &session).await?;
1365    Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370    request: proto::RejoinRoom,
1371    response: Response<proto::RejoinRoom>,
1372    session: MessageContext,
1373) -> Result<()> {
1374    let room;
1375    let channel;
1376    {
1377        let mut rejoined_room = session
1378            .db()
1379            .await
1380            .rejoin_room(request, session.user_id(), session.connection_id)
1381            .await?;
1382
1383        response.send(proto::RejoinRoomResponse {
1384            room: Some(rejoined_room.room.clone()),
1385            reshared_projects: rejoined_room
1386                .reshared_projects
1387                .iter()
1388                .map(|project| proto::ResharedProject {
1389                    id: project.id.to_proto(),
1390                    collaborators: project
1391                        .collaborators
1392                        .iter()
1393                        .map(|collaborator| collaborator.to_proto())
1394                        .collect(),
1395                })
1396                .collect(),
1397            rejoined_projects: rejoined_room
1398                .rejoined_projects
1399                .iter()
1400                .map(|rejoined_project| rejoined_project.to_proto())
1401                .collect(),
1402        })?;
1403        room_updated(&rejoined_room.room, &session.peer);
1404
1405        for project in &rejoined_room.reshared_projects {
1406            for collaborator in &project.collaborators {
1407                session
1408                    .peer
1409                    .send(
1410                        collaborator.connection_id,
1411                        proto::UpdateProjectCollaborator {
1412                            project_id: project.id.to_proto(),
1413                            old_peer_id: Some(project.old_connection_id.into()),
1414                            new_peer_id: Some(session.connection_id.into()),
1415                        },
1416                    )
1417                    .trace_err();
1418            }
1419
1420            broadcast(
1421                Some(session.connection_id),
1422                project
1423                    .collaborators
1424                    .iter()
1425                    .map(|collaborator| collaborator.connection_id),
1426                |connection_id| {
1427                    session.peer.forward_send(
1428                        session.connection_id,
1429                        connection_id,
1430                        proto::UpdateProject {
1431                            project_id: project.id.to_proto(),
1432                            worktrees: project.worktrees.clone(),
1433                        },
1434                    )
1435                },
1436            );
1437        }
1438
1439        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441        let rejoined_room = rejoined_room.into_inner();
1442
1443        room = rejoined_room.room;
1444        channel = rejoined_room.channel;
1445    }
1446
1447    if let Some(channel) = channel {
1448        channel_updated(
1449            &channel,
1450            &room,
1451            &session.peer,
1452            &*session.connection_pool().await,
1453        );
1454    }
1455
1456    update_user_contacts(session.user_id(), &session).await?;
1457    Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461    rejoined_projects: &mut Vec<RejoinedProject>,
1462    session: &Session,
1463) -> Result<()> {
1464    for project in rejoined_projects.iter() {
1465        for collaborator in &project.collaborators {
1466            session
1467                .peer
1468                .send(
1469                    collaborator.connection_id,
1470                    proto::UpdateProjectCollaborator {
1471                        project_id: project.id.to_proto(),
1472                        old_peer_id: Some(project.old_connection_id.into()),
1473                        new_peer_id: Some(session.connection_id.into()),
1474                    },
1475                )
1476                .trace_err();
1477        }
1478    }
1479
1480    for project in rejoined_projects {
1481        for worktree in mem::take(&mut project.worktrees) {
1482            // Stream this worktree's entries.
1483            let message = proto::UpdateWorktree {
1484                project_id: project.id.to_proto(),
1485                worktree_id: worktree.id,
1486                abs_path: worktree.abs_path.clone(),
1487                root_name: worktree.root_name,
1488                root_repo_common_dir: worktree.root_repo_common_dir,
1489                updated_entries: worktree.updated_entries,
1490                removed_entries: worktree.removed_entries,
1491                scan_id: worktree.scan_id,
1492                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1493                updated_repositories: worktree.updated_repositories,
1494                removed_repositories: worktree.removed_repositories,
1495            };
1496            for update in proto::split_worktree_update(message) {
1497                session.peer.send(session.connection_id, update)?;
1498            }
1499
1500            // Stream this worktree's diagnostics.
1501            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1502            if let Some(summary) = worktree_diagnostics.next() {
1503                let message = proto::UpdateDiagnosticSummary {
1504                    project_id: project.id.to_proto(),
1505                    worktree_id: worktree.id,
1506                    summary: Some(summary),
1507                    more_summaries: worktree_diagnostics.collect(),
1508                };
1509                session.peer.send(session.connection_id, message)?;
1510            }
1511
1512            for settings_file in worktree.settings_files {
1513                session.peer.send(
1514                    session.connection_id,
1515                    proto::UpdateWorktreeSettings {
1516                        project_id: project.id.to_proto(),
1517                        worktree_id: worktree.id,
1518                        path: settings_file.path,
1519                        content: Some(settings_file.content),
1520                        kind: Some(settings_file.kind.to_proto().into()),
1521                        outside_worktree: Some(settings_file.outside_worktree),
1522                    },
1523                )?;
1524            }
1525        }
1526
1527        for repository in mem::take(&mut project.updated_repositories) {
1528            for update in split_repository_update(repository) {
1529                session.peer.send(session.connection_id, update)?;
1530            }
1531        }
1532
1533        for id in mem::take(&mut project.removed_repositories) {
1534            session.peer.send(
1535                session.connection_id,
1536                proto::RemoveRepository {
1537                    project_id: project.id.to_proto(),
1538                    id,
1539                },
1540            )?;
1541        }
1542    }
1543
1544    Ok(())
1545}
1546
1547/// leave room disconnects from the room.
1548async fn leave_room(
1549    _: proto::LeaveRoom,
1550    response: Response<proto::LeaveRoom>,
1551    session: MessageContext,
1552) -> Result<()> {
1553    leave_room_for_session(&session, session.connection_id).await?;
1554    response.send(proto::Ack {})?;
1555    Ok(())
1556}
1557
1558/// Updates the permissions of someone else in the room.
1559async fn set_room_participant_role(
1560    request: proto::SetRoomParticipantRole,
1561    response: Response<proto::SetRoomParticipantRole>,
1562    session: MessageContext,
1563) -> Result<()> {
1564    let user_id = UserId::from_proto(request.user_id);
1565    let role = ChannelRole::from(request.role());
1566
1567    let (livekit_room, can_publish) = {
1568        let room = session
1569            .db()
1570            .await
1571            .set_room_participant_role(
1572                session.user_id(),
1573                RoomId::from_proto(request.room_id),
1574                user_id,
1575                role,
1576            )
1577            .await?;
1578
1579        let livekit_room = room.livekit_room.clone();
1580        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1581        room_updated(&room, &session.peer);
1582        (livekit_room, can_publish)
1583    };
1584
1585    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1586        live_kit
1587            .update_participant(
1588                livekit_room.clone(),
1589                request.user_id.to_string(),
1590                livekit_api::proto::ParticipantPermission {
1591                    can_subscribe: true,
1592                    can_publish,
1593                    can_publish_data: can_publish,
1594                    hidden: false,
1595                    recorder: false,
1596                },
1597            )
1598            .await
1599            .trace_err();
1600    }
1601
1602    response.send(proto::Ack {})?;
1603    Ok(())
1604}
1605
1606/// Call someone else into the current room
1607async fn call(
1608    request: proto::Call,
1609    response: Response<proto::Call>,
1610    session: MessageContext,
1611) -> Result<()> {
1612    let room_id = RoomId::from_proto(request.room_id);
1613    let calling_user_id = session.user_id();
1614    let calling_connection_id = session.connection_id;
1615    let called_user_id = UserId::from_proto(request.called_user_id);
1616    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1617    if !session
1618        .db()
1619        .await
1620        .has_contact(calling_user_id, called_user_id)
1621        .await?
1622    {
1623        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1624    }
1625
1626    let incoming_call = {
1627        let (room, incoming_call) = &mut *session
1628            .db()
1629            .await
1630            .call(
1631                room_id,
1632                calling_user_id,
1633                calling_connection_id,
1634                called_user_id,
1635                initial_project_id,
1636            )
1637            .await?;
1638        room_updated(room, &session.peer);
1639        mem::take(incoming_call)
1640    };
1641    update_user_contacts(called_user_id, &session).await?;
1642
1643    let mut calls = session
1644        .connection_pool()
1645        .await
1646        .user_connection_ids(called_user_id)
1647        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1648        .collect::<FuturesUnordered<_>>();
1649
1650    while let Some(call_response) = calls.next().await {
1651        match call_response.as_ref() {
1652            Ok(_) => {
1653                response.send(proto::Ack {})?;
1654                return Ok(());
1655            }
1656            Err(_) => {
1657                call_response.trace_err();
1658            }
1659        }
1660    }
1661
1662    {
1663        let room = session
1664            .db()
1665            .await
1666            .call_failed(room_id, called_user_id)
1667            .await?;
1668        room_updated(&room, &session.peer);
1669    }
1670    update_user_contacts(called_user_id, &session).await?;
1671
1672    Err(anyhow!("failed to ring user"))?
1673}
1674
1675/// Cancel an outgoing call.
1676async fn cancel_call(
1677    request: proto::CancelCall,
1678    response: Response<proto::CancelCall>,
1679    session: MessageContext,
1680) -> Result<()> {
1681    let called_user_id = UserId::from_proto(request.called_user_id);
1682    let room_id = RoomId::from_proto(request.room_id);
1683    {
1684        let room = session
1685            .db()
1686            .await
1687            .cancel_call(room_id, session.connection_id, called_user_id)
1688            .await?;
1689        room_updated(&room, &session.peer);
1690    }
1691
1692    for connection_id in session
1693        .connection_pool()
1694        .await
1695        .user_connection_ids(called_user_id)
1696    {
1697        session
1698            .peer
1699            .send(
1700                connection_id,
1701                proto::CallCanceled {
1702                    room_id: room_id.to_proto(),
1703                },
1704            )
1705            .trace_err();
1706    }
1707    response.send(proto::Ack {})?;
1708
1709    update_user_contacts(called_user_id, &session).await?;
1710    Ok(())
1711}
1712
1713/// Decline an incoming call.
1714async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1715    let room_id = RoomId::from_proto(message.room_id);
1716    {
1717        let room = session
1718            .db()
1719            .await
1720            .decline_call(Some(room_id), session.user_id())
1721            .await?
1722            .context("declining call")?;
1723        room_updated(&room, &session.peer);
1724    }
1725
1726    for connection_id in session
1727        .connection_pool()
1728        .await
1729        .user_connection_ids(session.user_id())
1730    {
1731        session
1732            .peer
1733            .send(
1734                connection_id,
1735                proto::CallCanceled {
1736                    room_id: room_id.to_proto(),
1737                },
1738            )
1739            .trace_err();
1740    }
1741    update_user_contacts(session.user_id(), &session).await?;
1742    Ok(())
1743}
1744
1745/// Updates other participants in the room with your current location.
1746async fn update_participant_location(
1747    request: proto::UpdateParticipantLocation,
1748    response: Response<proto::UpdateParticipantLocation>,
1749    session: MessageContext,
1750) -> Result<()> {
1751    let room_id = RoomId::from_proto(request.room_id);
1752    let location = request.location.context("invalid location")?;
1753
1754    let db = session.db().await;
1755    let room = db
1756        .update_room_participant_location(room_id, session.connection_id, location)
1757        .await?;
1758
1759    room_updated(&room, &session.peer);
1760    response.send(proto::Ack {})?;
1761    Ok(())
1762}
1763
1764/// Share a project into the room.
1765async fn share_project(
1766    request: proto::ShareProject,
1767    response: Response<proto::ShareProject>,
1768    session: MessageContext,
1769) -> Result<()> {
1770    let (project_id, room) = &*session
1771        .db()
1772        .await
1773        .share_project(
1774            RoomId::from_proto(request.room_id),
1775            session.connection_id,
1776            &request.worktrees,
1777            request.is_ssh_project,
1778            request.windows_paths.unwrap_or(false),
1779            &request.features,
1780        )
1781        .await?;
1782    response.send(proto::ShareProjectResponse {
1783        project_id: project_id.to_proto(),
1784    })?;
1785    room_updated(room, &session.peer);
1786
1787    Ok(())
1788}
1789
1790/// Unshare a project from the room.
1791async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1792    let project_id = ProjectId::from_proto(message.project_id);
1793    unshare_project_internal(project_id, session.connection_id, &session).await
1794}
1795
1796async fn unshare_project_internal(
1797    project_id: ProjectId,
1798    connection_id: ConnectionId,
1799    session: &Session,
1800) -> Result<()> {
1801    let delete = {
1802        let room_guard = session
1803            .db()
1804            .await
1805            .unshare_project(project_id, connection_id)
1806            .await?;
1807
1808        let (delete, room, guest_connection_ids) = &*room_guard;
1809
1810        let message = proto::UnshareProject {
1811            project_id: project_id.to_proto(),
1812        };
1813
1814        broadcast(
1815            Some(connection_id),
1816            guest_connection_ids.iter().copied(),
1817            |conn_id| session.peer.send(conn_id, message.clone()),
1818        );
1819        if let Some(room) = room {
1820            room_updated(room, &session.peer);
1821        }
1822
1823        *delete
1824    };
1825
1826    if delete {
1827        let db = session.db().await;
1828        db.delete_project(project_id).await?;
1829    }
1830
1831    Ok(())
1832}
1833
1834/// Join someone elses shared project.
1835async fn join_project(
1836    request: proto::JoinProject,
1837    response: Response<proto::JoinProject>,
1838    session: MessageContext,
1839) -> Result<()> {
1840    let project_id = ProjectId::from_proto(request.project_id);
1841
1842    tracing::info!(%project_id, "join project");
1843
1844    let db = session.db().await;
1845    let project_model = db.get_project(project_id).await?;
1846    let host_features: Vec<String> =
1847        serde_json::from_str(&project_model.features).unwrap_or_default();
1848    let guest_features: HashSet<_> = request.features.iter().collect();
1849    let host_features_set: HashSet<_> = host_features.iter().collect();
1850    if guest_features != host_features_set {
1851        let host_connection_id = project_model.host_connection()?;
1852        let mut pool = session.connection_pool().await;
1853        let host_version = pool
1854            .connection(host_connection_id)
1855            .map(|c| c.zed_version.to_string());
1856        let guest_version = pool
1857            .connection(session.connection_id)
1858            .map(|c| c.zed_version.to_string());
1859        drop(pool);
1860        Err(anyhow!(
1861            "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1862            host_version.as_deref().unwrap_or("unknown"),
1863            guest_version.as_deref().unwrap_or("unknown"),
1864        ))?;
1865    }
1866
1867    let (project, replica_id) = &mut *db
1868        .join_project(
1869            project_id,
1870            session.connection_id,
1871            session.user_id(),
1872            request.committer_name.clone(),
1873            request.committer_email.clone(),
1874        )
1875        .await?;
1876    drop(db);
1877
1878    tracing::info!(%project_id, "join remote project");
1879    let collaborators = project
1880        .collaborators
1881        .iter()
1882        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1883        .map(|collaborator| collaborator.to_proto())
1884        .collect::<Vec<_>>();
1885    let project_id = project.id;
1886    let guest_user_id = session.user_id();
1887
1888    let worktrees = project
1889        .worktrees
1890        .iter()
1891        .map(|(id, worktree)| proto::WorktreeMetadata {
1892            id: *id,
1893            root_name: worktree.root_name.clone(),
1894            visible: worktree.visible,
1895            abs_path: worktree.abs_path.clone(),
1896        })
1897        .collect::<Vec<_>>();
1898
1899    let add_project_collaborator = proto::AddProjectCollaborator {
1900        project_id: project_id.to_proto(),
1901        collaborator: Some(proto::Collaborator {
1902            peer_id: Some(session.connection_id.into()),
1903            replica_id: replica_id.0 as u32,
1904            user_id: guest_user_id.to_proto(),
1905            is_host: false,
1906            committer_name: request.committer_name.clone(),
1907            committer_email: request.committer_email.clone(),
1908        }),
1909    };
1910
1911    for collaborator in &collaborators {
1912        session
1913            .peer
1914            .send(
1915                collaborator.peer_id.unwrap().into(),
1916                add_project_collaborator.clone(),
1917            )
1918            .trace_err();
1919    }
1920
1921    // First, we send the metadata associated with each worktree.
1922    let (language_servers, language_server_capabilities) = project
1923        .language_servers
1924        .clone()
1925        .into_iter()
1926        .map(|server| (server.server, server.capabilities))
1927        .unzip();
1928    response.send(proto::JoinProjectResponse {
1929        project_id: project.id.0 as u64,
1930        worktrees,
1931        replica_id: replica_id.0 as u32,
1932        collaborators,
1933        language_servers,
1934        language_server_capabilities,
1935        role: project.role.into(),
1936        windows_paths: project.path_style == PathStyle::Windows,
1937        features: project.features.clone(),
1938    })?;
1939
1940    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1941        // Stream this worktree's entries.
1942        let message = proto::UpdateWorktree {
1943            project_id: project_id.to_proto(),
1944            worktree_id,
1945            abs_path: worktree.abs_path.clone(),
1946            root_name: worktree.root_name,
1947            root_repo_common_dir: worktree.root_repo_common_dir,
1948            updated_entries: worktree.entries,
1949            removed_entries: Default::default(),
1950            scan_id: worktree.scan_id,
1951            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1952            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1953            removed_repositories: Default::default(),
1954        };
1955        for update in proto::split_worktree_update(message) {
1956            session.peer.send(session.connection_id, update.clone())?;
1957        }
1958
1959        // Stream this worktree's diagnostics.
1960        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1961        if let Some(summary) = worktree_diagnostics.next() {
1962            let message = proto::UpdateDiagnosticSummary {
1963                project_id: project.id.to_proto(),
1964                worktree_id: worktree.id,
1965                summary: Some(summary),
1966                more_summaries: worktree_diagnostics.collect(),
1967            };
1968            session.peer.send(session.connection_id, message)?;
1969        }
1970
1971        for settings_file in worktree.settings_files {
1972            session.peer.send(
1973                session.connection_id,
1974                proto::UpdateWorktreeSettings {
1975                    project_id: project_id.to_proto(),
1976                    worktree_id: worktree.id,
1977                    path: settings_file.path,
1978                    content: Some(settings_file.content),
1979                    kind: Some(settings_file.kind.to_proto() as i32),
1980                    outside_worktree: Some(settings_file.outside_worktree),
1981                },
1982            )?;
1983        }
1984    }
1985
1986    for repository in mem::take(&mut project.repositories) {
1987        for update in split_repository_update(repository) {
1988            session.peer.send(session.connection_id, update)?;
1989        }
1990    }
1991
1992    for language_server in &project.language_servers {
1993        session.peer.send(
1994            session.connection_id,
1995            proto::UpdateLanguageServer {
1996                project_id: project_id.to_proto(),
1997                server_name: Some(language_server.server.name.clone()),
1998                language_server_id: language_server.server.id,
1999                variant: Some(
2000                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2001                        proto::LspDiskBasedDiagnosticsUpdated {},
2002                    ),
2003                ),
2004            },
2005        )?;
2006    }
2007
2008    Ok(())
2009}
2010
2011/// Leave someone elses shared project.
2012async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2013    let sender_id = session.connection_id;
2014    let project_id = ProjectId::from_proto(request.project_id);
2015    let db = session.db().await;
2016
2017    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2018    tracing::info!(
2019        %project_id,
2020        "leave project"
2021    );
2022
2023    project_left(project, &session);
2024    if let Some(room) = room {
2025        room_updated(room, &session.peer);
2026    }
2027
2028    Ok(())
2029}
2030
2031/// Updates other participants with changes to the project
2032async fn update_project(
2033    request: proto::UpdateProject,
2034    response: Response<proto::UpdateProject>,
2035    session: MessageContext,
2036) -> Result<()> {
2037    let project_id = ProjectId::from_proto(request.project_id);
2038    let (room, guest_connection_ids) = &*session
2039        .db()
2040        .await
2041        .update_project(project_id, session.connection_id, &request.worktrees)
2042        .await?;
2043    broadcast(
2044        Some(session.connection_id),
2045        guest_connection_ids.iter().copied(),
2046        |connection_id| {
2047            session
2048                .peer
2049                .forward_send(session.connection_id, connection_id, request.clone())
2050        },
2051    );
2052    if let Some(room) = room {
2053        room_updated(room, &session.peer);
2054    }
2055    response.send(proto::Ack {})?;
2056
2057    Ok(())
2058}
2059
2060/// Updates other participants with changes to the worktree
2061async fn update_worktree(
2062    request: proto::UpdateWorktree,
2063    response: Response<proto::UpdateWorktree>,
2064    session: MessageContext,
2065) -> Result<()> {
2066    let guest_connection_ids = session
2067        .db()
2068        .await
2069        .update_worktree(&request, session.connection_id)
2070        .await?;
2071
2072    broadcast(
2073        Some(session.connection_id),
2074        guest_connection_ids.iter().copied(),
2075        |connection_id| {
2076            session
2077                .peer
2078                .forward_send(session.connection_id, connection_id, request.clone())
2079        },
2080    );
2081    response.send(proto::Ack {})?;
2082    Ok(())
2083}
2084
2085async fn update_repository(
2086    request: proto::UpdateRepository,
2087    response: Response<proto::UpdateRepository>,
2088    session: MessageContext,
2089) -> Result<()> {
2090    let guest_connection_ids = session
2091        .db()
2092        .await
2093        .update_repository(&request, session.connection_id)
2094        .await?;
2095
2096    broadcast(
2097        Some(session.connection_id),
2098        guest_connection_ids.iter().copied(),
2099        |connection_id| {
2100            session
2101                .peer
2102                .forward_send(session.connection_id, connection_id, request.clone())
2103        },
2104    );
2105    response.send(proto::Ack {})?;
2106    Ok(())
2107}
2108
2109async fn remove_repository(
2110    request: proto::RemoveRepository,
2111    response: Response<proto::RemoveRepository>,
2112    session: MessageContext,
2113) -> Result<()> {
2114    let guest_connection_ids = session
2115        .db()
2116        .await
2117        .remove_repository(&request, session.connection_id)
2118        .await?;
2119
2120    broadcast(
2121        Some(session.connection_id),
2122        guest_connection_ids.iter().copied(),
2123        |connection_id| {
2124            session
2125                .peer
2126                .forward_send(session.connection_id, connection_id, request.clone())
2127        },
2128    );
2129    response.send(proto::Ack {})?;
2130    Ok(())
2131}
2132
2133/// Updates other participants with changes to the diagnostics
2134async fn update_diagnostic_summary(
2135    message: proto::UpdateDiagnosticSummary,
2136    session: MessageContext,
2137) -> Result<()> {
2138    let guest_connection_ids = session
2139        .db()
2140        .await
2141        .update_diagnostic_summary(&message, session.connection_id)
2142        .await?;
2143
2144    broadcast(
2145        Some(session.connection_id),
2146        guest_connection_ids.iter().copied(),
2147        |connection_id| {
2148            session
2149                .peer
2150                .forward_send(session.connection_id, connection_id, message.clone())
2151        },
2152    );
2153
2154    Ok(())
2155}
2156
2157/// Updates other participants with changes to the worktree settings
2158async fn update_worktree_settings(
2159    message: proto::UpdateWorktreeSettings,
2160    session: MessageContext,
2161) -> Result<()> {
2162    let guest_connection_ids = session
2163        .db()
2164        .await
2165        .update_worktree_settings(&message, session.connection_id)
2166        .await?;
2167
2168    broadcast(
2169        Some(session.connection_id),
2170        guest_connection_ids.iter().copied(),
2171        |connection_id| {
2172            session
2173                .peer
2174                .forward_send(session.connection_id, connection_id, message.clone())
2175        },
2176    );
2177
2178    Ok(())
2179}
2180
2181/// Notify other participants that a language server has started.
2182async fn start_language_server(
2183    request: proto::StartLanguageServer,
2184    session: MessageContext,
2185) -> Result<()> {
2186    let guest_connection_ids = session
2187        .db()
2188        .await
2189        .start_language_server(&request, session.connection_id)
2190        .await?;
2191
2192    broadcast(
2193        Some(session.connection_id),
2194        guest_connection_ids.iter().copied(),
2195        |connection_id| {
2196            session
2197                .peer
2198                .forward_send(session.connection_id, connection_id, request.clone())
2199        },
2200    );
2201    Ok(())
2202}
2203
2204/// Notify other participants that a language server has changed.
2205async fn update_language_server(
2206    request: proto::UpdateLanguageServer,
2207    session: MessageContext,
2208) -> Result<()> {
2209    let project_id = ProjectId::from_proto(request.project_id);
2210    let db = session.db().await;
2211
2212    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2213        && let Some(capabilities) = update.capabilities.clone()
2214    {
2215        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2216            .await?;
2217    }
2218
2219    let project_connection_ids = db
2220        .project_connection_ids(project_id, session.connection_id, true)
2221        .await?;
2222    broadcast(
2223        Some(session.connection_id),
2224        project_connection_ids.iter().copied(),
2225        |connection_id| {
2226            session
2227                .peer
2228                .forward_send(session.connection_id, connection_id, request.clone())
2229        },
2230    );
2231    Ok(())
2232}
2233
2234/// forward a project request to the host. These requests should be read only
2235/// as guests are allowed to send them.
2236async fn forward_read_only_project_request<T>(
2237    request: T,
2238    response: Response<T>,
2239    session: MessageContext,
2240) -> Result<()>
2241where
2242    T: EntityMessage + RequestMessage,
2243{
2244    let project_id = ProjectId::from_proto(request.remote_entity_id());
2245    let host_connection_id = session
2246        .db()
2247        .await
2248        .host_for_read_only_project_request(project_id, session.connection_id)
2249        .await?;
2250    let payload = session.forward_request(host_connection_id, request).await?;
2251    response.send(payload)?;
2252    Ok(())
2253}
2254
2255/// forward a project request to the host. These requests are disallowed
2256/// for guests.
2257async fn forward_mutating_project_request<T>(
2258    request: T,
2259    response: Response<T>,
2260    session: MessageContext,
2261) -> Result<()>
2262where
2263    T: EntityMessage + RequestMessage,
2264{
2265    let project_id = ProjectId::from_proto(request.remote_entity_id());
2266
2267    let host_connection_id = session
2268        .db()
2269        .await
2270        .host_for_mutating_project_request(project_id, session.connection_id)
2271        .await?;
2272    let payload = session.forward_request(host_connection_id, request).await?;
2273    response.send(payload)?;
2274    Ok(())
2275}
2276
2277async fn disallow_guest_request<T>(
2278    _request: T,
2279    response: Response<T>,
2280    _session: MessageContext,
2281) -> Result<()>
2282where
2283    T: RequestMessage,
2284{
2285    response.peer.respond_with_error(
2286        response.receipt,
2287        ErrorCode::Forbidden
2288            .message("request is not allowed for guests".to_string())
2289            .to_proto(),
2290    )?;
2291    response.responded.store(true, SeqCst);
2292    Ok(())
2293}
2294
2295async fn lsp_query(
2296    request: proto::LspQuery,
2297    response: Response<proto::LspQuery>,
2298    session: MessageContext,
2299) -> Result<()> {
2300    let (name, should_write) = request.query_name_and_write_permissions();
2301    tracing::Span::current().record("lsp_query_request", name);
2302    tracing::info!("lsp_query message received");
2303    if should_write {
2304        forward_mutating_project_request(request, response, session).await
2305    } else {
2306        forward_read_only_project_request(request, response, session).await
2307    }
2308}
2309
2310/// Notify other participants that a new buffer has been created
2311async fn create_buffer_for_peer(
2312    request: proto::CreateBufferForPeer,
2313    session: MessageContext,
2314) -> Result<()> {
2315    session
2316        .db()
2317        .await
2318        .check_user_is_project_host(
2319            ProjectId::from_proto(request.project_id),
2320            session.connection_id,
2321        )
2322        .await?;
2323    let peer_id = request.peer_id.context("invalid peer id")?;
2324    session
2325        .peer
2326        .forward_send(session.connection_id, peer_id.into(), request)?;
2327    Ok(())
2328}
2329
2330/// Notify other participants that a new image has been created
2331async fn create_image_for_peer(
2332    request: proto::CreateImageForPeer,
2333    session: MessageContext,
2334) -> Result<()> {
2335    session
2336        .db()
2337        .await
2338        .check_user_is_project_host(
2339            ProjectId::from_proto(request.project_id),
2340            session.connection_id,
2341        )
2342        .await?;
2343    let peer_id = request.peer_id.context("invalid peer id")?;
2344    session
2345        .peer
2346        .forward_send(session.connection_id, peer_id.into(), request)?;
2347    Ok(())
2348}
2349
2350/// Notify other participants that a buffer has been updated. This is
2351/// allowed for guests as long as the update is limited to selections.
2352async fn update_buffer(
2353    request: proto::UpdateBuffer,
2354    response: Response<proto::UpdateBuffer>,
2355    session: MessageContext,
2356) -> Result<()> {
2357    let project_id = ProjectId::from_proto(request.project_id);
2358    let mut capability = Capability::ReadOnly;
2359
2360    for op in request.operations.iter() {
2361        match op.variant {
2362            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2363            Some(_) => capability = Capability::ReadWrite,
2364        }
2365    }
2366
2367    let host = {
2368        let guard = session
2369            .db()
2370            .await
2371            .connections_for_buffer_update(project_id, session.connection_id, capability)
2372            .await?;
2373
2374        let (host, guests) = &*guard;
2375
2376        broadcast(
2377            Some(session.connection_id),
2378            guests.clone(),
2379            |connection_id| {
2380                session
2381                    .peer
2382                    .forward_send(session.connection_id, connection_id, request.clone())
2383            },
2384        );
2385
2386        *host
2387    };
2388
2389    if host != session.connection_id {
2390        session.forward_request(host, request.clone()).await?;
2391    }
2392
2393    response.send(proto::Ack {})?;
2394    Ok(())
2395}
2396
2397async fn forward_project_search_chunk(
2398    message: proto::FindSearchCandidatesChunk,
2399    response: Response<proto::FindSearchCandidatesChunk>,
2400    session: MessageContext,
2401) -> Result<()> {
2402    let peer_id = message.peer_id.context("missing peer_id")?;
2403    let payload = session
2404        .peer
2405        .forward_request(session.connection_id, peer_id.into(), message)
2406        .await?;
2407    response.send(payload)?;
2408    Ok(())
2409}
2410
2411/// Notify other participants that a project has been updated.
2412async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2413    request: T,
2414    session: MessageContext,
2415) -> Result<()> {
2416    let project_id = ProjectId::from_proto(request.remote_entity_id());
2417    let project_connection_ids = session
2418        .db()
2419        .await
2420        .project_connection_ids(project_id, session.connection_id, false)
2421        .await?;
2422
2423    broadcast(
2424        Some(session.connection_id),
2425        project_connection_ids.iter().copied(),
2426        |connection_id| {
2427            session
2428                .peer
2429                .forward_send(session.connection_id, connection_id, request.clone())
2430        },
2431    );
2432    Ok(())
2433}
2434
2435/// Start following another user in a call.
2436async fn follow(
2437    request: proto::Follow,
2438    response: Response<proto::Follow>,
2439    session: MessageContext,
2440) -> Result<()> {
2441    let room_id = RoomId::from_proto(request.room_id);
2442    let project_id = request.project_id.map(ProjectId::from_proto);
2443    let leader_id = request.leader_id.context("invalid leader id")?.into();
2444    let follower_id = session.connection_id;
2445
2446    session
2447        .db()
2448        .await
2449        .check_room_participants(room_id, leader_id, session.connection_id)
2450        .await?;
2451
2452    let response_payload = session.forward_request(leader_id, request).await?;
2453    response.send(response_payload)?;
2454
2455    if let Some(project_id) = project_id {
2456        let room = session
2457            .db()
2458            .await
2459            .follow(room_id, project_id, leader_id, follower_id)
2460            .await?;
2461        room_updated(&room, &session.peer);
2462    }
2463
2464    Ok(())
2465}
2466
2467/// Stop following another user in a call.
2468async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2469    let room_id = RoomId::from_proto(request.room_id);
2470    let project_id = request.project_id.map(ProjectId::from_proto);
2471    let leader_id = request.leader_id.context("invalid leader id")?.into();
2472    let follower_id = session.connection_id;
2473
2474    session
2475        .db()
2476        .await
2477        .check_room_participants(room_id, leader_id, session.connection_id)
2478        .await?;
2479
2480    session
2481        .peer
2482        .forward_send(session.connection_id, leader_id, request)?;
2483
2484    if let Some(project_id) = project_id {
2485        let room = session
2486            .db()
2487            .await
2488            .unfollow(room_id, project_id, leader_id, follower_id)
2489            .await?;
2490        room_updated(&room, &session.peer);
2491    }
2492
2493    Ok(())
2494}
2495
2496/// Notify everyone following you of your current location.
2497async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2498    let room_id = RoomId::from_proto(request.room_id);
2499    let database = session.db.lock().await;
2500
2501    let connection_ids = if let Some(project_id) = request.project_id {
2502        let project_id = ProjectId::from_proto(project_id);
2503        database
2504            .project_connection_ids(project_id, session.connection_id, true)
2505            .await?
2506    } else {
2507        database
2508            .room_connection_ids(room_id, session.connection_id)
2509            .await?
2510    };
2511
2512    // For now, don't send view update messages back to that view's current leader.
2513    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2514        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2515        _ => None,
2516    });
2517
2518    for connection_id in connection_ids.iter().cloned() {
2519        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2520            session
2521                .peer
2522                .forward_send(session.connection_id, connection_id, request.clone())?;
2523        }
2524    }
2525    Ok(())
2526}
2527
2528/// Get public data about users.
2529async fn get_users(
2530    request: proto::GetUsers,
2531    response: Response<proto::GetUsers>,
2532    session: MessageContext,
2533) -> Result<()> {
2534    let user_ids = request
2535        .user_ids
2536        .into_iter()
2537        .map(UserId::from_proto)
2538        .collect();
2539    let users = session
2540        .db()
2541        .await
2542        .get_users_by_ids(user_ids)
2543        .await?
2544        .into_iter()
2545        .map(|user| proto::User {
2546            id: user.id.to_proto(),
2547            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2548            github_login: user.github_login,
2549            name: user.name,
2550        })
2551        .collect();
2552    response.send(proto::UsersResponse { users })?;
2553    Ok(())
2554}
2555
2556/// Search for users (to invite) buy Github login
2557async fn fuzzy_search_users(
2558    request: proto::FuzzySearchUsers,
2559    response: Response<proto::FuzzySearchUsers>,
2560    session: MessageContext,
2561) -> Result<()> {
2562    let query = request.query;
2563    let users = match query.len() {
2564        0 => vec![],
2565        1 | 2 => session
2566            .db()
2567            .await
2568            .get_user_by_github_login(&query)
2569            .await?
2570            .into_iter()
2571            .collect(),
2572        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2573    };
2574    let users = users
2575        .into_iter()
2576        .filter(|user| user.id != session.user_id())
2577        .map(|user| proto::User {
2578            id: user.id.to_proto(),
2579            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2580            github_login: user.github_login,
2581            name: user.name,
2582        })
2583        .collect();
2584    response.send(proto::UsersResponse { users })?;
2585    Ok(())
2586}
2587
2588/// Send a contact request to another user.
2589async fn request_contact(
2590    request: proto::RequestContact,
2591    response: Response<proto::RequestContact>,
2592    session: MessageContext,
2593) -> Result<()> {
2594    let requester_id = session.user_id();
2595    let responder_id = UserId::from_proto(request.responder_id);
2596    if requester_id == responder_id {
2597        return Err(anyhow!("cannot add yourself as a contact"))?;
2598    }
2599
2600    let notifications = session
2601        .db()
2602        .await
2603        .send_contact_request(requester_id, responder_id)
2604        .await?;
2605
2606    // Update outgoing contact requests of requester
2607    let mut update = proto::UpdateContacts::default();
2608    update.outgoing_requests.push(responder_id.to_proto());
2609    for connection_id in session
2610        .connection_pool()
2611        .await
2612        .user_connection_ids(requester_id)
2613    {
2614        session.peer.send(connection_id, update.clone())?;
2615    }
2616
2617    // Update incoming contact requests of responder
2618    let mut update = proto::UpdateContacts::default();
2619    update
2620        .incoming_requests
2621        .push(proto::IncomingContactRequest {
2622            requester_id: requester_id.to_proto(),
2623        });
2624    let connection_pool = session.connection_pool().await;
2625    for connection_id in connection_pool.user_connection_ids(responder_id) {
2626        session.peer.send(connection_id, update.clone())?;
2627    }
2628
2629    send_notifications(&connection_pool, &session.peer, notifications);
2630
2631    response.send(proto::Ack {})?;
2632    Ok(())
2633}
2634
2635/// Accept or decline a contact request
2636async fn respond_to_contact_request(
2637    request: proto::RespondToContactRequest,
2638    response: Response<proto::RespondToContactRequest>,
2639    session: MessageContext,
2640) -> Result<()> {
2641    let responder_id = session.user_id();
2642    let requester_id = UserId::from_proto(request.requester_id);
2643    let db = session.db().await;
2644    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2645        db.dismiss_contact_notification(responder_id, requester_id)
2646            .await?;
2647    } else {
2648        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2649
2650        let notifications = db
2651            .respond_to_contact_request(responder_id, requester_id, accept)
2652            .await?;
2653        let requester_busy = db.is_user_busy(requester_id).await?;
2654        let responder_busy = db.is_user_busy(responder_id).await?;
2655
2656        let pool = session.connection_pool().await;
2657        // Update responder with new contact
2658        let mut update = proto::UpdateContacts::default();
2659        if accept {
2660            update
2661                .contacts
2662                .push(contact_for_user(requester_id, requester_busy, &pool));
2663        }
2664        update
2665            .remove_incoming_requests
2666            .push(requester_id.to_proto());
2667        for connection_id in pool.user_connection_ids(responder_id) {
2668            session.peer.send(connection_id, update.clone())?;
2669        }
2670
2671        // Update requester with new contact
2672        let mut update = proto::UpdateContacts::default();
2673        if accept {
2674            update
2675                .contacts
2676                .push(contact_for_user(responder_id, responder_busy, &pool));
2677        }
2678        update
2679            .remove_outgoing_requests
2680            .push(responder_id.to_proto());
2681
2682        for connection_id in pool.user_connection_ids(requester_id) {
2683            session.peer.send(connection_id, update.clone())?;
2684        }
2685
2686        send_notifications(&pool, &session.peer, notifications);
2687    }
2688
2689    response.send(proto::Ack {})?;
2690    Ok(())
2691}
2692
2693/// Remove a contact.
2694async fn remove_contact(
2695    request: proto::RemoveContact,
2696    response: Response<proto::RemoveContact>,
2697    session: MessageContext,
2698) -> Result<()> {
2699    let requester_id = session.user_id();
2700    let responder_id = UserId::from_proto(request.user_id);
2701    let db = session.db().await;
2702    let (contact_accepted, deleted_notification_id) =
2703        db.remove_contact(requester_id, responder_id).await?;
2704
2705    let pool = session.connection_pool().await;
2706    // Update outgoing contact requests of requester
2707    let mut update = proto::UpdateContacts::default();
2708    if contact_accepted {
2709        update.remove_contacts.push(responder_id.to_proto());
2710    } else {
2711        update
2712            .remove_outgoing_requests
2713            .push(responder_id.to_proto());
2714    }
2715    for connection_id in pool.user_connection_ids(requester_id) {
2716        session.peer.send(connection_id, update.clone())?;
2717    }
2718
2719    // Update incoming contact requests of responder
2720    let mut update = proto::UpdateContacts::default();
2721    if contact_accepted {
2722        update.remove_contacts.push(requester_id.to_proto());
2723    } else {
2724        update
2725            .remove_incoming_requests
2726            .push(requester_id.to_proto());
2727    }
2728    for connection_id in pool.user_connection_ids(responder_id) {
2729        session.peer.send(connection_id, update.clone())?;
2730        if let Some(notification_id) = deleted_notification_id {
2731            session.peer.send(
2732                connection_id,
2733                proto::DeleteNotification {
2734                    notification_id: notification_id.to_proto(),
2735                },
2736            )?;
2737        }
2738    }
2739
2740    response.send(proto::Ack {})?;
2741    Ok(())
2742}
2743
2744fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2745    version.0.minor < 139
2746}
2747
2748async fn subscribe_to_channels(
2749    _: proto::SubscribeToChannels,
2750    session: MessageContext,
2751) -> Result<()> {
2752    subscribe_user_to_channels(session.user_id(), &session).await?;
2753    Ok(())
2754}
2755
2756async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2757    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2758    let mut pool = session.connection_pool().await;
2759    for membership in &channels_for_user.channel_memberships {
2760        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2761    }
2762    session.peer.send(
2763        session.connection_id,
2764        build_update_user_channels(&channels_for_user),
2765    )?;
2766    session.peer.send(
2767        session.connection_id,
2768        build_channels_update(channels_for_user),
2769    )?;
2770    Ok(())
2771}
2772
2773/// Creates a new channel.
2774async fn create_channel(
2775    request: proto::CreateChannel,
2776    response: Response<proto::CreateChannel>,
2777    session: MessageContext,
2778) -> Result<()> {
2779    let db = session.db().await;
2780
2781    let parent_id = request.parent_id.map(ChannelId::from_proto);
2782    let (channel, membership) = db
2783        .create_channel(&request.name, parent_id, session.user_id())
2784        .await?;
2785
2786    let root_id = channel.root_id();
2787    let channel = Channel::from_model(channel);
2788
2789    response.send(proto::CreateChannelResponse {
2790        channel: Some(channel.to_proto()),
2791        parent_id: request.parent_id,
2792    })?;
2793
2794    let mut connection_pool = session.connection_pool().await;
2795    if let Some(membership) = membership {
2796        connection_pool.subscribe_to_channel(
2797            membership.user_id,
2798            membership.channel_id,
2799            membership.role,
2800        );
2801        let update = proto::UpdateUserChannels {
2802            channel_memberships: vec![proto::ChannelMembership {
2803                channel_id: membership.channel_id.to_proto(),
2804                role: membership.role.into(),
2805            }],
2806            ..Default::default()
2807        };
2808        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2809            session.peer.send(connection_id, update.clone())?;
2810        }
2811    }
2812
2813    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2814        if !role.can_see_channel(channel.visibility) {
2815            continue;
2816        }
2817
2818        let update = proto::UpdateChannels {
2819            channels: vec![channel.to_proto()],
2820            ..Default::default()
2821        };
2822        session.peer.send(connection_id, update.clone())?;
2823    }
2824
2825    Ok(())
2826}
2827
2828/// Delete a channel
2829async fn delete_channel(
2830    request: proto::DeleteChannel,
2831    response: Response<proto::DeleteChannel>,
2832    session: MessageContext,
2833) -> Result<()> {
2834    let db = session.db().await;
2835
2836    let channel_id = request.channel_id;
2837    let (root_channel, removed_channels) = db
2838        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2839        .await?;
2840    response.send(proto::Ack {})?;
2841
2842    // Notify members of removed channels
2843    let mut update = proto::UpdateChannels::default();
2844    update
2845        .delete_channels
2846        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2847
2848    let connection_pool = session.connection_pool().await;
2849    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2850        session.peer.send(connection_id, update.clone())?;
2851    }
2852
2853    Ok(())
2854}
2855
2856/// Invite someone to join a channel.
2857async fn invite_channel_member(
2858    request: proto::InviteChannelMember,
2859    response: Response<proto::InviteChannelMember>,
2860    session: MessageContext,
2861) -> Result<()> {
2862    let db = session.db().await;
2863    let channel_id = ChannelId::from_proto(request.channel_id);
2864    let invitee_id = UserId::from_proto(request.user_id);
2865    let InviteMemberResult {
2866        channel,
2867        notifications,
2868    } = db
2869        .invite_channel_member(
2870            channel_id,
2871            invitee_id,
2872            session.user_id(),
2873            request.role().into(),
2874        )
2875        .await?;
2876
2877    let update = proto::UpdateChannels {
2878        channel_invitations: vec![channel.to_proto()],
2879        ..Default::default()
2880    };
2881
2882    let connection_pool = session.connection_pool().await;
2883    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2884        session.peer.send(connection_id, update.clone())?;
2885    }
2886
2887    send_notifications(&connection_pool, &session.peer, notifications);
2888
2889    response.send(proto::Ack {})?;
2890    Ok(())
2891}
2892
2893/// remove someone from a channel
2894async fn remove_channel_member(
2895    request: proto::RemoveChannelMember,
2896    response: Response<proto::RemoveChannelMember>,
2897    session: MessageContext,
2898) -> Result<()> {
2899    let db = session.db().await;
2900    let channel_id = ChannelId::from_proto(request.channel_id);
2901    let member_id = UserId::from_proto(request.user_id);
2902
2903    let RemoveChannelMemberResult {
2904        membership_update,
2905        notification_id,
2906    } = db
2907        .remove_channel_member(channel_id, member_id, session.user_id())
2908        .await?;
2909
2910    let mut connection_pool = session.connection_pool().await;
2911    notify_membership_updated(
2912        &mut connection_pool,
2913        membership_update,
2914        member_id,
2915        &session.peer,
2916    );
2917    for connection_id in connection_pool.user_connection_ids(member_id) {
2918        if let Some(notification_id) = notification_id {
2919            session
2920                .peer
2921                .send(
2922                    connection_id,
2923                    proto::DeleteNotification {
2924                        notification_id: notification_id.to_proto(),
2925                    },
2926                )
2927                .trace_err();
2928        }
2929    }
2930
2931    response.send(proto::Ack {})?;
2932    Ok(())
2933}
2934
2935/// Toggle the channel between public and private.
2936/// Care is taken to maintain the invariant that public channels only descend from public channels,
2937/// (though members-only channels can appear at any point in the hierarchy).
2938async fn set_channel_visibility(
2939    request: proto::SetChannelVisibility,
2940    response: Response<proto::SetChannelVisibility>,
2941    session: MessageContext,
2942) -> Result<()> {
2943    let db = session.db().await;
2944    let channel_id = ChannelId::from_proto(request.channel_id);
2945    let visibility = request.visibility().into();
2946
2947    let channel_model = db
2948        .set_channel_visibility(channel_id, visibility, session.user_id())
2949        .await?;
2950    let root_id = channel_model.root_id();
2951    let channel = Channel::from_model(channel_model);
2952
2953    let mut connection_pool = session.connection_pool().await;
2954    for (user_id, role) in connection_pool
2955        .channel_user_ids(root_id)
2956        .collect::<Vec<_>>()
2957        .into_iter()
2958    {
2959        let update = if role.can_see_channel(channel.visibility) {
2960            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2961            proto::UpdateChannels {
2962                channels: vec![channel.to_proto()],
2963                ..Default::default()
2964            }
2965        } else {
2966            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2967            proto::UpdateChannels {
2968                delete_channels: vec![channel.id.to_proto()],
2969                ..Default::default()
2970            }
2971        };
2972
2973        for connection_id in connection_pool.user_connection_ids(user_id) {
2974            session.peer.send(connection_id, update.clone())?;
2975        }
2976    }
2977
2978    response.send(proto::Ack {})?;
2979    Ok(())
2980}
2981
2982/// Alter the role for a user in the channel.
2983async fn set_channel_member_role(
2984    request: proto::SetChannelMemberRole,
2985    response: Response<proto::SetChannelMemberRole>,
2986    session: MessageContext,
2987) -> Result<()> {
2988    let db = session.db().await;
2989    let channel_id = ChannelId::from_proto(request.channel_id);
2990    let member_id = UserId::from_proto(request.user_id);
2991    let result = db
2992        .set_channel_member_role(
2993            channel_id,
2994            session.user_id(),
2995            member_id,
2996            request.role().into(),
2997        )
2998        .await?;
2999
3000    match result {
3001        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3002            let mut connection_pool = session.connection_pool().await;
3003            notify_membership_updated(
3004                &mut connection_pool,
3005                membership_update,
3006                member_id,
3007                &session.peer,
3008            )
3009        }
3010        db::SetMemberRoleResult::InviteUpdated(channel) => {
3011            let update = proto::UpdateChannels {
3012                channel_invitations: vec![channel.to_proto()],
3013                ..Default::default()
3014            };
3015
3016            for connection_id in session
3017                .connection_pool()
3018                .await
3019                .user_connection_ids(member_id)
3020            {
3021                session.peer.send(connection_id, update.clone())?;
3022            }
3023        }
3024    }
3025
3026    response.send(proto::Ack {})?;
3027    Ok(())
3028}
3029
3030/// Change the name of a channel
3031async fn rename_channel(
3032    request: proto::RenameChannel,
3033    response: Response<proto::RenameChannel>,
3034    session: MessageContext,
3035) -> Result<()> {
3036    let db = session.db().await;
3037    let channel_id = ChannelId::from_proto(request.channel_id);
3038    let channel_model = db
3039        .rename_channel(channel_id, session.user_id(), &request.name)
3040        .await?;
3041    let root_id = channel_model.root_id();
3042    let channel = Channel::from_model(channel_model);
3043
3044    response.send(proto::RenameChannelResponse {
3045        channel: Some(channel.to_proto()),
3046    })?;
3047
3048    let connection_pool = session.connection_pool().await;
3049    let update = proto::UpdateChannels {
3050        channels: vec![channel.to_proto()],
3051        ..Default::default()
3052    };
3053    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3054        if role.can_see_channel(channel.visibility) {
3055            session.peer.send(connection_id, update.clone())?;
3056        }
3057    }
3058
3059    Ok(())
3060}
3061
3062/// Move a channel to a new parent.
3063async fn move_channel(
3064    request: proto::MoveChannel,
3065    response: Response<proto::MoveChannel>,
3066    session: MessageContext,
3067) -> Result<()> {
3068    let channel_id = ChannelId::from_proto(request.channel_id);
3069    let to = ChannelId::from_proto(request.to);
3070
3071    let (root_id, channels) = session
3072        .db()
3073        .await
3074        .move_channel(channel_id, to, session.user_id())
3075        .await?;
3076
3077    let connection_pool = session.connection_pool().await;
3078    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3079        let channels = channels
3080            .iter()
3081            .filter_map(|channel| {
3082                if role.can_see_channel(channel.visibility) {
3083                    Some(channel.to_proto())
3084                } else {
3085                    None
3086                }
3087            })
3088            .collect::<Vec<_>>();
3089        if channels.is_empty() {
3090            continue;
3091        }
3092
3093        let update = proto::UpdateChannels {
3094            channels,
3095            ..Default::default()
3096        };
3097
3098        session.peer.send(connection_id, update.clone())?;
3099    }
3100
3101    response.send(Ack {})?;
3102    Ok(())
3103}
3104
3105async fn reorder_channel(
3106    request: proto::ReorderChannel,
3107    response: Response<proto::ReorderChannel>,
3108    session: MessageContext,
3109) -> Result<()> {
3110    let channel_id = ChannelId::from_proto(request.channel_id);
3111    let direction = request.direction();
3112
3113    let updated_channels = session
3114        .db()
3115        .await
3116        .reorder_channel(channel_id, direction, session.user_id())
3117        .await?;
3118
3119    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3120        let connection_pool = session.connection_pool().await;
3121        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3122            let channels = updated_channels
3123                .iter()
3124                .filter_map(|channel| {
3125                    if role.can_see_channel(channel.visibility) {
3126                        Some(channel.to_proto())
3127                    } else {
3128                        None
3129                    }
3130                })
3131                .collect::<Vec<_>>();
3132
3133            if channels.is_empty() {
3134                continue;
3135            }
3136
3137            let update = proto::UpdateChannels {
3138                channels,
3139                ..Default::default()
3140            };
3141
3142            session.peer.send(connection_id, update.clone())?;
3143        }
3144    }
3145
3146    response.send(Ack {})?;
3147    Ok(())
3148}
3149
3150/// Get the list of channel members
3151async fn get_channel_members(
3152    request: proto::GetChannelMembers,
3153    response: Response<proto::GetChannelMembers>,
3154    session: MessageContext,
3155) -> Result<()> {
3156    let db = session.db().await;
3157    let channel_id = ChannelId::from_proto(request.channel_id);
3158    let limit = if request.limit == 0 {
3159        u16::MAX as u64
3160    } else {
3161        request.limit
3162    };
3163    let (members, users) = db
3164        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3165        .await?;
3166    response.send(proto::GetChannelMembersResponse { members, users })?;
3167    Ok(())
3168}
3169
3170/// Accept or decline a channel invitation.
3171async fn respond_to_channel_invite(
3172    request: proto::RespondToChannelInvite,
3173    response: Response<proto::RespondToChannelInvite>,
3174    session: MessageContext,
3175) -> Result<()> {
3176    let db = session.db().await;
3177    let channel_id = ChannelId::from_proto(request.channel_id);
3178    let RespondToChannelInvite {
3179        membership_update,
3180        notifications,
3181    } = db
3182        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3183        .await?;
3184
3185    let mut connection_pool = session.connection_pool().await;
3186    if let Some(membership_update) = membership_update {
3187        notify_membership_updated(
3188            &mut connection_pool,
3189            membership_update,
3190            session.user_id(),
3191            &session.peer,
3192        );
3193    } else {
3194        let update = proto::UpdateChannels {
3195            remove_channel_invitations: vec![channel_id.to_proto()],
3196            ..Default::default()
3197        };
3198
3199        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3200            session.peer.send(connection_id, update.clone())?;
3201        }
3202    };
3203
3204    send_notifications(&connection_pool, &session.peer, notifications);
3205
3206    response.send(proto::Ack {})?;
3207
3208    Ok(())
3209}
3210
3211/// Join the channels' room
3212async fn join_channel(
3213    request: proto::JoinChannel,
3214    response: Response<proto::JoinChannel>,
3215    session: MessageContext,
3216) -> Result<()> {
3217    let channel_id = ChannelId::from_proto(request.channel_id);
3218    join_channel_internal(channel_id, Box::new(response), session).await
3219}
3220
3221trait JoinChannelInternalResponse {
3222    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3223}
3224impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3225    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3226        Response::<proto::JoinChannel>::send(self, result)
3227    }
3228}
3229impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3230    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3231        Response::<proto::JoinRoom>::send(self, result)
3232    }
3233}
3234
3235async fn join_channel_internal(
3236    channel_id: ChannelId,
3237    response: Box<impl JoinChannelInternalResponse>,
3238    session: MessageContext,
3239) -> Result<()> {
3240    let joined_room = {
3241        let mut db = session.db().await;
3242        // If zed quits without leaving the room, and the user re-opens zed before the
3243        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3244        // room they were in.
3245        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3246            tracing::info!(
3247                stale_connection_id = %connection,
3248                "cleaning up stale connection",
3249            );
3250            drop(db);
3251            leave_room_for_session(&session, connection).await?;
3252            db = session.db().await;
3253        }
3254
3255        let (joined_room, membership_updated, role) = db
3256            .join_channel(channel_id, session.user_id(), session.connection_id)
3257            .await?;
3258
3259        let live_kit_connection_info =
3260            session
3261                .app_state
3262                .livekit_client
3263                .as_ref()
3264                .and_then(|live_kit| {
3265                    let (can_publish, token) = if role == ChannelRole::Guest {
3266                        (
3267                            false,
3268                            live_kit
3269                                .guest_token(
3270                                    &joined_room.room.livekit_room,
3271                                    &session.user_id().to_string(),
3272                                )
3273                                .trace_err()?,
3274                        )
3275                    } else {
3276                        (
3277                            true,
3278                            live_kit
3279                                .room_token(
3280                                    &joined_room.room.livekit_room,
3281                                    &session.user_id().to_string(),
3282                                )
3283                                .trace_err()?,
3284                        )
3285                    };
3286
3287                    Some(LiveKitConnectionInfo {
3288                        server_url: live_kit.url().into(),
3289                        token,
3290                        can_publish,
3291                    })
3292                });
3293
3294        response.send(proto::JoinRoomResponse {
3295            room: Some(joined_room.room.clone()),
3296            channel_id: joined_room
3297                .channel
3298                .as_ref()
3299                .map(|channel| channel.id.to_proto()),
3300            live_kit_connection_info,
3301        })?;
3302
3303        let mut connection_pool = session.connection_pool().await;
3304        if let Some(membership_updated) = membership_updated {
3305            notify_membership_updated(
3306                &mut connection_pool,
3307                membership_updated,
3308                session.user_id(),
3309                &session.peer,
3310            );
3311        }
3312
3313        room_updated(&joined_room.room, &session.peer);
3314
3315        joined_room
3316    };
3317
3318    channel_updated(
3319        &joined_room.channel.context("channel not returned")?,
3320        &joined_room.room,
3321        &session.peer,
3322        &*session.connection_pool().await,
3323    );
3324
3325    update_user_contacts(session.user_id(), &session).await?;
3326    Ok(())
3327}
3328
3329/// Start editing the channel notes
3330async fn join_channel_buffer(
3331    request: proto::JoinChannelBuffer,
3332    response: Response<proto::JoinChannelBuffer>,
3333    session: MessageContext,
3334) -> Result<()> {
3335    let db = session.db().await;
3336    let channel_id = ChannelId::from_proto(request.channel_id);
3337
3338    let open_response = db
3339        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3340        .await?;
3341
3342    let collaborators = open_response.collaborators.clone();
3343    response.send(open_response)?;
3344
3345    let update = UpdateChannelBufferCollaborators {
3346        channel_id: channel_id.to_proto(),
3347        collaborators: collaborators.clone(),
3348    };
3349    channel_buffer_updated(
3350        session.connection_id,
3351        collaborators
3352            .iter()
3353            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3354        &update,
3355        &session.peer,
3356    );
3357
3358    Ok(())
3359}
3360
3361/// Edit the channel notes
3362async fn update_channel_buffer(
3363    request: proto::UpdateChannelBuffer,
3364    session: MessageContext,
3365) -> Result<()> {
3366    let db = session.db().await;
3367    let channel_id = ChannelId::from_proto(request.channel_id);
3368
3369    let (collaborators, epoch, version) = db
3370        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3371        .await?;
3372
3373    channel_buffer_updated(
3374        session.connection_id,
3375        collaborators.clone(),
3376        &proto::UpdateChannelBuffer {
3377            channel_id: channel_id.to_proto(),
3378            operations: request.operations,
3379        },
3380        &session.peer,
3381    );
3382
3383    let pool = &*session.connection_pool().await;
3384
3385    let non_collaborators =
3386        pool.channel_connection_ids(channel_id)
3387            .filter_map(|(connection_id, _)| {
3388                if collaborators.contains(&connection_id) {
3389                    None
3390                } else {
3391                    Some(connection_id)
3392                }
3393            });
3394
3395    broadcast(None, non_collaborators, |peer_id| {
3396        session.peer.send(
3397            peer_id,
3398            proto::UpdateChannels {
3399                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3400                    channel_id: channel_id.to_proto(),
3401                    epoch: epoch as u64,
3402                    version: version.clone(),
3403                }],
3404                ..Default::default()
3405            },
3406        )
3407    });
3408
3409    Ok(())
3410}
3411
3412/// Rejoin the channel notes after a connection blip
3413async fn rejoin_channel_buffers(
3414    request: proto::RejoinChannelBuffers,
3415    response: Response<proto::RejoinChannelBuffers>,
3416    session: MessageContext,
3417) -> Result<()> {
3418    let db = session.db().await;
3419    let buffers = db
3420        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3421        .await?;
3422
3423    for rejoined_buffer in &buffers {
3424        let collaborators_to_notify = rejoined_buffer
3425            .buffer
3426            .collaborators
3427            .iter()
3428            .filter_map(|c| Some(c.peer_id?.into()));
3429        channel_buffer_updated(
3430            session.connection_id,
3431            collaborators_to_notify,
3432            &proto::UpdateChannelBufferCollaborators {
3433                channel_id: rejoined_buffer.buffer.channel_id,
3434                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3435            },
3436            &session.peer,
3437        );
3438    }
3439
3440    response.send(proto::RejoinChannelBuffersResponse {
3441        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3442    })?;
3443
3444    Ok(())
3445}
3446
3447/// Stop editing the channel notes
3448async fn leave_channel_buffer(
3449    request: proto::LeaveChannelBuffer,
3450    response: Response<proto::LeaveChannelBuffer>,
3451    session: MessageContext,
3452) -> Result<()> {
3453    let db = session.db().await;
3454    let channel_id = ChannelId::from_proto(request.channel_id);
3455
3456    let left_buffer = db
3457        .leave_channel_buffer(channel_id, session.connection_id)
3458        .await?;
3459
3460    response.send(Ack {})?;
3461
3462    channel_buffer_updated(
3463        session.connection_id,
3464        left_buffer.connections,
3465        &proto::UpdateChannelBufferCollaborators {
3466            channel_id: channel_id.to_proto(),
3467            collaborators: left_buffer.collaborators,
3468        },
3469        &session.peer,
3470    );
3471
3472    Ok(())
3473}
3474
3475fn channel_buffer_updated<T: EnvelopedMessage>(
3476    sender_id: ConnectionId,
3477    collaborators: impl IntoIterator<Item = ConnectionId>,
3478    message: &T,
3479    peer: &Peer,
3480) {
3481    broadcast(Some(sender_id), collaborators, |peer_id| {
3482        peer.send(peer_id, message.clone())
3483    });
3484}
3485
3486fn send_notifications(
3487    connection_pool: &ConnectionPool,
3488    peer: &Peer,
3489    notifications: db::NotificationBatch,
3490) {
3491    for (user_id, notification) in notifications {
3492        for connection_id in connection_pool.user_connection_ids(user_id) {
3493            if let Err(error) = peer.send(
3494                connection_id,
3495                proto::AddNotification {
3496                    notification: Some(notification.clone()),
3497                },
3498            ) {
3499                tracing::error!(
3500                    "failed to send notification to {:?} {}",
3501                    connection_id,
3502                    error
3503                );
3504            }
3505        }
3506    }
3507}
3508
3509/// Send a message to the channel
3510async fn send_channel_message(
3511    _request: proto::SendChannelMessage,
3512    _response: Response<proto::SendChannelMessage>,
3513    _session: MessageContext,
3514) -> Result<()> {
3515    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3516}
3517
3518/// Delete a channel message
3519async fn remove_channel_message(
3520    _request: proto::RemoveChannelMessage,
3521    _response: Response<proto::RemoveChannelMessage>,
3522    _session: MessageContext,
3523) -> Result<()> {
3524    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3525}
3526
3527async fn update_channel_message(
3528    _request: proto::UpdateChannelMessage,
3529    _response: Response<proto::UpdateChannelMessage>,
3530    _session: MessageContext,
3531) -> Result<()> {
3532    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3533}
3534
3535/// Mark a channel message as read
3536async fn acknowledge_channel_message(
3537    _request: proto::AckChannelMessage,
3538    _session: MessageContext,
3539) -> Result<()> {
3540    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3541}
3542
3543/// Mark a buffer version as synced
3544async fn acknowledge_buffer_version(
3545    request: proto::AckBufferOperation,
3546    session: MessageContext,
3547) -> Result<()> {
3548    let buffer_id = BufferId::from_proto(request.buffer_id);
3549    session
3550        .db()
3551        .await
3552        .observe_buffer_version(
3553            buffer_id,
3554            session.user_id(),
3555            request.epoch as i32,
3556            &request.version,
3557        )
3558        .await?;
3559    Ok(())
3560}
3561
3562/// Start receiving chat updates for a channel
3563async fn join_channel_chat(
3564    _request: proto::JoinChannelChat,
3565    _response: Response<proto::JoinChannelChat>,
3566    _session: MessageContext,
3567) -> Result<()> {
3568    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3569}
3570
3571/// Stop receiving chat updates for a channel
3572async fn leave_channel_chat(
3573    _request: proto::LeaveChannelChat,
3574    _session: MessageContext,
3575) -> Result<()> {
3576    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3577}
3578
3579/// Retrieve the chat history for a channel
3580async fn get_channel_messages(
3581    _request: proto::GetChannelMessages,
3582    _response: Response<proto::GetChannelMessages>,
3583    _session: MessageContext,
3584) -> Result<()> {
3585    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3586}
3587
3588/// Retrieve specific chat messages
3589async fn get_channel_messages_by_id(
3590    _request: proto::GetChannelMessagesById,
3591    _response: Response<proto::GetChannelMessagesById>,
3592    _session: MessageContext,
3593) -> Result<()> {
3594    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3595}
3596
3597/// Retrieve the current users notifications
3598async fn get_notifications(
3599    request: proto::GetNotifications,
3600    response: Response<proto::GetNotifications>,
3601    session: MessageContext,
3602) -> Result<()> {
3603    let notifications = session
3604        .db()
3605        .await
3606        .get_notifications(
3607            session.user_id(),
3608            NOTIFICATION_COUNT_PER_PAGE,
3609            request.before_id.map(db::NotificationId::from_proto),
3610        )
3611        .await?;
3612    response.send(proto::GetNotificationsResponse {
3613        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3614        notifications,
3615    })?;
3616    Ok(())
3617}
3618
3619/// Mark notifications as read
3620async fn mark_notification_as_read(
3621    request: proto::MarkNotificationRead,
3622    response: Response<proto::MarkNotificationRead>,
3623    session: MessageContext,
3624) -> Result<()> {
3625    let database = &session.db().await;
3626    let notifications = database
3627        .mark_notification_as_read_by_id(
3628            session.user_id(),
3629            NotificationId::from_proto(request.notification_id),
3630        )
3631        .await?;
3632    send_notifications(
3633        &*session.connection_pool().await,
3634        &session.peer,
3635        notifications,
3636    );
3637    response.send(proto::Ack {})?;
3638    Ok(())
3639}
3640
3641fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3642    let message = match message {
3643        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3644        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3645        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3646        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3647        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3648            code: frame.code.into(),
3649            reason: frame.reason.as_str().to_owned().into(),
3650        })),
3651        // We should never receive a frame while reading the message, according
3652        // to the `tungstenite` maintainers:
3653        //
3654        // > It cannot occur when you read messages from the WebSocket, but it
3655        // > can be used when you want to send the raw frames (e.g. you want to
3656        // > send the frames to the WebSocket without composing the full message first).
3657        // >
3658        // > — https://github.com/snapview/tungstenite-rs/issues/268
3659        TungsteniteMessage::Frame(_) => {
3660            bail!("received an unexpected frame while reading the message")
3661        }
3662    };
3663
3664    Ok(message)
3665}
3666
3667fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3668    match message {
3669        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3670        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3671        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3672        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3673        AxumMessage::Close(frame) => {
3674            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3675                code: frame.code.into(),
3676                reason: frame.reason.as_ref().into(),
3677            }))
3678        }
3679    }
3680}
3681
3682fn notify_membership_updated(
3683    connection_pool: &mut ConnectionPool,
3684    result: MembershipUpdated,
3685    user_id: UserId,
3686    peer: &Peer,
3687) {
3688    for membership in &result.new_channels.channel_memberships {
3689        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3690    }
3691    for channel_id in &result.removed_channels {
3692        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3693    }
3694
3695    let user_channels_update = proto::UpdateUserChannels {
3696        channel_memberships: result
3697            .new_channels
3698            .channel_memberships
3699            .iter()
3700            .map(|cm| proto::ChannelMembership {
3701                channel_id: cm.channel_id.to_proto(),
3702                role: cm.role.into(),
3703            })
3704            .collect(),
3705        ..Default::default()
3706    };
3707
3708    let mut update = build_channels_update(result.new_channels);
3709    update.delete_channels = result
3710        .removed_channels
3711        .into_iter()
3712        .map(|id| id.to_proto())
3713        .collect();
3714    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3715
3716    for connection_id in connection_pool.user_connection_ids(user_id) {
3717        peer.send(connection_id, user_channels_update.clone())
3718            .trace_err();
3719        peer.send(connection_id, update.clone()).trace_err();
3720    }
3721}
3722
3723fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3724    proto::UpdateUserChannels {
3725        channel_memberships: channels
3726            .channel_memberships
3727            .iter()
3728            .map(|m| proto::ChannelMembership {
3729                channel_id: m.channel_id.to_proto(),
3730                role: m.role.into(),
3731            })
3732            .collect(),
3733        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3734    }
3735}
3736
3737fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3738    let mut update = proto::UpdateChannels::default();
3739
3740    for channel in channels.channels {
3741        update.channels.push(channel.to_proto());
3742    }
3743
3744    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3745
3746    for (channel_id, participants) in channels.channel_participants {
3747        update
3748            .channel_participants
3749            .push(proto::ChannelParticipants {
3750                channel_id: channel_id.to_proto(),
3751                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3752            });
3753    }
3754
3755    for channel in channels.invited_channels {
3756        update.channel_invitations.push(channel.to_proto());
3757    }
3758
3759    update
3760}
3761
3762fn build_initial_contacts_update(
3763    contacts: Vec<db::Contact>,
3764    pool: &ConnectionPool,
3765) -> proto::UpdateContacts {
3766    let mut update = proto::UpdateContacts::default();
3767
3768    for contact in contacts {
3769        match contact {
3770            db::Contact::Accepted { user_id, busy } => {
3771                update.contacts.push(contact_for_user(user_id, busy, pool));
3772            }
3773            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3774            db::Contact::Incoming { user_id } => {
3775                update
3776                    .incoming_requests
3777                    .push(proto::IncomingContactRequest {
3778                        requester_id: user_id.to_proto(),
3779                    })
3780            }
3781        }
3782    }
3783
3784    update
3785}
3786
3787fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3788    proto::Contact {
3789        user_id: user_id.to_proto(),
3790        online: pool.is_user_online(user_id),
3791        busy,
3792    }
3793}
3794
3795fn room_updated(room: &proto::Room, peer: &Peer) {
3796    broadcast(
3797        None,
3798        room.participants
3799            .iter()
3800            .filter_map(|participant| Some(participant.peer_id?.into())),
3801        |peer_id| {
3802            peer.send(
3803                peer_id,
3804                proto::RoomUpdated {
3805                    room: Some(room.clone()),
3806                },
3807            )
3808        },
3809    );
3810}
3811
3812fn channel_updated(
3813    channel: &db::channel::Model,
3814    room: &proto::Room,
3815    peer: &Peer,
3816    pool: &ConnectionPool,
3817) {
3818    let participants = room
3819        .participants
3820        .iter()
3821        .map(|p| p.user_id)
3822        .collect::<Vec<_>>();
3823
3824    broadcast(
3825        None,
3826        pool.channel_connection_ids(channel.root_id())
3827            .filter_map(|(channel_id, role)| {
3828                role.can_see_channel(channel.visibility)
3829                    .then_some(channel_id)
3830            }),
3831        |peer_id| {
3832            peer.send(
3833                peer_id,
3834                proto::UpdateChannels {
3835                    channel_participants: vec![proto::ChannelParticipants {
3836                        channel_id: channel.id.to_proto(),
3837                        participant_user_ids: participants.clone(),
3838                    }],
3839                    ..Default::default()
3840                },
3841            )
3842        },
3843    );
3844}
3845
3846async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3847    let db = session.db().await;
3848
3849    let contacts = db.get_contacts(user_id).await?;
3850    let busy = db.is_user_busy(user_id).await?;
3851
3852    let pool = session.connection_pool().await;
3853    let updated_contact = contact_for_user(user_id, busy, &pool);
3854    for contact in contacts {
3855        if let db::Contact::Accepted {
3856            user_id: contact_user_id,
3857            ..
3858        } = contact
3859        {
3860            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3861                session
3862                    .peer
3863                    .send(
3864                        contact_conn_id,
3865                        proto::UpdateContacts {
3866                            contacts: vec![updated_contact.clone()],
3867                            remove_contacts: Default::default(),
3868                            incoming_requests: Default::default(),
3869                            remove_incoming_requests: Default::default(),
3870                            outgoing_requests: Default::default(),
3871                            remove_outgoing_requests: Default::default(),
3872                        },
3873                    )
3874                    .trace_err();
3875            }
3876        }
3877    }
3878    Ok(())
3879}
3880
3881async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3882    let mut contacts_to_update = HashSet::default();
3883
3884    let room_id;
3885    let canceled_calls_to_user_ids;
3886    let livekit_room;
3887    let delete_livekit_room;
3888    let room;
3889    let channel;
3890
3891    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3892        contacts_to_update.insert(session.user_id());
3893
3894        for project in left_room.left_projects.values() {
3895            project_left(project, session);
3896        }
3897
3898        room_id = RoomId::from_proto(left_room.room.id);
3899        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3900        livekit_room = mem::take(&mut left_room.room.livekit_room);
3901        delete_livekit_room = left_room.deleted;
3902        room = mem::take(&mut left_room.room);
3903        channel = mem::take(&mut left_room.channel);
3904
3905        room_updated(&room, &session.peer);
3906    } else {
3907        return Ok(());
3908    }
3909
3910    if let Some(channel) = channel {
3911        channel_updated(
3912            &channel,
3913            &room,
3914            &session.peer,
3915            &*session.connection_pool().await,
3916        );
3917    }
3918
3919    {
3920        let pool = session.connection_pool().await;
3921        for canceled_user_id in canceled_calls_to_user_ids {
3922            for connection_id in pool.user_connection_ids(canceled_user_id) {
3923                session
3924                    .peer
3925                    .send(
3926                        connection_id,
3927                        proto::CallCanceled {
3928                            room_id: room_id.to_proto(),
3929                        },
3930                    )
3931                    .trace_err();
3932            }
3933            contacts_to_update.insert(canceled_user_id);
3934        }
3935    }
3936
3937    for contact_user_id in contacts_to_update {
3938        update_user_contacts(contact_user_id, session).await?;
3939    }
3940
3941    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3942        live_kit
3943            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3944            .await
3945            .trace_err();
3946
3947        if delete_livekit_room {
3948            live_kit.delete_room(livekit_room).await.trace_err();
3949        }
3950    }
3951
3952    Ok(())
3953}
3954
3955async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3956    let left_channel_buffers = session
3957        .db()
3958        .await
3959        .leave_channel_buffers(session.connection_id)
3960        .await?;
3961
3962    for left_buffer in left_channel_buffers {
3963        channel_buffer_updated(
3964            session.connection_id,
3965            left_buffer.connections,
3966            &proto::UpdateChannelBufferCollaborators {
3967                channel_id: left_buffer.channel_id.to_proto(),
3968                collaborators: left_buffer.collaborators,
3969            },
3970            &session.peer,
3971        );
3972    }
3973
3974    Ok(())
3975}
3976
3977fn project_left(project: &db::LeftProject, session: &Session) {
3978    for connection_id in &project.connection_ids {
3979        if project.should_unshare {
3980            session
3981                .peer
3982                .send(
3983                    *connection_id,
3984                    proto::UnshareProject {
3985                        project_id: project.id.to_proto(),
3986                    },
3987                )
3988                .trace_err();
3989        } else {
3990            session
3991                .peer
3992                .send(
3993                    *connection_id,
3994                    proto::RemoveProjectCollaborator {
3995                        project_id: project.id.to_proto(),
3996                        peer_id: Some(session.connection_id.into()),
3997                    },
3998                )
3999                .trace_err();
4000        }
4001    }
4002}
4003
4004async fn share_agent_thread(
4005    request: proto::ShareAgentThread,
4006    response: Response<proto::ShareAgentThread>,
4007    session: MessageContext,
4008) -> Result<()> {
4009    let user_id = session.user_id();
4010
4011    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4012        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4013
4014    session
4015        .db()
4016        .await
4017        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4018        .await?;
4019
4020    response.send(proto::Ack {})?;
4021
4022    Ok(())
4023}
4024
4025async fn get_shared_agent_thread(
4026    request: proto::GetSharedAgentThread,
4027    response: Response<proto::GetSharedAgentThread>,
4028    session: MessageContext,
4029) -> Result<()> {
4030    let share_id = SharedThreadId::from_proto(request.session_id)
4031        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4032
4033    let result = session.db().await.get_shared_thread(share_id).await?;
4034
4035    match result {
4036        Some((thread, username)) => {
4037            response.send(proto::GetSharedAgentThreadResponse {
4038                title: thread.title,
4039                thread_data: thread.data,
4040                sharer_username: username,
4041                created_at: thread.created_at.and_utc().to_rfc3339(),
4042            })?;
4043        }
4044        None => {
4045            return Err(anyhow!("Shared thread not found").into());
4046        }
4047    }
4048
4049    Ok(())
4050}
4051
4052pub trait ResultExt {
4053    type Ok;
4054
4055    fn trace_err(self) -> Option<Self::Ok>;
4056}
4057
4058impl<T, E> ResultExt for Result<T, E>
4059where
4060    E: std::fmt::Debug,
4061{
4062    type Ok = T;
4063
4064    #[track_caller]
4065    fn trace_err(self) -> Option<T> {
4066        match self {
4067            Ok(value) => Some(value),
4068            Err(error) => {
4069                tracing::error!("{:?}", error);
4070                None
4071            }
4072        }
4073    }
4074}