rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 414            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 415            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 416            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 417            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 419            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 420            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 421            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 422            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 423            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 426            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 427            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 428            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 429            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 430            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 431            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 432            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 437            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 438            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 439            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 440            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 441            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 442            .add_message_handler(update_context)
 443            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 444            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 445            .add_request_handler(share_agent_thread)
 446            .add_request_handler(get_shared_agent_thread)
 447            .add_request_handler(forward_project_search_chunk);
 448
 449        Arc::new(server)
 450    }
 451
 452    pub async fn start(&self) -> Result<()> {
 453        let server_id = *self.id.lock();
 454        let app_state = self.app_state.clone();
 455        let peer = self.peer.clone();
 456        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 457        let pool = self.connection_pool.clone();
 458        let livekit_client = self.app_state.livekit_client.clone();
 459
 460        let span = info_span!("start server");
 461        self.app_state.executor.spawn_detached(
 462            async move {
 463                tracing::info!("waiting for cleanup timeout");
 464                timeout.await;
 465                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 466
 467                app_state
 468                    .db
 469                    .delete_stale_channel_chat_participants(
 470                        &app_state.config.zed_environment,
 471                        server_id,
 472                    )
 473                    .await
 474                    .trace_err();
 475
 476                if let Some((room_ids, channel_ids)) = app_state
 477                    .db
 478                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 479                    .await
 480                    .trace_err()
 481                {
 482                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 483                    tracing::info!(
 484                        stale_channel_buffer_count = channel_ids.len(),
 485                        "retrieved stale channel buffers"
 486                    );
 487
 488                    for channel_id in channel_ids {
 489                        if let Some(refreshed_channel_buffer) = app_state
 490                            .db
 491                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 492                            .await
 493                            .trace_err()
 494                        {
 495                            for connection_id in refreshed_channel_buffer.connection_ids {
 496                                peer.send(
 497                                    connection_id,
 498                                    proto::UpdateChannelBufferCollaborators {
 499                                        channel_id: channel_id.to_proto(),
 500                                        collaborators: refreshed_channel_buffer
 501                                            .collaborators
 502                                            .clone(),
 503                                    },
 504                                )
 505                                .trace_err();
 506                            }
 507                        }
 508                    }
 509
 510                    for room_id in room_ids {
 511                        let mut contacts_to_update = HashSet::default();
 512                        let mut canceled_calls_to_user_ids = Vec::new();
 513                        let mut livekit_room = String::new();
 514                        let mut delete_livekit_room = false;
 515
 516                        if let Some(mut refreshed_room) = app_state
 517                            .db
 518                            .clear_stale_room_participants(room_id, server_id)
 519                            .await
 520                            .trace_err()
 521                        {
 522                            tracing::info!(
 523                                room_id = room_id.0,
 524                                new_participant_count = refreshed_room.room.participants.len(),
 525                                "refreshed room"
 526                            );
 527                            room_updated(&refreshed_room.room, &peer);
 528                            if let Some(channel) = refreshed_room.channel.as_ref() {
 529                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 530                            }
 531                            contacts_to_update
 532                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 533                            contacts_to_update
 534                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 535                            canceled_calls_to_user_ids =
 536                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 537                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 538                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 539                        }
 540
 541                        {
 542                            let pool = pool.lock();
 543                            for canceled_user_id in canceled_calls_to_user_ids {
 544                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 545                                    peer.send(
 546                                        connection_id,
 547                                        proto::CallCanceled {
 548                                            room_id: room_id.to_proto(),
 549                                        },
 550                                    )
 551                                    .trace_err();
 552                                }
 553                            }
 554                        }
 555
 556                        for user_id in contacts_to_update {
 557                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 558                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 559                            if let Some((busy, contacts)) = busy.zip(contacts) {
 560                                let pool = pool.lock();
 561                                let updated_contact = contact_for_user(user_id, busy, &pool);
 562                                for contact in contacts {
 563                                    if let db::Contact::Accepted {
 564                                        user_id: contact_user_id,
 565                                        ..
 566                                    } = contact
 567                                    {
 568                                        for contact_conn_id in
 569                                            pool.user_connection_ids(contact_user_id)
 570                                        {
 571                                            peer.send(
 572                                                contact_conn_id,
 573                                                proto::UpdateContacts {
 574                                                    contacts: vec![updated_contact.clone()],
 575                                                    remove_contacts: Default::default(),
 576                                                    incoming_requests: Default::default(),
 577                                                    remove_incoming_requests: Default::default(),
 578                                                    outgoing_requests: Default::default(),
 579                                                    remove_outgoing_requests: Default::default(),
 580                                                },
 581                                            )
 582                                            .trace_err();
 583                                        }
 584                                    }
 585                                }
 586                            }
 587                        }
 588
 589                        if let Some(live_kit) = livekit_client.as_ref()
 590                            && delete_livekit_room
 591                        {
 592                            live_kit.delete_room(livekit_room).await.trace_err();
 593                        }
 594                    }
 595                }
 596
 597                app_state
 598                    .db
 599                    .delete_stale_channel_chat_participants(
 600                        &app_state.config.zed_environment,
 601                        server_id,
 602                    )
 603                    .await
 604                    .trace_err();
 605
 606                app_state
 607                    .db
 608                    .clear_old_worktree_entries(server_id)
 609                    .await
 610                    .trace_err();
 611
 612                app_state
 613                    .db
 614                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 615                    .await
 616                    .trace_err();
 617            }
 618            .instrument(span),
 619        );
 620        Ok(())
 621    }
 622
 623    pub fn teardown(&self) {
 624        self.peer.teardown();
 625        self.connection_pool.lock().reset();
 626        let _ = self.teardown.send(true);
 627    }
 628
 629    #[cfg(feature = "test-support")]
 630    pub fn reset(&self, id: ServerId) {
 631        self.teardown();
 632        *self.id.lock() = id;
 633        self.peer.reset(id.0 as u32);
 634        let _ = self.teardown.send(false);
 635    }
 636
 637    #[cfg(feature = "test-support")]
 638    pub fn id(&self) -> ServerId {
 639        *self.id.lock()
 640    }
 641
 642    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 643    where
 644        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 645        Fut: 'static + Send + Future<Output = Result<()>>,
 646        M: EnvelopedMessage,
 647    {
 648        let prev_handler = self.handlers.insert(
 649            TypeId::of::<M>(),
 650            Box::new(move |envelope, session, span| {
 651                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 652                let received_at = envelope.received_at;
 653                tracing::info!("message received");
 654                let start_time = Instant::now();
 655                let future = (handler)(
 656                    *envelope,
 657                    MessageContext {
 658                        session,
 659                        span: span.clone(),
 660                    },
 661                );
 662                async move {
 663                    let result = future.await;
 664                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 665                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 666                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 667                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 668                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 669                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 670                    match result {
 671                        Err(error) => {
 672                            tracing::error!(?error, "error handling message")
 673                        }
 674                        Ok(()) => tracing::info!("finished handling message"),
 675                    }
 676                }
 677                .boxed()
 678            }),
 679        );
 680        if prev_handler.is_some() {
 681            panic!("registered a handler for the same message twice");
 682        }
 683        self
 684    }
 685
 686    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 687    where
 688        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 689        Fut: 'static + Send + Future<Output = Result<()>>,
 690        M: EnvelopedMessage,
 691    {
 692        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 693        self
 694    }
 695
 696    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 697    where
 698        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 699        Fut: Send + Future<Output = Result<()>>,
 700        M: RequestMessage,
 701    {
 702        let handler = Arc::new(handler);
 703        self.add_handler(move |envelope, session| {
 704            let receipt = envelope.receipt();
 705            let handler = handler.clone();
 706            async move {
 707                let peer = session.peer.clone();
 708                let responded = Arc::new(AtomicBool::default());
 709                let response = Response {
 710                    peer: peer.clone(),
 711                    responded: responded.clone(),
 712                    receipt,
 713                };
 714                match (handler)(envelope.payload, response, session).await {
 715                    Ok(()) => {
 716                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 717                            Ok(())
 718                        } else {
 719                            Err(anyhow!("handler did not send a response"))?
 720                        }
 721                    }
 722                    Err(error) => {
 723                        let proto_err = match &error {
 724                            Error::Internal(err) => err.to_proto(),
 725                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 726                        };
 727                        peer.respond_with_error(receipt, proto_err)?;
 728                        Err(error)
 729                    }
 730                }
 731            }
 732        })
 733    }
 734
 735    pub fn handle_connection(
 736        self: &Arc<Self>,
 737        connection: Connection,
 738        address: String,
 739        principal: Principal,
 740        zed_version: ZedVersion,
 741        release_channel: Option<String>,
 742        user_agent: Option<String>,
 743        geoip_country_code: Option<String>,
 744        system_id: Option<String>,
 745        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 746        executor: Executor,
 747        connection_guard: Option<ConnectionGuard>,
 748    ) -> impl Future<Output = ()> + use<> {
 749        let this = self.clone();
 750        let span = info_span!("handle connection", %address,
 751            connection_id=field::Empty,
 752            user_id=field::Empty,
 753            login=field::Empty,
 754            user_agent=field::Empty,
 755            geoip_country_code=field::Empty,
 756            release_channel=field::Empty,
 757        );
 758        principal.update_span(&span);
 759        if let Some(user_agent) = user_agent {
 760            span.record("user_agent", user_agent);
 761        }
 762        if let Some(release_channel) = release_channel {
 763            span.record("release_channel", release_channel);
 764        }
 765
 766        if let Some(country_code) = geoip_country_code.as_ref() {
 767            span.record("geoip_country_code", country_code);
 768        }
 769
 770        let mut teardown = self.teardown.subscribe();
 771        async move {
 772            if *teardown.borrow() {
 773                tracing::error!("server is tearing down");
 774                return;
 775            }
 776
 777            let (connection_id, handle_io, mut incoming_rx) =
 778                this.peer.add_connection(connection, {
 779                    let executor = executor.clone();
 780                    move |duration| executor.sleep(duration)
 781                });
 782            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 783
 784            tracing::info!("connection opened");
 785
 786            let session = Session {
 787                principal: principal.clone(),
 788                connection_id,
 789                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 790                peer: this.peer.clone(),
 791                connection_pool: this.connection_pool.clone(),
 792                app_state: this.app_state.clone(),
 793                geoip_country_code,
 794                system_id,
 795                _executor: executor.clone(),
 796            };
 797
 798            if let Err(error) = this
 799                .send_initial_client_update(
 800                    connection_id,
 801                    zed_version,
 802                    send_connection_id,
 803                    &session,
 804                )
 805                .await
 806            {
 807                tracing::error!(?error, "failed to send initial client update");
 808                return;
 809            }
 810            drop(connection_guard);
 811
 812            let handle_io = handle_io.fuse();
 813            futures::pin_mut!(handle_io);
 814
 815            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 816            // This prevents deadlocks when e.g., client A performs a request to client B and
 817            // client B performs a request to client A. If both clients stop processing further
 818            // messages until their respective request completes, they won't have a chance to
 819            // respond to the other client's request and cause a deadlock.
 820            //
 821            // This arrangement ensures we will attempt to process earlier messages first, but fall
 822            // back to processing messages arrived later in the spirit of making progress.
 823            const MAX_CONCURRENT_HANDLERS: usize = 256;
 824            let mut foreground_message_handlers = FuturesUnordered::new();
 825            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 826            let get_concurrent_handlers = {
 827                let concurrent_handlers = concurrent_handlers.clone();
 828                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 829            };
 830            loop {
 831                let next_message = async {
 832                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 833                    let message = incoming_rx.next().await;
 834                    // Cache the concurrent_handlers here, so that we know what the
 835                    // queue looks like as each handler starts
 836                    (permit, message, get_concurrent_handlers())
 837                }
 838                .fuse();
 839                futures::pin_mut!(next_message);
 840                futures::select_biased! {
 841                    _ = teardown.changed().fuse() => return,
 842                    result = handle_io => {
 843                        if let Err(error) = result {
 844                            tracing::error!(?error, "error handling I/O");
 845                        }
 846                        break;
 847                    }
 848                    _ = foreground_message_handlers.next() => {}
 849                    next_message = next_message => {
 850                        let (permit, message, concurrent_handlers) = next_message;
 851                        if let Some(message) = message {
 852                            let type_name = message.payload_type_name();
 853                            // note: we copy all the fields from the parent span so we can query them in the logs.
 854                            // (https://github.com/tokio-rs/tracing/issues/2670).
 855                            let span = tracing::info_span!("receive message",
 856                                %connection_id,
 857                                %address,
 858                                type_name,
 859                                concurrent_handlers,
 860                                user_id=field::Empty,
 861                                login=field::Empty,
 862                                lsp_query_request=field::Empty,
 863                                release_channel=field::Empty,
 864                                { TOTAL_DURATION_MS }=field::Empty,
 865                                { PROCESSING_DURATION_MS }=field::Empty,
 866                                { QUEUE_DURATION_MS }=field::Empty,
 867                                { HOST_WAITING_MS }=field::Empty
 868                            );
 869                            principal.update_span(&span);
 870                            let span_enter = span.enter();
 871                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 872                                let is_background = message.is_background();
 873                                let handle_message = (handler)(message, session.clone(), span.clone());
 874                                drop(span_enter);
 875
 876                                let handle_message = async move {
 877                                    handle_message.await;
 878                                    drop(permit);
 879                                }.instrument(span);
 880                                if is_background {
 881                                    executor.spawn_detached(handle_message);
 882                                } else {
 883                                    foreground_message_handlers.push(handle_message);
 884                                }
 885                            } else {
 886                                tracing::error!("no message handler");
 887                            }
 888                        } else {
 889                            tracing::info!("connection closed");
 890                            break;
 891                        }
 892                    }
 893                }
 894            }
 895
 896            drop(foreground_message_handlers);
 897            let concurrent_handlers = get_concurrent_handlers();
 898            tracing::info!(concurrent_handlers, "signing out");
 899            if let Err(error) = connection_lost(session, teardown, executor).await {
 900                tracing::error!(?error, "error signing out");
 901            }
 902        }
 903        .instrument(span)
 904    }
 905
 906    async fn send_initial_client_update(
 907        &self,
 908        connection_id: ConnectionId,
 909        zed_version: ZedVersion,
 910        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 911        session: &Session,
 912    ) -> Result<()> {
 913        self.peer.send(
 914            connection_id,
 915            proto::Hello {
 916                peer_id: Some(connection_id.into()),
 917            },
 918        )?;
 919        tracing::info!("sent hello message");
 920        if let Some(send_connection_id) = send_connection_id.take() {
 921            let _ = send_connection_id.send(connection_id);
 922        }
 923
 924        match &session.principal {
 925            Principal::User(user) => {
 926                if !user.connected_once {
 927                    self.peer.send(connection_id, proto::ShowContacts {})?;
 928                    self.app_state
 929                        .db
 930                        .set_user_connected_once(user.id, true)
 931                        .await?;
 932                }
 933
 934                let contacts = self.app_state.db.get_contacts(user.id).await?;
 935
 936                {
 937                    let mut pool = self.connection_pool.lock();
 938                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 939                    self.peer.send(
 940                        connection_id,
 941                        build_initial_contacts_update(contacts, &pool),
 942                    )?;
 943                }
 944
 945                if should_auto_subscribe_to_channels(&zed_version) {
 946                    subscribe_user_to_channels(user.id, session).await?;
 947                }
 948
 949                if let Some(incoming_call) =
 950                    self.app_state.db.incoming_call_for_user(user.id).await?
 951                {
 952                    self.peer.send(connection_id, incoming_call)?;
 953                }
 954
 955                update_user_contacts(user.id, session).await?;
 956            }
 957        }
 958
 959        Ok(())
 960    }
 961}
 962
 963impl Deref for ConnectionPoolGuard<'_> {
 964    type Target = ConnectionPool;
 965
 966    fn deref(&self) -> &Self::Target {
 967        &self.guard
 968    }
 969}
 970
 971impl DerefMut for ConnectionPoolGuard<'_> {
 972    fn deref_mut(&mut self) -> &mut Self::Target {
 973        &mut self.guard
 974    }
 975}
 976
 977impl Drop for ConnectionPoolGuard<'_> {
 978    fn drop(&mut self) {
 979        #[cfg(feature = "test-support")]
 980        self.check_invariants();
 981    }
 982}
 983
 984fn broadcast<F>(
 985    sender_id: Option<ConnectionId>,
 986    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 987    mut f: F,
 988) where
 989    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 990{
 991    for receiver_id in receiver_ids {
 992        if Some(receiver_id) != sender_id
 993            && let Err(error) = f(receiver_id)
 994        {
 995            tracing::error!("failed to send to {:?} {}", receiver_id, error);
 996        }
 997    }
 998}
 999
1000pub struct ProtocolVersion(u32);
1001
1002impl Header for ProtocolVersion {
1003    fn name() -> &'static HeaderName {
1004        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1005        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1006    }
1007
1008    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1009    where
1010        Self: Sized,
1011        I: Iterator<Item = &'i axum::http::HeaderValue>,
1012    {
1013        let version = values
1014            .next()
1015            .ok_or_else(axum::headers::Error::invalid)?
1016            .to_str()
1017            .map_err(|_| axum::headers::Error::invalid())?
1018            .parse()
1019            .map_err(|_| axum::headers::Error::invalid())?;
1020        Ok(Self(version))
1021    }
1022
1023    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1024        values.extend([self.0.to_string().parse().unwrap()]);
1025    }
1026}
1027
1028pub struct AppVersionHeader(Version);
1029impl Header for AppVersionHeader {
1030    fn name() -> &'static HeaderName {
1031        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1032        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1033    }
1034
1035    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1036    where
1037        Self: Sized,
1038        I: Iterator<Item = &'i axum::http::HeaderValue>,
1039    {
1040        let version = values
1041            .next()
1042            .ok_or_else(axum::headers::Error::invalid)?
1043            .to_str()
1044            .map_err(|_| axum::headers::Error::invalid())?
1045            .parse()
1046            .map_err(|_| axum::headers::Error::invalid())?;
1047        Ok(Self(version))
1048    }
1049
1050    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1051        values.extend([self.0.to_string().parse().unwrap()]);
1052    }
1053}
1054
1055#[derive(Debug)]
1056pub struct ReleaseChannelHeader(String);
1057
1058impl Header for ReleaseChannelHeader {
1059    fn name() -> &'static HeaderName {
1060        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1061        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1062    }
1063
1064    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1065    where
1066        Self: Sized,
1067        I: Iterator<Item = &'i axum::http::HeaderValue>,
1068    {
1069        Ok(Self(
1070            values
1071                .next()
1072                .ok_or_else(axum::headers::Error::invalid)?
1073                .to_str()
1074                .map_err(|_| axum::headers::Error::invalid())?
1075                .to_owned(),
1076        ))
1077    }
1078
1079    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1080        values.extend([self.0.parse().unwrap()]);
1081    }
1082}
1083
1084pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1085    Router::new()
1086        .route("/rpc", get(handle_websocket_request))
1087        .layer(
1088            ServiceBuilder::new()
1089                .layer(Extension(server.app_state.clone()))
1090                .layer(middleware::from_fn(auth::validate_header)),
1091        )
1092        .route("/metrics", get(handle_metrics))
1093        .layer(Extension(server))
1094}
1095
1096pub async fn handle_websocket_request(
1097    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1098    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1099    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1100    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1101    Extension(server): Extension<Arc<Server>>,
1102    Extension(principal): Extension<Principal>,
1103    user_agent: Option<TypedHeader<UserAgent>>,
1104    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1105    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1106    ws: WebSocketUpgrade,
1107) -> axum::response::Response {
1108    if protocol_version != rpc::PROTOCOL_VERSION {
1109        return (
1110            StatusCode::UPGRADE_REQUIRED,
1111            "client must be upgraded".to_string(),
1112        )
1113            .into_response();
1114    }
1115
1116    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1117        return (
1118            StatusCode::UPGRADE_REQUIRED,
1119            "no version header found".to_string(),
1120        )
1121            .into_response();
1122    };
1123
1124    let release_channel = release_channel_header.map(|header| header.0.0);
1125
1126    if !version.can_collaborate() {
1127        return (
1128            StatusCode::UPGRADE_REQUIRED,
1129            "client must be upgraded".to_string(),
1130        )
1131            .into_response();
1132    }
1133
1134    let socket_address = socket_address.to_string();
1135
1136    // Acquire connection guard before WebSocket upgrade
1137    let connection_guard = match ConnectionGuard::try_acquire() {
1138        Ok(guard) => guard,
1139        Err(()) => {
1140            return (
1141                StatusCode::SERVICE_UNAVAILABLE,
1142                "Too many concurrent connections",
1143            )
1144                .into_response();
1145        }
1146    };
1147
1148    ws.on_upgrade(move |socket| {
1149        let socket = socket
1150            .map_ok(to_tungstenite_message)
1151            .err_into()
1152            .with(|message| async move { to_axum_message(message) });
1153        let connection = Connection::new(Box::pin(socket));
1154        async move {
1155            server
1156                .handle_connection(
1157                    connection,
1158                    socket_address,
1159                    principal,
1160                    version,
1161                    release_channel,
1162                    user_agent.map(|header| header.to_string()),
1163                    country_code_header.map(|header| header.to_string()),
1164                    system_id_header.map(|header| header.to_string()),
1165                    None,
1166                    Executor::Production,
1167                    Some(connection_guard),
1168                )
1169                .await;
1170        }
1171    })
1172}
1173
1174pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1175    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1176    let connections_metric = CONNECTIONS_METRIC
1177        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1178
1179    let connections = server
1180        .connection_pool
1181        .lock()
1182        .connections()
1183        .filter(|connection| !connection.admin)
1184        .count();
1185    connections_metric.set(connections as _);
1186
1187    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1188    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1189        register_int_gauge!(
1190            "shared_projects",
1191            "number of open projects with one or more guests"
1192        )
1193        .unwrap()
1194    });
1195
1196    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1197    shared_projects_metric.set(shared_projects as _);
1198
1199    let encoder = prometheus::TextEncoder::new();
1200    let metric_families = prometheus::gather();
1201    let encoded_metrics = encoder
1202        .encode_to_string(&metric_families)
1203        .map_err(|err| anyhow!("{err}"))?;
1204    Ok(encoded_metrics)
1205}
1206
1207#[instrument(err, skip(executor))]
1208async fn connection_lost(
1209    session: Session,
1210    mut teardown: watch::Receiver<bool>,
1211    executor: Executor,
1212) -> Result<()> {
1213    session.peer.disconnect(session.connection_id);
1214    session
1215        .connection_pool()
1216        .await
1217        .remove_connection(session.connection_id)?;
1218
1219    session
1220        .db()
1221        .await
1222        .connection_lost(session.connection_id)
1223        .await
1224        .trace_err();
1225
1226    futures::select_biased! {
1227        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1228
1229            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1230            leave_room_for_session(&session, session.connection_id).await.trace_err();
1231            leave_channel_buffers_for_session(&session)
1232                .await
1233                .trace_err();
1234
1235            if !session
1236                .connection_pool()
1237                .await
1238                .is_user_online(session.user_id())
1239            {
1240                let db = session.db().await;
1241                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1242                    room_updated(&room, &session.peer);
1243                }
1244            }
1245
1246            update_user_contacts(session.user_id(), &session).await?;
1247        },
1248        _ = teardown.changed().fuse() => {}
1249    }
1250
1251    Ok(())
1252}
1253
1254/// Acknowledges a ping from a client, used to keep the connection alive.
1255async fn ping(
1256    _: proto::Ping,
1257    response: Response<proto::Ping>,
1258    _session: MessageContext,
1259) -> Result<()> {
1260    response.send(proto::Ack {})?;
1261    Ok(())
1262}
1263
1264/// Creates a new room for calling (outside of channels)
1265async fn create_room(
1266    _request: proto::CreateRoom,
1267    response: Response<proto::CreateRoom>,
1268    session: MessageContext,
1269) -> Result<()> {
1270    let livekit_room = nanoid::nanoid!(30);
1271
1272    let live_kit_connection_info = util::maybe!(async {
1273        let live_kit = session.app_state.livekit_client.as_ref();
1274        let live_kit = live_kit?;
1275        let user_id = session.user_id().to_string();
1276
1277        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1278
1279        Some(proto::LiveKitConnectionInfo {
1280            server_url: live_kit.url().into(),
1281            token,
1282            can_publish: true,
1283        })
1284    })
1285    .await;
1286
1287    let room = session
1288        .db()
1289        .await
1290        .create_room(session.user_id(), session.connection_id, &livekit_room)
1291        .await?;
1292
1293    response.send(proto::CreateRoomResponse {
1294        room: Some(room.clone()),
1295        live_kit_connection_info,
1296    })?;
1297
1298    update_user_contacts(session.user_id(), &session).await?;
1299    Ok(())
1300}
1301
1302/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1303async fn join_room(
1304    request: proto::JoinRoom,
1305    response: Response<proto::JoinRoom>,
1306    session: MessageContext,
1307) -> Result<()> {
1308    let room_id = RoomId::from_proto(request.id);
1309
1310    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1311
1312    if let Some(channel_id) = channel_id {
1313        return join_channel_internal(channel_id, Box::new(response), session).await;
1314    }
1315
1316    let joined_room = {
1317        let room = session
1318            .db()
1319            .await
1320            .join_room(room_id, session.user_id(), session.connection_id)
1321            .await?;
1322        room_updated(&room.room, &session.peer);
1323        room.into_inner()
1324    };
1325
1326    for connection_id in session
1327        .connection_pool()
1328        .await
1329        .user_connection_ids(session.user_id())
1330    {
1331        session
1332            .peer
1333            .send(
1334                connection_id,
1335                proto::CallCanceled {
1336                    room_id: room_id.to_proto(),
1337                },
1338            )
1339            .trace_err();
1340    }
1341
1342    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1343    {
1344        live_kit
1345            .room_token(
1346                &joined_room.room.livekit_room,
1347                &session.user_id().to_string(),
1348            )
1349            .trace_err()
1350            .map(|token| proto::LiveKitConnectionInfo {
1351                server_url: live_kit.url().into(),
1352                token,
1353                can_publish: true,
1354            })
1355    } else {
1356        None
1357    };
1358
1359    response.send(proto::JoinRoomResponse {
1360        room: Some(joined_room.room),
1361        channel_id: None,
1362        live_kit_connection_info,
1363    })?;
1364
1365    update_user_contacts(session.user_id(), &session).await?;
1366    Ok(())
1367}
1368
1369/// Rejoin room is used to reconnect to a room after connection errors.
1370async fn rejoin_room(
1371    request: proto::RejoinRoom,
1372    response: Response<proto::RejoinRoom>,
1373    session: MessageContext,
1374) -> Result<()> {
1375    let room;
1376    let channel;
1377    {
1378        let mut rejoined_room = session
1379            .db()
1380            .await
1381            .rejoin_room(request, session.user_id(), session.connection_id)
1382            .await?;
1383
1384        response.send(proto::RejoinRoomResponse {
1385            room: Some(rejoined_room.room.clone()),
1386            reshared_projects: rejoined_room
1387                .reshared_projects
1388                .iter()
1389                .map(|project| proto::ResharedProject {
1390                    id: project.id.to_proto(),
1391                    collaborators: project
1392                        .collaborators
1393                        .iter()
1394                        .map(|collaborator| collaborator.to_proto())
1395                        .collect(),
1396                })
1397                .collect(),
1398            rejoined_projects: rejoined_room
1399                .rejoined_projects
1400                .iter()
1401                .map(|rejoined_project| rejoined_project.to_proto())
1402                .collect(),
1403        })?;
1404        room_updated(&rejoined_room.room, &session.peer);
1405
1406        for project in &rejoined_room.reshared_projects {
1407            for collaborator in &project.collaborators {
1408                session
1409                    .peer
1410                    .send(
1411                        collaborator.connection_id,
1412                        proto::UpdateProjectCollaborator {
1413                            project_id: project.id.to_proto(),
1414                            old_peer_id: Some(project.old_connection_id.into()),
1415                            new_peer_id: Some(session.connection_id.into()),
1416                        },
1417                    )
1418                    .trace_err();
1419            }
1420
1421            broadcast(
1422                Some(session.connection_id),
1423                project
1424                    .collaborators
1425                    .iter()
1426                    .map(|collaborator| collaborator.connection_id),
1427                |connection_id| {
1428                    session.peer.forward_send(
1429                        session.connection_id,
1430                        connection_id,
1431                        proto::UpdateProject {
1432                            project_id: project.id.to_proto(),
1433                            worktrees: project.worktrees.clone(),
1434                        },
1435                    )
1436                },
1437            );
1438        }
1439
1440        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1441
1442        let rejoined_room = rejoined_room.into_inner();
1443
1444        room = rejoined_room.room;
1445        channel = rejoined_room.channel;
1446    }
1447
1448    if let Some(channel) = channel {
1449        channel_updated(
1450            &channel,
1451            &room,
1452            &session.peer,
1453            &*session.connection_pool().await,
1454        );
1455    }
1456
1457    update_user_contacts(session.user_id(), &session).await?;
1458    Ok(())
1459}
1460
1461fn notify_rejoined_projects(
1462    rejoined_projects: &mut Vec<RejoinedProject>,
1463    session: &Session,
1464) -> Result<()> {
1465    for project in rejoined_projects.iter() {
1466        for collaborator in &project.collaborators {
1467            session
1468                .peer
1469                .send(
1470                    collaborator.connection_id,
1471                    proto::UpdateProjectCollaborator {
1472                        project_id: project.id.to_proto(),
1473                        old_peer_id: Some(project.old_connection_id.into()),
1474                        new_peer_id: Some(session.connection_id.into()),
1475                    },
1476                )
1477                .trace_err();
1478        }
1479    }
1480
1481    for project in rejoined_projects {
1482        for worktree in mem::take(&mut project.worktrees) {
1483            // Stream this worktree's entries.
1484            let message = proto::UpdateWorktree {
1485                project_id: project.id.to_proto(),
1486                worktree_id: worktree.id,
1487                abs_path: worktree.abs_path.clone(),
1488                root_name: worktree.root_name,
1489                updated_entries: worktree.updated_entries,
1490                removed_entries: worktree.removed_entries,
1491                scan_id: worktree.scan_id,
1492                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1493                updated_repositories: worktree.updated_repositories,
1494                removed_repositories: worktree.removed_repositories,
1495            };
1496            for update in proto::split_worktree_update(message) {
1497                session.peer.send(session.connection_id, update)?;
1498            }
1499
1500            // Stream this worktree's diagnostics.
1501            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1502            if let Some(summary) = worktree_diagnostics.next() {
1503                let message = proto::UpdateDiagnosticSummary {
1504                    project_id: project.id.to_proto(),
1505                    worktree_id: worktree.id,
1506                    summary: Some(summary),
1507                    more_summaries: worktree_diagnostics.collect(),
1508                };
1509                session.peer.send(session.connection_id, message)?;
1510            }
1511
1512            for settings_file in worktree.settings_files {
1513                session.peer.send(
1514                    session.connection_id,
1515                    proto::UpdateWorktreeSettings {
1516                        project_id: project.id.to_proto(),
1517                        worktree_id: worktree.id,
1518                        path: settings_file.path,
1519                        content: Some(settings_file.content),
1520                        kind: Some(settings_file.kind.to_proto().into()),
1521                        outside_worktree: Some(settings_file.outside_worktree),
1522                    },
1523                )?;
1524            }
1525        }
1526
1527        for repository in mem::take(&mut project.updated_repositories) {
1528            for update in split_repository_update(repository) {
1529                session.peer.send(session.connection_id, update)?;
1530            }
1531        }
1532
1533        for id in mem::take(&mut project.removed_repositories) {
1534            session.peer.send(
1535                session.connection_id,
1536                proto::RemoveRepository {
1537                    project_id: project.id.to_proto(),
1538                    id,
1539                },
1540            )?;
1541        }
1542    }
1543
1544    Ok(())
1545}
1546
1547/// leave room disconnects from the room.
1548async fn leave_room(
1549    _: proto::LeaveRoom,
1550    response: Response<proto::LeaveRoom>,
1551    session: MessageContext,
1552) -> Result<()> {
1553    leave_room_for_session(&session, session.connection_id).await?;
1554    response.send(proto::Ack {})?;
1555    Ok(())
1556}
1557
1558/// Updates the permissions of someone else in the room.
1559async fn set_room_participant_role(
1560    request: proto::SetRoomParticipantRole,
1561    response: Response<proto::SetRoomParticipantRole>,
1562    session: MessageContext,
1563) -> Result<()> {
1564    let user_id = UserId::from_proto(request.user_id);
1565    let role = ChannelRole::from(request.role());
1566
1567    let (livekit_room, can_publish) = {
1568        let room = session
1569            .db()
1570            .await
1571            .set_room_participant_role(
1572                session.user_id(),
1573                RoomId::from_proto(request.room_id),
1574                user_id,
1575                role,
1576            )
1577            .await?;
1578
1579        let livekit_room = room.livekit_room.clone();
1580        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1581        room_updated(&room, &session.peer);
1582        (livekit_room, can_publish)
1583    };
1584
1585    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1586        live_kit
1587            .update_participant(
1588                livekit_room.clone(),
1589                request.user_id.to_string(),
1590                livekit_api::proto::ParticipantPermission {
1591                    can_subscribe: true,
1592                    can_publish,
1593                    can_publish_data: can_publish,
1594                    hidden: false,
1595                    recorder: false,
1596                },
1597            )
1598            .await
1599            .trace_err();
1600    }
1601
1602    response.send(proto::Ack {})?;
1603    Ok(())
1604}
1605
1606/// Call someone else into the current room
1607async fn call(
1608    request: proto::Call,
1609    response: Response<proto::Call>,
1610    session: MessageContext,
1611) -> Result<()> {
1612    let room_id = RoomId::from_proto(request.room_id);
1613    let calling_user_id = session.user_id();
1614    let calling_connection_id = session.connection_id;
1615    let called_user_id = UserId::from_proto(request.called_user_id);
1616    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1617    if !session
1618        .db()
1619        .await
1620        .has_contact(calling_user_id, called_user_id)
1621        .await?
1622    {
1623        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1624    }
1625
1626    let incoming_call = {
1627        let (room, incoming_call) = &mut *session
1628            .db()
1629            .await
1630            .call(
1631                room_id,
1632                calling_user_id,
1633                calling_connection_id,
1634                called_user_id,
1635                initial_project_id,
1636            )
1637            .await?;
1638        room_updated(room, &session.peer);
1639        mem::take(incoming_call)
1640    };
1641    update_user_contacts(called_user_id, &session).await?;
1642
1643    let mut calls = session
1644        .connection_pool()
1645        .await
1646        .user_connection_ids(called_user_id)
1647        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1648        .collect::<FuturesUnordered<_>>();
1649
1650    while let Some(call_response) = calls.next().await {
1651        match call_response.as_ref() {
1652            Ok(_) => {
1653                response.send(proto::Ack {})?;
1654                return Ok(());
1655            }
1656            Err(_) => {
1657                call_response.trace_err();
1658            }
1659        }
1660    }
1661
1662    {
1663        let room = session
1664            .db()
1665            .await
1666            .call_failed(room_id, called_user_id)
1667            .await?;
1668        room_updated(&room, &session.peer);
1669    }
1670    update_user_contacts(called_user_id, &session).await?;
1671
1672    Err(anyhow!("failed to ring user"))?
1673}
1674
1675/// Cancel an outgoing call.
1676async fn cancel_call(
1677    request: proto::CancelCall,
1678    response: Response<proto::CancelCall>,
1679    session: MessageContext,
1680) -> Result<()> {
1681    let called_user_id = UserId::from_proto(request.called_user_id);
1682    let room_id = RoomId::from_proto(request.room_id);
1683    {
1684        let room = session
1685            .db()
1686            .await
1687            .cancel_call(room_id, session.connection_id, called_user_id)
1688            .await?;
1689        room_updated(&room, &session.peer);
1690    }
1691
1692    for connection_id in session
1693        .connection_pool()
1694        .await
1695        .user_connection_ids(called_user_id)
1696    {
1697        session
1698            .peer
1699            .send(
1700                connection_id,
1701                proto::CallCanceled {
1702                    room_id: room_id.to_proto(),
1703                },
1704            )
1705            .trace_err();
1706    }
1707    response.send(proto::Ack {})?;
1708
1709    update_user_contacts(called_user_id, &session).await?;
1710    Ok(())
1711}
1712
1713/// Decline an incoming call.
1714async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1715    let room_id = RoomId::from_proto(message.room_id);
1716    {
1717        let room = session
1718            .db()
1719            .await
1720            .decline_call(Some(room_id), session.user_id())
1721            .await?
1722            .context("declining call")?;
1723        room_updated(&room, &session.peer);
1724    }
1725
1726    for connection_id in session
1727        .connection_pool()
1728        .await
1729        .user_connection_ids(session.user_id())
1730    {
1731        session
1732            .peer
1733            .send(
1734                connection_id,
1735                proto::CallCanceled {
1736                    room_id: room_id.to_proto(),
1737                },
1738            )
1739            .trace_err();
1740    }
1741    update_user_contacts(session.user_id(), &session).await?;
1742    Ok(())
1743}
1744
1745/// Updates other participants in the room with your current location.
1746async fn update_participant_location(
1747    request: proto::UpdateParticipantLocation,
1748    response: Response<proto::UpdateParticipantLocation>,
1749    session: MessageContext,
1750) -> Result<()> {
1751    let room_id = RoomId::from_proto(request.room_id);
1752    let location = request.location.context("invalid location")?;
1753
1754    let db = session.db().await;
1755    let room = db
1756        .update_room_participant_location(room_id, session.connection_id, location)
1757        .await?;
1758
1759    room_updated(&room, &session.peer);
1760    response.send(proto::Ack {})?;
1761    Ok(())
1762}
1763
1764/// Share a project into the room.
1765async fn share_project(
1766    request: proto::ShareProject,
1767    response: Response<proto::ShareProject>,
1768    session: MessageContext,
1769) -> Result<()> {
1770    let (project_id, room) = &*session
1771        .db()
1772        .await
1773        .share_project(
1774            RoomId::from_proto(request.room_id),
1775            session.connection_id,
1776            &request.worktrees,
1777            request.is_ssh_project,
1778            request.windows_paths.unwrap_or(false),
1779        )
1780        .await?;
1781    response.send(proto::ShareProjectResponse {
1782        project_id: project_id.to_proto(),
1783    })?;
1784    room_updated(room, &session.peer);
1785
1786    Ok(())
1787}
1788
1789/// Unshare a project from the room.
1790async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1791    let project_id = ProjectId::from_proto(message.project_id);
1792    unshare_project_internal(project_id, session.connection_id, &session).await
1793}
1794
1795async fn unshare_project_internal(
1796    project_id: ProjectId,
1797    connection_id: ConnectionId,
1798    session: &Session,
1799) -> Result<()> {
1800    let delete = {
1801        let room_guard = session
1802            .db()
1803            .await
1804            .unshare_project(project_id, connection_id)
1805            .await?;
1806
1807        let (delete, room, guest_connection_ids) = &*room_guard;
1808
1809        let message = proto::UnshareProject {
1810            project_id: project_id.to_proto(),
1811        };
1812
1813        broadcast(
1814            Some(connection_id),
1815            guest_connection_ids.iter().copied(),
1816            |conn_id| session.peer.send(conn_id, message.clone()),
1817        );
1818        if let Some(room) = room {
1819            room_updated(room, &session.peer);
1820        }
1821
1822        *delete
1823    };
1824
1825    if delete {
1826        let db = session.db().await;
1827        db.delete_project(project_id).await?;
1828    }
1829
1830    Ok(())
1831}
1832
1833/// Join someone elses shared project.
1834async fn join_project(
1835    request: proto::JoinProject,
1836    response: Response<proto::JoinProject>,
1837    session: MessageContext,
1838) -> Result<()> {
1839    let project_id = ProjectId::from_proto(request.project_id);
1840
1841    tracing::info!(%project_id, "join project");
1842
1843    let db = session.db().await;
1844    let (project, replica_id) = &mut *db
1845        .join_project(
1846            project_id,
1847            session.connection_id,
1848            session.user_id(),
1849            request.committer_name.clone(),
1850            request.committer_email.clone(),
1851        )
1852        .await?;
1853    drop(db);
1854    tracing::info!(%project_id, "join remote project");
1855    let collaborators = project
1856        .collaborators
1857        .iter()
1858        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1859        .map(|collaborator| collaborator.to_proto())
1860        .collect::<Vec<_>>();
1861    let project_id = project.id;
1862    let guest_user_id = session.user_id();
1863
1864    let worktrees = project
1865        .worktrees
1866        .iter()
1867        .map(|(id, worktree)| proto::WorktreeMetadata {
1868            id: *id,
1869            root_name: worktree.root_name.clone(),
1870            visible: worktree.visible,
1871            abs_path: worktree.abs_path.clone(),
1872        })
1873        .collect::<Vec<_>>();
1874
1875    let add_project_collaborator = proto::AddProjectCollaborator {
1876        project_id: project_id.to_proto(),
1877        collaborator: Some(proto::Collaborator {
1878            peer_id: Some(session.connection_id.into()),
1879            replica_id: replica_id.0 as u32,
1880            user_id: guest_user_id.to_proto(),
1881            is_host: false,
1882            committer_name: request.committer_name.clone(),
1883            committer_email: request.committer_email.clone(),
1884        }),
1885    };
1886
1887    for collaborator in &collaborators {
1888        session
1889            .peer
1890            .send(
1891                collaborator.peer_id.unwrap().into(),
1892                add_project_collaborator.clone(),
1893            )
1894            .trace_err();
1895    }
1896
1897    // First, we send the metadata associated with each worktree.
1898    let (language_servers, language_server_capabilities) = project
1899        .language_servers
1900        .clone()
1901        .into_iter()
1902        .map(|server| (server.server, server.capabilities))
1903        .unzip();
1904    response.send(proto::JoinProjectResponse {
1905        project_id: project.id.0 as u64,
1906        worktrees,
1907        replica_id: replica_id.0 as u32,
1908        collaborators,
1909        language_servers,
1910        language_server_capabilities,
1911        role: project.role.into(),
1912        windows_paths: project.path_style == PathStyle::Windows,
1913    })?;
1914
1915    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1916        // Stream this worktree's entries.
1917        let message = proto::UpdateWorktree {
1918            project_id: project_id.to_proto(),
1919            worktree_id,
1920            abs_path: worktree.abs_path.clone(),
1921            root_name: worktree.root_name,
1922            updated_entries: worktree.entries,
1923            removed_entries: Default::default(),
1924            scan_id: worktree.scan_id,
1925            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1926            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1927            removed_repositories: Default::default(),
1928        };
1929        for update in proto::split_worktree_update(message) {
1930            session.peer.send(session.connection_id, update.clone())?;
1931        }
1932
1933        // Stream this worktree's diagnostics.
1934        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1935        if let Some(summary) = worktree_diagnostics.next() {
1936            let message = proto::UpdateDiagnosticSummary {
1937                project_id: project.id.to_proto(),
1938                worktree_id: worktree.id,
1939                summary: Some(summary),
1940                more_summaries: worktree_diagnostics.collect(),
1941            };
1942            session.peer.send(session.connection_id, message)?;
1943        }
1944
1945        for settings_file in worktree.settings_files {
1946            session.peer.send(
1947                session.connection_id,
1948                proto::UpdateWorktreeSettings {
1949                    project_id: project_id.to_proto(),
1950                    worktree_id: worktree.id,
1951                    path: settings_file.path,
1952                    content: Some(settings_file.content),
1953                    kind: Some(settings_file.kind.to_proto() as i32),
1954                    outside_worktree: Some(settings_file.outside_worktree),
1955                },
1956            )?;
1957        }
1958    }
1959
1960    for repository in mem::take(&mut project.repositories) {
1961        for update in split_repository_update(repository) {
1962            session.peer.send(session.connection_id, update)?;
1963        }
1964    }
1965
1966    for language_server in &project.language_servers {
1967        session.peer.send(
1968            session.connection_id,
1969            proto::UpdateLanguageServer {
1970                project_id: project_id.to_proto(),
1971                server_name: Some(language_server.server.name.clone()),
1972                language_server_id: language_server.server.id,
1973                variant: Some(
1974                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1975                        proto::LspDiskBasedDiagnosticsUpdated {},
1976                    ),
1977                ),
1978            },
1979        )?;
1980    }
1981
1982    Ok(())
1983}
1984
1985/// Leave someone elses shared project.
1986async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
1987    let sender_id = session.connection_id;
1988    let project_id = ProjectId::from_proto(request.project_id);
1989    let db = session.db().await;
1990
1991    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1992    tracing::info!(
1993        %project_id,
1994        "leave project"
1995    );
1996
1997    project_left(project, &session);
1998    if let Some(room) = room {
1999        room_updated(room, &session.peer);
2000    }
2001
2002    Ok(())
2003}
2004
2005/// Updates other participants with changes to the project
2006async fn update_project(
2007    request: proto::UpdateProject,
2008    response: Response<proto::UpdateProject>,
2009    session: MessageContext,
2010) -> Result<()> {
2011    let project_id = ProjectId::from_proto(request.project_id);
2012    let (room, guest_connection_ids) = &*session
2013        .db()
2014        .await
2015        .update_project(project_id, session.connection_id, &request.worktrees)
2016        .await?;
2017    broadcast(
2018        Some(session.connection_id),
2019        guest_connection_ids.iter().copied(),
2020        |connection_id| {
2021            session
2022                .peer
2023                .forward_send(session.connection_id, connection_id, request.clone())
2024        },
2025    );
2026    if let Some(room) = room {
2027        room_updated(room, &session.peer);
2028    }
2029    response.send(proto::Ack {})?;
2030
2031    Ok(())
2032}
2033
2034/// Updates other participants with changes to the worktree
2035async fn update_worktree(
2036    request: proto::UpdateWorktree,
2037    response: Response<proto::UpdateWorktree>,
2038    session: MessageContext,
2039) -> Result<()> {
2040    let guest_connection_ids = session
2041        .db()
2042        .await
2043        .update_worktree(&request, session.connection_id)
2044        .await?;
2045
2046    broadcast(
2047        Some(session.connection_id),
2048        guest_connection_ids.iter().copied(),
2049        |connection_id| {
2050            session
2051                .peer
2052                .forward_send(session.connection_id, connection_id, request.clone())
2053        },
2054    );
2055    response.send(proto::Ack {})?;
2056    Ok(())
2057}
2058
2059async fn update_repository(
2060    request: proto::UpdateRepository,
2061    response: Response<proto::UpdateRepository>,
2062    session: MessageContext,
2063) -> Result<()> {
2064    let guest_connection_ids = session
2065        .db()
2066        .await
2067        .update_repository(&request, session.connection_id)
2068        .await?;
2069
2070    broadcast(
2071        Some(session.connection_id),
2072        guest_connection_ids.iter().copied(),
2073        |connection_id| {
2074            session
2075                .peer
2076                .forward_send(session.connection_id, connection_id, request.clone())
2077        },
2078    );
2079    response.send(proto::Ack {})?;
2080    Ok(())
2081}
2082
2083async fn remove_repository(
2084    request: proto::RemoveRepository,
2085    response: Response<proto::RemoveRepository>,
2086    session: MessageContext,
2087) -> Result<()> {
2088    let guest_connection_ids = session
2089        .db()
2090        .await
2091        .remove_repository(&request, session.connection_id)
2092        .await?;
2093
2094    broadcast(
2095        Some(session.connection_id),
2096        guest_connection_ids.iter().copied(),
2097        |connection_id| {
2098            session
2099                .peer
2100                .forward_send(session.connection_id, connection_id, request.clone())
2101        },
2102    );
2103    response.send(proto::Ack {})?;
2104    Ok(())
2105}
2106
2107/// Updates other participants with changes to the diagnostics
2108async fn update_diagnostic_summary(
2109    message: proto::UpdateDiagnosticSummary,
2110    session: MessageContext,
2111) -> Result<()> {
2112    let guest_connection_ids = session
2113        .db()
2114        .await
2115        .update_diagnostic_summary(&message, session.connection_id)
2116        .await?;
2117
2118    broadcast(
2119        Some(session.connection_id),
2120        guest_connection_ids.iter().copied(),
2121        |connection_id| {
2122            session
2123                .peer
2124                .forward_send(session.connection_id, connection_id, message.clone())
2125        },
2126    );
2127
2128    Ok(())
2129}
2130
2131/// Updates other participants with changes to the worktree settings
2132async fn update_worktree_settings(
2133    message: proto::UpdateWorktreeSettings,
2134    session: MessageContext,
2135) -> Result<()> {
2136    let guest_connection_ids = session
2137        .db()
2138        .await
2139        .update_worktree_settings(&message, session.connection_id)
2140        .await?;
2141
2142    broadcast(
2143        Some(session.connection_id),
2144        guest_connection_ids.iter().copied(),
2145        |connection_id| {
2146            session
2147                .peer
2148                .forward_send(session.connection_id, connection_id, message.clone())
2149        },
2150    );
2151
2152    Ok(())
2153}
2154
2155/// Notify other participants that a language server has started.
2156async fn start_language_server(
2157    request: proto::StartLanguageServer,
2158    session: MessageContext,
2159) -> Result<()> {
2160    let guest_connection_ids = session
2161        .db()
2162        .await
2163        .start_language_server(&request, session.connection_id)
2164        .await?;
2165
2166    broadcast(
2167        Some(session.connection_id),
2168        guest_connection_ids.iter().copied(),
2169        |connection_id| {
2170            session
2171                .peer
2172                .forward_send(session.connection_id, connection_id, request.clone())
2173        },
2174    );
2175    Ok(())
2176}
2177
2178/// Notify other participants that a language server has changed.
2179async fn update_language_server(
2180    request: proto::UpdateLanguageServer,
2181    session: MessageContext,
2182) -> Result<()> {
2183    let project_id = ProjectId::from_proto(request.project_id);
2184    let db = session.db().await;
2185
2186    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2187        && let Some(capabilities) = update.capabilities.clone()
2188    {
2189        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2190            .await?;
2191    }
2192
2193    let project_connection_ids = db
2194        .project_connection_ids(project_id, session.connection_id, true)
2195        .await?;
2196    broadcast(
2197        Some(session.connection_id),
2198        project_connection_ids.iter().copied(),
2199        |connection_id| {
2200            session
2201                .peer
2202                .forward_send(session.connection_id, connection_id, request.clone())
2203        },
2204    );
2205    Ok(())
2206}
2207
2208/// forward a project request to the host. These requests should be read only
2209/// as guests are allowed to send them.
2210async fn forward_read_only_project_request<T>(
2211    request: T,
2212    response: Response<T>,
2213    session: MessageContext,
2214) -> Result<()>
2215where
2216    T: EntityMessage + RequestMessage,
2217{
2218    let project_id = ProjectId::from_proto(request.remote_entity_id());
2219    let host_connection_id = session
2220        .db()
2221        .await
2222        .host_for_read_only_project_request(project_id, session.connection_id)
2223        .await?;
2224    let payload = session.forward_request(host_connection_id, request).await?;
2225    response.send(payload)?;
2226    Ok(())
2227}
2228
2229/// forward a project request to the host. These requests are disallowed
2230/// for guests.
2231async fn forward_mutating_project_request<T>(
2232    request: T,
2233    response: Response<T>,
2234    session: MessageContext,
2235) -> Result<()>
2236where
2237    T: EntityMessage + RequestMessage,
2238{
2239    let project_id = ProjectId::from_proto(request.remote_entity_id());
2240
2241    let host_connection_id = session
2242        .db()
2243        .await
2244        .host_for_mutating_project_request(project_id, session.connection_id)
2245        .await?;
2246    let payload = session.forward_request(host_connection_id, request).await?;
2247    response.send(payload)?;
2248    Ok(())
2249}
2250
2251async fn lsp_query(
2252    request: proto::LspQuery,
2253    response: Response<proto::LspQuery>,
2254    session: MessageContext,
2255) -> Result<()> {
2256    let (name, should_write) = request.query_name_and_write_permissions();
2257    tracing::Span::current().record("lsp_query_request", name);
2258    tracing::info!("lsp_query message received");
2259    if should_write {
2260        forward_mutating_project_request(request, response, session).await
2261    } else {
2262        forward_read_only_project_request(request, response, session).await
2263    }
2264}
2265
2266/// Notify other participants that a new buffer has been created
2267async fn create_buffer_for_peer(
2268    request: proto::CreateBufferForPeer,
2269    session: MessageContext,
2270) -> Result<()> {
2271    session
2272        .db()
2273        .await
2274        .check_user_is_project_host(
2275            ProjectId::from_proto(request.project_id),
2276            session.connection_id,
2277        )
2278        .await?;
2279    let peer_id = request.peer_id.context("invalid peer id")?;
2280    session
2281        .peer
2282        .forward_send(session.connection_id, peer_id.into(), request)?;
2283    Ok(())
2284}
2285
2286/// Notify other participants that a new image has been created
2287async fn create_image_for_peer(
2288    request: proto::CreateImageForPeer,
2289    session: MessageContext,
2290) -> Result<()> {
2291    session
2292        .db()
2293        .await
2294        .check_user_is_project_host(
2295            ProjectId::from_proto(request.project_id),
2296            session.connection_id,
2297        )
2298        .await?;
2299    let peer_id = request.peer_id.context("invalid peer id")?;
2300    session
2301        .peer
2302        .forward_send(session.connection_id, peer_id.into(), request)?;
2303    Ok(())
2304}
2305
2306/// Notify other participants that a buffer has been updated. This is
2307/// allowed for guests as long as the update is limited to selections.
2308async fn update_buffer(
2309    request: proto::UpdateBuffer,
2310    response: Response<proto::UpdateBuffer>,
2311    session: MessageContext,
2312) -> Result<()> {
2313    let project_id = ProjectId::from_proto(request.project_id);
2314    let mut capability = Capability::ReadOnly;
2315
2316    for op in request.operations.iter() {
2317        match op.variant {
2318            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2319            Some(_) => capability = Capability::ReadWrite,
2320        }
2321    }
2322
2323    let host = {
2324        let guard = session
2325            .db()
2326            .await
2327            .connections_for_buffer_update(project_id, session.connection_id, capability)
2328            .await?;
2329
2330        let (host, guests) = &*guard;
2331
2332        broadcast(
2333            Some(session.connection_id),
2334            guests.clone(),
2335            |connection_id| {
2336                session
2337                    .peer
2338                    .forward_send(session.connection_id, connection_id, request.clone())
2339            },
2340        );
2341
2342        *host
2343    };
2344
2345    if host != session.connection_id {
2346        session.forward_request(host, request.clone()).await?;
2347    }
2348
2349    response.send(proto::Ack {})?;
2350    Ok(())
2351}
2352
2353async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2354    let project_id = ProjectId::from_proto(message.project_id);
2355
2356    let operation = message.operation.as_ref().context("invalid operation")?;
2357    let capability = match operation.variant.as_ref() {
2358        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2359            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2360                match buffer_op.variant {
2361                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2362                        Capability::ReadOnly
2363                    }
2364                    _ => Capability::ReadWrite,
2365                }
2366            } else {
2367                Capability::ReadWrite
2368            }
2369        }
2370        Some(_) => Capability::ReadWrite,
2371        None => Capability::ReadOnly,
2372    };
2373
2374    let guard = session
2375        .db()
2376        .await
2377        .connections_for_buffer_update(project_id, session.connection_id, capability)
2378        .await?;
2379
2380    let (host, guests) = &*guard;
2381
2382    broadcast(
2383        Some(session.connection_id),
2384        guests.iter().chain([host]).copied(),
2385        |connection_id| {
2386            session
2387                .peer
2388                .forward_send(session.connection_id, connection_id, message.clone())
2389        },
2390    );
2391
2392    Ok(())
2393}
2394
2395async fn forward_project_search_chunk(
2396    message: proto::FindSearchCandidatesChunk,
2397    response: Response<proto::FindSearchCandidatesChunk>,
2398    session: MessageContext,
2399) -> Result<()> {
2400    let peer_id = message.peer_id.context("missing peer_id")?;
2401    let payload = session
2402        .peer
2403        .forward_request(session.connection_id, peer_id.into(), message)
2404        .await?;
2405    response.send(payload)?;
2406    Ok(())
2407}
2408
2409/// Notify other participants that a project has been updated.
2410async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2411    request: T,
2412    session: MessageContext,
2413) -> Result<()> {
2414    let project_id = ProjectId::from_proto(request.remote_entity_id());
2415    let project_connection_ids = session
2416        .db()
2417        .await
2418        .project_connection_ids(project_id, session.connection_id, false)
2419        .await?;
2420
2421    broadcast(
2422        Some(session.connection_id),
2423        project_connection_ids.iter().copied(),
2424        |connection_id| {
2425            session
2426                .peer
2427                .forward_send(session.connection_id, connection_id, request.clone())
2428        },
2429    );
2430    Ok(())
2431}
2432
2433/// Start following another user in a call.
2434async fn follow(
2435    request: proto::Follow,
2436    response: Response<proto::Follow>,
2437    session: MessageContext,
2438) -> Result<()> {
2439    let room_id = RoomId::from_proto(request.room_id);
2440    let project_id = request.project_id.map(ProjectId::from_proto);
2441    let leader_id = request.leader_id.context("invalid leader id")?.into();
2442    let follower_id = session.connection_id;
2443
2444    session
2445        .db()
2446        .await
2447        .check_room_participants(room_id, leader_id, session.connection_id)
2448        .await?;
2449
2450    let response_payload = session.forward_request(leader_id, request).await?;
2451    response.send(response_payload)?;
2452
2453    if let Some(project_id) = project_id {
2454        let room = session
2455            .db()
2456            .await
2457            .follow(room_id, project_id, leader_id, follower_id)
2458            .await?;
2459        room_updated(&room, &session.peer);
2460    }
2461
2462    Ok(())
2463}
2464
2465/// Stop following another user in a call.
2466async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2467    let room_id = RoomId::from_proto(request.room_id);
2468    let project_id = request.project_id.map(ProjectId::from_proto);
2469    let leader_id = request.leader_id.context("invalid leader id")?.into();
2470    let follower_id = session.connection_id;
2471
2472    session
2473        .db()
2474        .await
2475        .check_room_participants(room_id, leader_id, session.connection_id)
2476        .await?;
2477
2478    session
2479        .peer
2480        .forward_send(session.connection_id, leader_id, request)?;
2481
2482    if let Some(project_id) = project_id {
2483        let room = session
2484            .db()
2485            .await
2486            .unfollow(room_id, project_id, leader_id, follower_id)
2487            .await?;
2488        room_updated(&room, &session.peer);
2489    }
2490
2491    Ok(())
2492}
2493
2494/// Notify everyone following you of your current location.
2495async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2496    let room_id = RoomId::from_proto(request.room_id);
2497    let database = session.db.lock().await;
2498
2499    let connection_ids = if let Some(project_id) = request.project_id {
2500        let project_id = ProjectId::from_proto(project_id);
2501        database
2502            .project_connection_ids(project_id, session.connection_id, true)
2503            .await?
2504    } else {
2505        database
2506            .room_connection_ids(room_id, session.connection_id)
2507            .await?
2508    };
2509
2510    // For now, don't send view update messages back to that view's current leader.
2511    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2512        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2513        _ => None,
2514    });
2515
2516    for connection_id in connection_ids.iter().cloned() {
2517        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2518            session
2519                .peer
2520                .forward_send(session.connection_id, connection_id, request.clone())?;
2521        }
2522    }
2523    Ok(())
2524}
2525
2526/// Get public data about users.
2527async fn get_users(
2528    request: proto::GetUsers,
2529    response: Response<proto::GetUsers>,
2530    session: MessageContext,
2531) -> Result<()> {
2532    let user_ids = request
2533        .user_ids
2534        .into_iter()
2535        .map(UserId::from_proto)
2536        .collect();
2537    let users = session
2538        .db()
2539        .await
2540        .get_users_by_ids(user_ids)
2541        .await?
2542        .into_iter()
2543        .map(|user| proto::User {
2544            id: user.id.to_proto(),
2545            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2546            github_login: user.github_login,
2547            name: user.name,
2548        })
2549        .collect();
2550    response.send(proto::UsersResponse { users })?;
2551    Ok(())
2552}
2553
2554/// Search for users (to invite) buy Github login
2555async fn fuzzy_search_users(
2556    request: proto::FuzzySearchUsers,
2557    response: Response<proto::FuzzySearchUsers>,
2558    session: MessageContext,
2559) -> Result<()> {
2560    let query = request.query;
2561    let users = match query.len() {
2562        0 => vec![],
2563        1 | 2 => session
2564            .db()
2565            .await
2566            .get_user_by_github_login(&query)
2567            .await?
2568            .into_iter()
2569            .collect(),
2570        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2571    };
2572    let users = users
2573        .into_iter()
2574        .filter(|user| user.id != session.user_id())
2575        .map(|user| proto::User {
2576            id: user.id.to_proto(),
2577            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2578            github_login: user.github_login,
2579            name: user.name,
2580        })
2581        .collect();
2582    response.send(proto::UsersResponse { users })?;
2583    Ok(())
2584}
2585
2586/// Send a contact request to another user.
2587async fn request_contact(
2588    request: proto::RequestContact,
2589    response: Response<proto::RequestContact>,
2590    session: MessageContext,
2591) -> Result<()> {
2592    let requester_id = session.user_id();
2593    let responder_id = UserId::from_proto(request.responder_id);
2594    if requester_id == responder_id {
2595        return Err(anyhow!("cannot add yourself as a contact"))?;
2596    }
2597
2598    let notifications = session
2599        .db()
2600        .await
2601        .send_contact_request(requester_id, responder_id)
2602        .await?;
2603
2604    // Update outgoing contact requests of requester
2605    let mut update = proto::UpdateContacts::default();
2606    update.outgoing_requests.push(responder_id.to_proto());
2607    for connection_id in session
2608        .connection_pool()
2609        .await
2610        .user_connection_ids(requester_id)
2611    {
2612        session.peer.send(connection_id, update.clone())?;
2613    }
2614
2615    // Update incoming contact requests of responder
2616    let mut update = proto::UpdateContacts::default();
2617    update
2618        .incoming_requests
2619        .push(proto::IncomingContactRequest {
2620            requester_id: requester_id.to_proto(),
2621        });
2622    let connection_pool = session.connection_pool().await;
2623    for connection_id in connection_pool.user_connection_ids(responder_id) {
2624        session.peer.send(connection_id, update.clone())?;
2625    }
2626
2627    send_notifications(&connection_pool, &session.peer, notifications);
2628
2629    response.send(proto::Ack {})?;
2630    Ok(())
2631}
2632
2633/// Accept or decline a contact request
2634async fn respond_to_contact_request(
2635    request: proto::RespondToContactRequest,
2636    response: Response<proto::RespondToContactRequest>,
2637    session: MessageContext,
2638) -> Result<()> {
2639    let responder_id = session.user_id();
2640    let requester_id = UserId::from_proto(request.requester_id);
2641    let db = session.db().await;
2642    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2643        db.dismiss_contact_notification(responder_id, requester_id)
2644            .await?;
2645    } else {
2646        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2647
2648        let notifications = db
2649            .respond_to_contact_request(responder_id, requester_id, accept)
2650            .await?;
2651        let requester_busy = db.is_user_busy(requester_id).await?;
2652        let responder_busy = db.is_user_busy(responder_id).await?;
2653
2654        let pool = session.connection_pool().await;
2655        // Update responder with new contact
2656        let mut update = proto::UpdateContacts::default();
2657        if accept {
2658            update
2659                .contacts
2660                .push(contact_for_user(requester_id, requester_busy, &pool));
2661        }
2662        update
2663            .remove_incoming_requests
2664            .push(requester_id.to_proto());
2665        for connection_id in pool.user_connection_ids(responder_id) {
2666            session.peer.send(connection_id, update.clone())?;
2667        }
2668
2669        // Update requester with new contact
2670        let mut update = proto::UpdateContacts::default();
2671        if accept {
2672            update
2673                .contacts
2674                .push(contact_for_user(responder_id, responder_busy, &pool));
2675        }
2676        update
2677            .remove_outgoing_requests
2678            .push(responder_id.to_proto());
2679
2680        for connection_id in pool.user_connection_ids(requester_id) {
2681            session.peer.send(connection_id, update.clone())?;
2682        }
2683
2684        send_notifications(&pool, &session.peer, notifications);
2685    }
2686
2687    response.send(proto::Ack {})?;
2688    Ok(())
2689}
2690
2691/// Remove a contact.
2692async fn remove_contact(
2693    request: proto::RemoveContact,
2694    response: Response<proto::RemoveContact>,
2695    session: MessageContext,
2696) -> Result<()> {
2697    let requester_id = session.user_id();
2698    let responder_id = UserId::from_proto(request.user_id);
2699    let db = session.db().await;
2700    let (contact_accepted, deleted_notification_id) =
2701        db.remove_contact(requester_id, responder_id).await?;
2702
2703    let pool = session.connection_pool().await;
2704    // Update outgoing contact requests of requester
2705    let mut update = proto::UpdateContacts::default();
2706    if contact_accepted {
2707        update.remove_contacts.push(responder_id.to_proto());
2708    } else {
2709        update
2710            .remove_outgoing_requests
2711            .push(responder_id.to_proto());
2712    }
2713    for connection_id in pool.user_connection_ids(requester_id) {
2714        session.peer.send(connection_id, update.clone())?;
2715    }
2716
2717    // Update incoming contact requests of responder
2718    let mut update = proto::UpdateContacts::default();
2719    if contact_accepted {
2720        update.remove_contacts.push(requester_id.to_proto());
2721    } else {
2722        update
2723            .remove_incoming_requests
2724            .push(requester_id.to_proto());
2725    }
2726    for connection_id in pool.user_connection_ids(responder_id) {
2727        session.peer.send(connection_id, update.clone())?;
2728        if let Some(notification_id) = deleted_notification_id {
2729            session.peer.send(
2730                connection_id,
2731                proto::DeleteNotification {
2732                    notification_id: notification_id.to_proto(),
2733                },
2734            )?;
2735        }
2736    }
2737
2738    response.send(proto::Ack {})?;
2739    Ok(())
2740}
2741
2742fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2743    version.0.minor < 139
2744}
2745
2746async fn subscribe_to_channels(
2747    _: proto::SubscribeToChannels,
2748    session: MessageContext,
2749) -> Result<()> {
2750    subscribe_user_to_channels(session.user_id(), &session).await?;
2751    Ok(())
2752}
2753
2754async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2755    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2756    let mut pool = session.connection_pool().await;
2757    for membership in &channels_for_user.channel_memberships {
2758        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2759    }
2760    session.peer.send(
2761        session.connection_id,
2762        build_update_user_channels(&channels_for_user),
2763    )?;
2764    session.peer.send(
2765        session.connection_id,
2766        build_channels_update(channels_for_user),
2767    )?;
2768    Ok(())
2769}
2770
2771/// Creates a new channel.
2772async fn create_channel(
2773    request: proto::CreateChannel,
2774    response: Response<proto::CreateChannel>,
2775    session: MessageContext,
2776) -> Result<()> {
2777    let db = session.db().await;
2778
2779    let parent_id = request.parent_id.map(ChannelId::from_proto);
2780    let (channel, membership) = db
2781        .create_channel(&request.name, parent_id, session.user_id())
2782        .await?;
2783
2784    let root_id = channel.root_id();
2785    let channel = Channel::from_model(channel);
2786
2787    response.send(proto::CreateChannelResponse {
2788        channel: Some(channel.to_proto()),
2789        parent_id: request.parent_id,
2790    })?;
2791
2792    let mut connection_pool = session.connection_pool().await;
2793    if let Some(membership) = membership {
2794        connection_pool.subscribe_to_channel(
2795            membership.user_id,
2796            membership.channel_id,
2797            membership.role,
2798        );
2799        let update = proto::UpdateUserChannels {
2800            channel_memberships: vec![proto::ChannelMembership {
2801                channel_id: membership.channel_id.to_proto(),
2802                role: membership.role.into(),
2803            }],
2804            ..Default::default()
2805        };
2806        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2807            session.peer.send(connection_id, update.clone())?;
2808        }
2809    }
2810
2811    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2812        if !role.can_see_channel(channel.visibility) {
2813            continue;
2814        }
2815
2816        let update = proto::UpdateChannels {
2817            channels: vec![channel.to_proto()],
2818            ..Default::default()
2819        };
2820        session.peer.send(connection_id, update.clone())?;
2821    }
2822
2823    Ok(())
2824}
2825
2826/// Delete a channel
2827async fn delete_channel(
2828    request: proto::DeleteChannel,
2829    response: Response<proto::DeleteChannel>,
2830    session: MessageContext,
2831) -> Result<()> {
2832    let db = session.db().await;
2833
2834    let channel_id = request.channel_id;
2835    let (root_channel, removed_channels) = db
2836        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2837        .await?;
2838    response.send(proto::Ack {})?;
2839
2840    // Notify members of removed channels
2841    let mut update = proto::UpdateChannels::default();
2842    update
2843        .delete_channels
2844        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2845
2846    let connection_pool = session.connection_pool().await;
2847    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2848        session.peer.send(connection_id, update.clone())?;
2849    }
2850
2851    Ok(())
2852}
2853
2854/// Invite someone to join a channel.
2855async fn invite_channel_member(
2856    request: proto::InviteChannelMember,
2857    response: Response<proto::InviteChannelMember>,
2858    session: MessageContext,
2859) -> Result<()> {
2860    let db = session.db().await;
2861    let channel_id = ChannelId::from_proto(request.channel_id);
2862    let invitee_id = UserId::from_proto(request.user_id);
2863    let InviteMemberResult {
2864        channel,
2865        notifications,
2866    } = db
2867        .invite_channel_member(
2868            channel_id,
2869            invitee_id,
2870            session.user_id(),
2871            request.role().into(),
2872        )
2873        .await?;
2874
2875    let update = proto::UpdateChannels {
2876        channel_invitations: vec![channel.to_proto()],
2877        ..Default::default()
2878    };
2879
2880    let connection_pool = session.connection_pool().await;
2881    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2882        session.peer.send(connection_id, update.clone())?;
2883    }
2884
2885    send_notifications(&connection_pool, &session.peer, notifications);
2886
2887    response.send(proto::Ack {})?;
2888    Ok(())
2889}
2890
2891/// remove someone from a channel
2892async fn remove_channel_member(
2893    request: proto::RemoveChannelMember,
2894    response: Response<proto::RemoveChannelMember>,
2895    session: MessageContext,
2896) -> Result<()> {
2897    let db = session.db().await;
2898    let channel_id = ChannelId::from_proto(request.channel_id);
2899    let member_id = UserId::from_proto(request.user_id);
2900
2901    let RemoveChannelMemberResult {
2902        membership_update,
2903        notification_id,
2904    } = db
2905        .remove_channel_member(channel_id, member_id, session.user_id())
2906        .await?;
2907
2908    let mut connection_pool = session.connection_pool().await;
2909    notify_membership_updated(
2910        &mut connection_pool,
2911        membership_update,
2912        member_id,
2913        &session.peer,
2914    );
2915    for connection_id in connection_pool.user_connection_ids(member_id) {
2916        if let Some(notification_id) = notification_id {
2917            session
2918                .peer
2919                .send(
2920                    connection_id,
2921                    proto::DeleteNotification {
2922                        notification_id: notification_id.to_proto(),
2923                    },
2924                )
2925                .trace_err();
2926        }
2927    }
2928
2929    response.send(proto::Ack {})?;
2930    Ok(())
2931}
2932
2933/// Toggle the channel between public and private.
2934/// Care is taken to maintain the invariant that public channels only descend from public channels,
2935/// (though members-only channels can appear at any point in the hierarchy).
2936async fn set_channel_visibility(
2937    request: proto::SetChannelVisibility,
2938    response: Response<proto::SetChannelVisibility>,
2939    session: MessageContext,
2940) -> Result<()> {
2941    let db = session.db().await;
2942    let channel_id = ChannelId::from_proto(request.channel_id);
2943    let visibility = request.visibility().into();
2944
2945    let channel_model = db
2946        .set_channel_visibility(channel_id, visibility, session.user_id())
2947        .await?;
2948    let root_id = channel_model.root_id();
2949    let channel = Channel::from_model(channel_model);
2950
2951    let mut connection_pool = session.connection_pool().await;
2952    for (user_id, role) in connection_pool
2953        .channel_user_ids(root_id)
2954        .collect::<Vec<_>>()
2955        .into_iter()
2956    {
2957        let update = if role.can_see_channel(channel.visibility) {
2958            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2959            proto::UpdateChannels {
2960                channels: vec![channel.to_proto()],
2961                ..Default::default()
2962            }
2963        } else {
2964            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2965            proto::UpdateChannels {
2966                delete_channels: vec![channel.id.to_proto()],
2967                ..Default::default()
2968            }
2969        };
2970
2971        for connection_id in connection_pool.user_connection_ids(user_id) {
2972            session.peer.send(connection_id, update.clone())?;
2973        }
2974    }
2975
2976    response.send(proto::Ack {})?;
2977    Ok(())
2978}
2979
2980/// Alter the role for a user in the channel.
2981async fn set_channel_member_role(
2982    request: proto::SetChannelMemberRole,
2983    response: Response<proto::SetChannelMemberRole>,
2984    session: MessageContext,
2985) -> Result<()> {
2986    let db = session.db().await;
2987    let channel_id = ChannelId::from_proto(request.channel_id);
2988    let member_id = UserId::from_proto(request.user_id);
2989    let result = db
2990        .set_channel_member_role(
2991            channel_id,
2992            session.user_id(),
2993            member_id,
2994            request.role().into(),
2995        )
2996        .await?;
2997
2998    match result {
2999        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3000            let mut connection_pool = session.connection_pool().await;
3001            notify_membership_updated(
3002                &mut connection_pool,
3003                membership_update,
3004                member_id,
3005                &session.peer,
3006            )
3007        }
3008        db::SetMemberRoleResult::InviteUpdated(channel) => {
3009            let update = proto::UpdateChannels {
3010                channel_invitations: vec![channel.to_proto()],
3011                ..Default::default()
3012            };
3013
3014            for connection_id in session
3015                .connection_pool()
3016                .await
3017                .user_connection_ids(member_id)
3018            {
3019                session.peer.send(connection_id, update.clone())?;
3020            }
3021        }
3022    }
3023
3024    response.send(proto::Ack {})?;
3025    Ok(())
3026}
3027
3028/// Change the name of a channel
3029async fn rename_channel(
3030    request: proto::RenameChannel,
3031    response: Response<proto::RenameChannel>,
3032    session: MessageContext,
3033) -> Result<()> {
3034    let db = session.db().await;
3035    let channel_id = ChannelId::from_proto(request.channel_id);
3036    let channel_model = db
3037        .rename_channel(channel_id, session.user_id(), &request.name)
3038        .await?;
3039    let root_id = channel_model.root_id();
3040    let channel = Channel::from_model(channel_model);
3041
3042    response.send(proto::RenameChannelResponse {
3043        channel: Some(channel.to_proto()),
3044    })?;
3045
3046    let connection_pool = session.connection_pool().await;
3047    let update = proto::UpdateChannels {
3048        channels: vec![channel.to_proto()],
3049        ..Default::default()
3050    };
3051    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3052        if role.can_see_channel(channel.visibility) {
3053            session.peer.send(connection_id, update.clone())?;
3054        }
3055    }
3056
3057    Ok(())
3058}
3059
3060/// Move a channel to a new parent.
3061async fn move_channel(
3062    request: proto::MoveChannel,
3063    response: Response<proto::MoveChannel>,
3064    session: MessageContext,
3065) -> Result<()> {
3066    let channel_id = ChannelId::from_proto(request.channel_id);
3067    let to = ChannelId::from_proto(request.to);
3068
3069    let (root_id, channels) = session
3070        .db()
3071        .await
3072        .move_channel(channel_id, to, session.user_id())
3073        .await?;
3074
3075    let connection_pool = session.connection_pool().await;
3076    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3077        let channels = channels
3078            .iter()
3079            .filter_map(|channel| {
3080                if role.can_see_channel(channel.visibility) {
3081                    Some(channel.to_proto())
3082                } else {
3083                    None
3084                }
3085            })
3086            .collect::<Vec<_>>();
3087        if channels.is_empty() {
3088            continue;
3089        }
3090
3091        let update = proto::UpdateChannels {
3092            channels,
3093            ..Default::default()
3094        };
3095
3096        session.peer.send(connection_id, update.clone())?;
3097    }
3098
3099    response.send(Ack {})?;
3100    Ok(())
3101}
3102
3103async fn reorder_channel(
3104    request: proto::ReorderChannel,
3105    response: Response<proto::ReorderChannel>,
3106    session: MessageContext,
3107) -> Result<()> {
3108    let channel_id = ChannelId::from_proto(request.channel_id);
3109    let direction = request.direction();
3110
3111    let updated_channels = session
3112        .db()
3113        .await
3114        .reorder_channel(channel_id, direction, session.user_id())
3115        .await?;
3116
3117    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3118        let connection_pool = session.connection_pool().await;
3119        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3120            let channels = updated_channels
3121                .iter()
3122                .filter_map(|channel| {
3123                    if role.can_see_channel(channel.visibility) {
3124                        Some(channel.to_proto())
3125                    } else {
3126                        None
3127                    }
3128                })
3129                .collect::<Vec<_>>();
3130
3131            if channels.is_empty() {
3132                continue;
3133            }
3134
3135            let update = proto::UpdateChannels {
3136                channels,
3137                ..Default::default()
3138            };
3139
3140            session.peer.send(connection_id, update.clone())?;
3141        }
3142    }
3143
3144    response.send(Ack {})?;
3145    Ok(())
3146}
3147
3148/// Get the list of channel members
3149async fn get_channel_members(
3150    request: proto::GetChannelMembers,
3151    response: Response<proto::GetChannelMembers>,
3152    session: MessageContext,
3153) -> Result<()> {
3154    let db = session.db().await;
3155    let channel_id = ChannelId::from_proto(request.channel_id);
3156    let limit = if request.limit == 0 {
3157        u16::MAX as u64
3158    } else {
3159        request.limit
3160    };
3161    let (members, users) = db
3162        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3163        .await?;
3164    response.send(proto::GetChannelMembersResponse { members, users })?;
3165    Ok(())
3166}
3167
3168/// Accept or decline a channel invitation.
3169async fn respond_to_channel_invite(
3170    request: proto::RespondToChannelInvite,
3171    response: Response<proto::RespondToChannelInvite>,
3172    session: MessageContext,
3173) -> Result<()> {
3174    let db = session.db().await;
3175    let channel_id = ChannelId::from_proto(request.channel_id);
3176    let RespondToChannelInvite {
3177        membership_update,
3178        notifications,
3179    } = db
3180        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3181        .await?;
3182
3183    let mut connection_pool = session.connection_pool().await;
3184    if let Some(membership_update) = membership_update {
3185        notify_membership_updated(
3186            &mut connection_pool,
3187            membership_update,
3188            session.user_id(),
3189            &session.peer,
3190        );
3191    } else {
3192        let update = proto::UpdateChannels {
3193            remove_channel_invitations: vec![channel_id.to_proto()],
3194            ..Default::default()
3195        };
3196
3197        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3198            session.peer.send(connection_id, update.clone())?;
3199        }
3200    };
3201
3202    send_notifications(&connection_pool, &session.peer, notifications);
3203
3204    response.send(proto::Ack {})?;
3205
3206    Ok(())
3207}
3208
3209/// Join the channels' room
3210async fn join_channel(
3211    request: proto::JoinChannel,
3212    response: Response<proto::JoinChannel>,
3213    session: MessageContext,
3214) -> Result<()> {
3215    let channel_id = ChannelId::from_proto(request.channel_id);
3216    join_channel_internal(channel_id, Box::new(response), session).await
3217}
3218
3219trait JoinChannelInternalResponse {
3220    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3221}
3222impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3223    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3224        Response::<proto::JoinChannel>::send(self, result)
3225    }
3226}
3227impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3228    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3229        Response::<proto::JoinRoom>::send(self, result)
3230    }
3231}
3232
3233async fn join_channel_internal(
3234    channel_id: ChannelId,
3235    response: Box<impl JoinChannelInternalResponse>,
3236    session: MessageContext,
3237) -> Result<()> {
3238    let joined_room = {
3239        let mut db = session.db().await;
3240        // If zed quits without leaving the room, and the user re-opens zed before the
3241        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3242        // room they were in.
3243        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3244            tracing::info!(
3245                stale_connection_id = %connection,
3246                "cleaning up stale connection",
3247            );
3248            drop(db);
3249            leave_room_for_session(&session, connection).await?;
3250            db = session.db().await;
3251        }
3252
3253        let (joined_room, membership_updated, role) = db
3254            .join_channel(channel_id, session.user_id(), session.connection_id)
3255            .await?;
3256
3257        let live_kit_connection_info =
3258            session
3259                .app_state
3260                .livekit_client
3261                .as_ref()
3262                .and_then(|live_kit| {
3263                    let (can_publish, token) = if role == ChannelRole::Guest {
3264                        (
3265                            false,
3266                            live_kit
3267                                .guest_token(
3268                                    &joined_room.room.livekit_room,
3269                                    &session.user_id().to_string(),
3270                                )
3271                                .trace_err()?,
3272                        )
3273                    } else {
3274                        (
3275                            true,
3276                            live_kit
3277                                .room_token(
3278                                    &joined_room.room.livekit_room,
3279                                    &session.user_id().to_string(),
3280                                )
3281                                .trace_err()?,
3282                        )
3283                    };
3284
3285                    Some(LiveKitConnectionInfo {
3286                        server_url: live_kit.url().into(),
3287                        token,
3288                        can_publish,
3289                    })
3290                });
3291
3292        response.send(proto::JoinRoomResponse {
3293            room: Some(joined_room.room.clone()),
3294            channel_id: joined_room
3295                .channel
3296                .as_ref()
3297                .map(|channel| channel.id.to_proto()),
3298            live_kit_connection_info,
3299        })?;
3300
3301        let mut connection_pool = session.connection_pool().await;
3302        if let Some(membership_updated) = membership_updated {
3303            notify_membership_updated(
3304                &mut connection_pool,
3305                membership_updated,
3306                session.user_id(),
3307                &session.peer,
3308            );
3309        }
3310
3311        room_updated(&joined_room.room, &session.peer);
3312
3313        joined_room
3314    };
3315
3316    channel_updated(
3317        &joined_room.channel.context("channel not returned")?,
3318        &joined_room.room,
3319        &session.peer,
3320        &*session.connection_pool().await,
3321    );
3322
3323    update_user_contacts(session.user_id(), &session).await?;
3324    Ok(())
3325}
3326
3327/// Start editing the channel notes
3328async fn join_channel_buffer(
3329    request: proto::JoinChannelBuffer,
3330    response: Response<proto::JoinChannelBuffer>,
3331    session: MessageContext,
3332) -> Result<()> {
3333    let db = session.db().await;
3334    let channel_id = ChannelId::from_proto(request.channel_id);
3335
3336    let open_response = db
3337        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3338        .await?;
3339
3340    let collaborators = open_response.collaborators.clone();
3341    response.send(open_response)?;
3342
3343    let update = UpdateChannelBufferCollaborators {
3344        channel_id: channel_id.to_proto(),
3345        collaborators: collaborators.clone(),
3346    };
3347    channel_buffer_updated(
3348        session.connection_id,
3349        collaborators
3350            .iter()
3351            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3352        &update,
3353        &session.peer,
3354    );
3355
3356    Ok(())
3357}
3358
3359/// Edit the channel notes
3360async fn update_channel_buffer(
3361    request: proto::UpdateChannelBuffer,
3362    session: MessageContext,
3363) -> Result<()> {
3364    let db = session.db().await;
3365    let channel_id = ChannelId::from_proto(request.channel_id);
3366
3367    let (collaborators, epoch, version) = db
3368        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3369        .await?;
3370
3371    channel_buffer_updated(
3372        session.connection_id,
3373        collaborators.clone(),
3374        &proto::UpdateChannelBuffer {
3375            channel_id: channel_id.to_proto(),
3376            operations: request.operations,
3377        },
3378        &session.peer,
3379    );
3380
3381    let pool = &*session.connection_pool().await;
3382
3383    let non_collaborators =
3384        pool.channel_connection_ids(channel_id)
3385            .filter_map(|(connection_id, _)| {
3386                if collaborators.contains(&connection_id) {
3387                    None
3388                } else {
3389                    Some(connection_id)
3390                }
3391            });
3392
3393    broadcast(None, non_collaborators, |peer_id| {
3394        session.peer.send(
3395            peer_id,
3396            proto::UpdateChannels {
3397                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3398                    channel_id: channel_id.to_proto(),
3399                    epoch: epoch as u64,
3400                    version: version.clone(),
3401                }],
3402                ..Default::default()
3403            },
3404        )
3405    });
3406
3407    Ok(())
3408}
3409
3410/// Rejoin the channel notes after a connection blip
3411async fn rejoin_channel_buffers(
3412    request: proto::RejoinChannelBuffers,
3413    response: Response<proto::RejoinChannelBuffers>,
3414    session: MessageContext,
3415) -> Result<()> {
3416    let db = session.db().await;
3417    let buffers = db
3418        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3419        .await?;
3420
3421    for rejoined_buffer in &buffers {
3422        let collaborators_to_notify = rejoined_buffer
3423            .buffer
3424            .collaborators
3425            .iter()
3426            .filter_map(|c| Some(c.peer_id?.into()));
3427        channel_buffer_updated(
3428            session.connection_id,
3429            collaborators_to_notify,
3430            &proto::UpdateChannelBufferCollaborators {
3431                channel_id: rejoined_buffer.buffer.channel_id,
3432                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3433            },
3434            &session.peer,
3435        );
3436    }
3437
3438    response.send(proto::RejoinChannelBuffersResponse {
3439        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3440    })?;
3441
3442    Ok(())
3443}
3444
3445/// Stop editing the channel notes
3446async fn leave_channel_buffer(
3447    request: proto::LeaveChannelBuffer,
3448    response: Response<proto::LeaveChannelBuffer>,
3449    session: MessageContext,
3450) -> Result<()> {
3451    let db = session.db().await;
3452    let channel_id = ChannelId::from_proto(request.channel_id);
3453
3454    let left_buffer = db
3455        .leave_channel_buffer(channel_id, session.connection_id)
3456        .await?;
3457
3458    response.send(Ack {})?;
3459
3460    channel_buffer_updated(
3461        session.connection_id,
3462        left_buffer.connections,
3463        &proto::UpdateChannelBufferCollaborators {
3464            channel_id: channel_id.to_proto(),
3465            collaborators: left_buffer.collaborators,
3466        },
3467        &session.peer,
3468    );
3469
3470    Ok(())
3471}
3472
3473fn channel_buffer_updated<T: EnvelopedMessage>(
3474    sender_id: ConnectionId,
3475    collaborators: impl IntoIterator<Item = ConnectionId>,
3476    message: &T,
3477    peer: &Peer,
3478) {
3479    broadcast(Some(sender_id), collaborators, |peer_id| {
3480        peer.send(peer_id, message.clone())
3481    });
3482}
3483
3484fn send_notifications(
3485    connection_pool: &ConnectionPool,
3486    peer: &Peer,
3487    notifications: db::NotificationBatch,
3488) {
3489    for (user_id, notification) in notifications {
3490        for connection_id in connection_pool.user_connection_ids(user_id) {
3491            if let Err(error) = peer.send(
3492                connection_id,
3493                proto::AddNotification {
3494                    notification: Some(notification.clone()),
3495                },
3496            ) {
3497                tracing::error!(
3498                    "failed to send notification to {:?} {}",
3499                    connection_id,
3500                    error
3501                );
3502            }
3503        }
3504    }
3505}
3506
3507/// Send a message to the channel
3508async fn send_channel_message(
3509    _request: proto::SendChannelMessage,
3510    _response: Response<proto::SendChannelMessage>,
3511    _session: MessageContext,
3512) -> Result<()> {
3513    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3514}
3515
3516/// Delete a channel message
3517async fn remove_channel_message(
3518    _request: proto::RemoveChannelMessage,
3519    _response: Response<proto::RemoveChannelMessage>,
3520    _session: MessageContext,
3521) -> Result<()> {
3522    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3523}
3524
3525async fn update_channel_message(
3526    _request: proto::UpdateChannelMessage,
3527    _response: Response<proto::UpdateChannelMessage>,
3528    _session: MessageContext,
3529) -> Result<()> {
3530    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3531}
3532
3533/// Mark a channel message as read
3534async fn acknowledge_channel_message(
3535    _request: proto::AckChannelMessage,
3536    _session: MessageContext,
3537) -> Result<()> {
3538    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3539}
3540
3541/// Mark a buffer version as synced
3542async fn acknowledge_buffer_version(
3543    request: proto::AckBufferOperation,
3544    session: MessageContext,
3545) -> Result<()> {
3546    let buffer_id = BufferId::from_proto(request.buffer_id);
3547    session
3548        .db()
3549        .await
3550        .observe_buffer_version(
3551            buffer_id,
3552            session.user_id(),
3553            request.epoch as i32,
3554            &request.version,
3555        )
3556        .await?;
3557    Ok(())
3558}
3559
3560/// Start receiving chat updates for a channel
3561async fn join_channel_chat(
3562    _request: proto::JoinChannelChat,
3563    _response: Response<proto::JoinChannelChat>,
3564    _session: MessageContext,
3565) -> Result<()> {
3566    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3567}
3568
3569/// Stop receiving chat updates for a channel
3570async fn leave_channel_chat(
3571    _request: proto::LeaveChannelChat,
3572    _session: MessageContext,
3573) -> Result<()> {
3574    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3575}
3576
3577/// Retrieve the chat history for a channel
3578async fn get_channel_messages(
3579    _request: proto::GetChannelMessages,
3580    _response: Response<proto::GetChannelMessages>,
3581    _session: MessageContext,
3582) -> Result<()> {
3583    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Retrieve specific chat messages
3587async fn get_channel_messages_by_id(
3588    _request: proto::GetChannelMessagesById,
3589    _response: Response<proto::GetChannelMessagesById>,
3590    _session: MessageContext,
3591) -> Result<()> {
3592    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3593}
3594
3595/// Retrieve the current users notifications
3596async fn get_notifications(
3597    request: proto::GetNotifications,
3598    response: Response<proto::GetNotifications>,
3599    session: MessageContext,
3600) -> Result<()> {
3601    let notifications = session
3602        .db()
3603        .await
3604        .get_notifications(
3605            session.user_id(),
3606            NOTIFICATION_COUNT_PER_PAGE,
3607            request.before_id.map(db::NotificationId::from_proto),
3608        )
3609        .await?;
3610    response.send(proto::GetNotificationsResponse {
3611        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3612        notifications,
3613    })?;
3614    Ok(())
3615}
3616
3617/// Mark notifications as read
3618async fn mark_notification_as_read(
3619    request: proto::MarkNotificationRead,
3620    response: Response<proto::MarkNotificationRead>,
3621    session: MessageContext,
3622) -> Result<()> {
3623    let database = &session.db().await;
3624    let notifications = database
3625        .mark_notification_as_read_by_id(
3626            session.user_id(),
3627            NotificationId::from_proto(request.notification_id),
3628        )
3629        .await?;
3630    send_notifications(
3631        &*session.connection_pool().await,
3632        &session.peer,
3633        notifications,
3634    );
3635    response.send(proto::Ack {})?;
3636    Ok(())
3637}
3638
3639fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3640    let message = match message {
3641        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3642        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3643        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3644        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3645        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3646            code: frame.code.into(),
3647            reason: frame.reason.as_str().to_owned().into(),
3648        })),
3649        // We should never receive a frame while reading the message, according
3650        // to the `tungstenite` maintainers:
3651        //
3652        // > It cannot occur when you read messages from the WebSocket, but it
3653        // > can be used when you want to send the raw frames (e.g. you want to
3654        // > send the frames to the WebSocket without composing the full message first).
3655        // >
3656        // > — https://github.com/snapview/tungstenite-rs/issues/268
3657        TungsteniteMessage::Frame(_) => {
3658            bail!("received an unexpected frame while reading the message")
3659        }
3660    };
3661
3662    Ok(message)
3663}
3664
3665fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3666    match message {
3667        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3668        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3669        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3670        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3671        AxumMessage::Close(frame) => {
3672            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3673                code: frame.code.into(),
3674                reason: frame.reason.as_ref().into(),
3675            }))
3676        }
3677    }
3678}
3679
3680fn notify_membership_updated(
3681    connection_pool: &mut ConnectionPool,
3682    result: MembershipUpdated,
3683    user_id: UserId,
3684    peer: &Peer,
3685) {
3686    for membership in &result.new_channels.channel_memberships {
3687        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3688    }
3689    for channel_id in &result.removed_channels {
3690        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3691    }
3692
3693    let user_channels_update = proto::UpdateUserChannels {
3694        channel_memberships: result
3695            .new_channels
3696            .channel_memberships
3697            .iter()
3698            .map(|cm| proto::ChannelMembership {
3699                channel_id: cm.channel_id.to_proto(),
3700                role: cm.role.into(),
3701            })
3702            .collect(),
3703        ..Default::default()
3704    };
3705
3706    let mut update = build_channels_update(result.new_channels);
3707    update.delete_channels = result
3708        .removed_channels
3709        .into_iter()
3710        .map(|id| id.to_proto())
3711        .collect();
3712    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3713
3714    for connection_id in connection_pool.user_connection_ids(user_id) {
3715        peer.send(connection_id, user_channels_update.clone())
3716            .trace_err();
3717        peer.send(connection_id, update.clone()).trace_err();
3718    }
3719}
3720
3721fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3722    proto::UpdateUserChannels {
3723        channel_memberships: channels
3724            .channel_memberships
3725            .iter()
3726            .map(|m| proto::ChannelMembership {
3727                channel_id: m.channel_id.to_proto(),
3728                role: m.role.into(),
3729            })
3730            .collect(),
3731        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3732    }
3733}
3734
3735fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3736    let mut update = proto::UpdateChannels::default();
3737
3738    for channel in channels.channels {
3739        update.channels.push(channel.to_proto());
3740    }
3741
3742    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3743
3744    for (channel_id, participants) in channels.channel_participants {
3745        update
3746            .channel_participants
3747            .push(proto::ChannelParticipants {
3748                channel_id: channel_id.to_proto(),
3749                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3750            });
3751    }
3752
3753    for channel in channels.invited_channels {
3754        update.channel_invitations.push(channel.to_proto());
3755    }
3756
3757    update
3758}
3759
3760fn build_initial_contacts_update(
3761    contacts: Vec<db::Contact>,
3762    pool: &ConnectionPool,
3763) -> proto::UpdateContacts {
3764    let mut update = proto::UpdateContacts::default();
3765
3766    for contact in contacts {
3767        match contact {
3768            db::Contact::Accepted { user_id, busy } => {
3769                update.contacts.push(contact_for_user(user_id, busy, pool));
3770            }
3771            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3772            db::Contact::Incoming { user_id } => {
3773                update
3774                    .incoming_requests
3775                    .push(proto::IncomingContactRequest {
3776                        requester_id: user_id.to_proto(),
3777                    })
3778            }
3779        }
3780    }
3781
3782    update
3783}
3784
3785fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3786    proto::Contact {
3787        user_id: user_id.to_proto(),
3788        online: pool.is_user_online(user_id),
3789        busy,
3790    }
3791}
3792
3793fn room_updated(room: &proto::Room, peer: &Peer) {
3794    broadcast(
3795        None,
3796        room.participants
3797            .iter()
3798            .filter_map(|participant| Some(participant.peer_id?.into())),
3799        |peer_id| {
3800            peer.send(
3801                peer_id,
3802                proto::RoomUpdated {
3803                    room: Some(room.clone()),
3804                },
3805            )
3806        },
3807    );
3808}
3809
3810fn channel_updated(
3811    channel: &db::channel::Model,
3812    room: &proto::Room,
3813    peer: &Peer,
3814    pool: &ConnectionPool,
3815) {
3816    let participants = room
3817        .participants
3818        .iter()
3819        .map(|p| p.user_id)
3820        .collect::<Vec<_>>();
3821
3822    broadcast(
3823        None,
3824        pool.channel_connection_ids(channel.root_id())
3825            .filter_map(|(channel_id, role)| {
3826                role.can_see_channel(channel.visibility)
3827                    .then_some(channel_id)
3828            }),
3829        |peer_id| {
3830            peer.send(
3831                peer_id,
3832                proto::UpdateChannels {
3833                    channel_participants: vec![proto::ChannelParticipants {
3834                        channel_id: channel.id.to_proto(),
3835                        participant_user_ids: participants.clone(),
3836                    }],
3837                    ..Default::default()
3838                },
3839            )
3840        },
3841    );
3842}
3843
3844async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3845    let db = session.db().await;
3846
3847    let contacts = db.get_contacts(user_id).await?;
3848    let busy = db.is_user_busy(user_id).await?;
3849
3850    let pool = session.connection_pool().await;
3851    let updated_contact = contact_for_user(user_id, busy, &pool);
3852    for contact in contacts {
3853        if let db::Contact::Accepted {
3854            user_id: contact_user_id,
3855            ..
3856        } = contact
3857        {
3858            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3859                session
3860                    .peer
3861                    .send(
3862                        contact_conn_id,
3863                        proto::UpdateContacts {
3864                            contacts: vec![updated_contact.clone()],
3865                            remove_contacts: Default::default(),
3866                            incoming_requests: Default::default(),
3867                            remove_incoming_requests: Default::default(),
3868                            outgoing_requests: Default::default(),
3869                            remove_outgoing_requests: Default::default(),
3870                        },
3871                    )
3872                    .trace_err();
3873            }
3874        }
3875    }
3876    Ok(())
3877}
3878
3879async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3880    let mut contacts_to_update = HashSet::default();
3881
3882    let room_id;
3883    let canceled_calls_to_user_ids;
3884    let livekit_room;
3885    let delete_livekit_room;
3886    let room;
3887    let channel;
3888
3889    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3890        contacts_to_update.insert(session.user_id());
3891
3892        for project in left_room.left_projects.values() {
3893            project_left(project, session);
3894        }
3895
3896        room_id = RoomId::from_proto(left_room.room.id);
3897        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3898        livekit_room = mem::take(&mut left_room.room.livekit_room);
3899        delete_livekit_room = left_room.deleted;
3900        room = mem::take(&mut left_room.room);
3901        channel = mem::take(&mut left_room.channel);
3902
3903        room_updated(&room, &session.peer);
3904    } else {
3905        return Ok(());
3906    }
3907
3908    if let Some(channel) = channel {
3909        channel_updated(
3910            &channel,
3911            &room,
3912            &session.peer,
3913            &*session.connection_pool().await,
3914        );
3915    }
3916
3917    {
3918        let pool = session.connection_pool().await;
3919        for canceled_user_id in canceled_calls_to_user_ids {
3920            for connection_id in pool.user_connection_ids(canceled_user_id) {
3921                session
3922                    .peer
3923                    .send(
3924                        connection_id,
3925                        proto::CallCanceled {
3926                            room_id: room_id.to_proto(),
3927                        },
3928                    )
3929                    .trace_err();
3930            }
3931            contacts_to_update.insert(canceled_user_id);
3932        }
3933    }
3934
3935    for contact_user_id in contacts_to_update {
3936        update_user_contacts(contact_user_id, session).await?;
3937    }
3938
3939    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3940        live_kit
3941            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3942            .await
3943            .trace_err();
3944
3945        if delete_livekit_room {
3946            live_kit.delete_room(livekit_room).await.trace_err();
3947        }
3948    }
3949
3950    Ok(())
3951}
3952
3953async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3954    let left_channel_buffers = session
3955        .db()
3956        .await
3957        .leave_channel_buffers(session.connection_id)
3958        .await?;
3959
3960    for left_buffer in left_channel_buffers {
3961        channel_buffer_updated(
3962            session.connection_id,
3963            left_buffer.connections,
3964            &proto::UpdateChannelBufferCollaborators {
3965                channel_id: left_buffer.channel_id.to_proto(),
3966                collaborators: left_buffer.collaborators,
3967            },
3968            &session.peer,
3969        );
3970    }
3971
3972    Ok(())
3973}
3974
3975fn project_left(project: &db::LeftProject, session: &Session) {
3976    for connection_id in &project.connection_ids {
3977        if project.should_unshare {
3978            session
3979                .peer
3980                .send(
3981                    *connection_id,
3982                    proto::UnshareProject {
3983                        project_id: project.id.to_proto(),
3984                    },
3985                )
3986                .trace_err();
3987        } else {
3988            session
3989                .peer
3990                .send(
3991                    *connection_id,
3992                    proto::RemoveProjectCollaborator {
3993                        project_id: project.id.to_proto(),
3994                        peer_id: Some(session.connection_id.into()),
3995                    },
3996                )
3997                .trace_err();
3998        }
3999    }
4000}
4001
4002async fn share_agent_thread(
4003    request: proto::ShareAgentThread,
4004    response: Response<proto::ShareAgentThread>,
4005    session: MessageContext,
4006) -> Result<()> {
4007    let user_id = session.user_id();
4008
4009    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4010        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4011
4012    session
4013        .db()
4014        .await
4015        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4016        .await?;
4017
4018    response.send(proto::Ack {})?;
4019
4020    Ok(())
4021}
4022
4023async fn get_shared_agent_thread(
4024    request: proto::GetSharedAgentThread,
4025    response: Response<proto::GetSharedAgentThread>,
4026    session: MessageContext,
4027) -> Result<()> {
4028    let share_id = SharedThreadId::from_proto(request.session_id)
4029        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4030
4031    let result = session.db().await.get_shared_thread(share_id).await?;
4032
4033    match result {
4034        Some((thread, username)) => {
4035            response.send(proto::GetSharedAgentThreadResponse {
4036                title: thread.title,
4037                thread_data: thread.data,
4038                sharer_username: username,
4039                created_at: thread.created_at.and_utc().to_rfc3339(),
4040            })?;
4041        }
4042        None => {
4043            return Err(anyhow!("Shared thread not found").into());
4044        }
4045    }
4046
4047    Ok(())
4048}
4049
4050pub trait ResultExt {
4051    type Ok;
4052
4053    fn trace_err(self) -> Option<Self::Ok>;
4054}
4055
4056impl<T, E> ResultExt for Result<T, E>
4057where
4058    E: std::fmt::Debug,
4059{
4060    type Ok = T;
4061
4062    #[track_caller]
4063    fn trace_err(self) -> Option<T> {
4064        match self {
4065            Ok(value) => Some(value),
4066            Err(error) => {
4067                tracing::error!("{:?}", error);
4068                None
4069            }
4070        }
4071    }
4072}