rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        Arc, OnceLock,
  64        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  65    },
  66    time::{Duration, Instant},
  67};
  68use tokio::sync::{Semaphore, watch};
  69use tower::ServiceBuilder;
  70use tracing::{
  71    Instrument,
  72    field::{self},
  73    info_span, instrument,
  74};
  75
  76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  77
  78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  80
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  83
  84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  85
  86const TOTAL_DURATION_MS: &str = "total_duration_ms";
  87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  89const HOST_WAITING_MS: &str = "host_waiting_ms";
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  93
  94pub struct ConnectionGuard;
  95
  96impl ConnectionGuard {
  97    pub fn try_acquire() -> Result<Self, ()> {
  98        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  99        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 100            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 101            tracing::error!(
 102                "too many concurrent connections: {}",
 103                current_connections + 1
 104            );
 105            return Err(());
 106        }
 107        Ok(ConnectionGuard)
 108    }
 109}
 110
 111impl Drop for ConnectionGuard {
 112    fn drop(&mut self) {
 113        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114    }
 115}
 116
 117struct Response<R> {
 118    peer: Arc<Peer>,
 119    receipt: Receipt<R>,
 120    responded: Arc<AtomicBool>,
 121}
 122
 123impl<R: RequestMessage> Response<R> {
 124    fn send(self, payload: R::Response) -> Result<()> {
 125        self.responded.store(true, SeqCst);
 126        self.peer.respond(self.receipt, payload)?;
 127        Ok(())
 128    }
 129}
 130
 131#[derive(Clone, Debug)]
 132pub enum Principal {
 133    User(User),
 134    Impersonated { user: User, admin: User },
 135}
 136
 137impl Principal {
 138    fn update_span(&self, span: &tracing::Span) {
 139        match &self {
 140            Principal::User(user) => {
 141                span.record("user_id", user.id.0);
 142                span.record("login", &user.github_login);
 143            }
 144            Principal::Impersonated { user, admin } => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147                span.record("impersonator", &admin.github_login);
 148            }
 149        }
 150    }
 151}
 152
 153#[derive(Clone)]
 154struct MessageContext {
 155    session: Session,
 156    span: tracing::Span,
 157}
 158
 159impl Deref for MessageContext {
 160    type Target = Session;
 161
 162    fn deref(&self) -> &Self::Target {
 163        &self.session
 164    }
 165}
 166
 167impl MessageContext {
 168    pub fn forward_request<T: RequestMessage>(
 169        &self,
 170        receiver_id: ConnectionId,
 171        request: T,
 172    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 173        let request_start_time = Instant::now();
 174        let span = self.span.clone();
 175        tracing::info!("start forwarding request");
 176        self.peer
 177            .forward_request(self.connection_id, receiver_id, request)
 178            .inspect(move |_| {
 179                span.record(
 180                    HOST_WAITING_MS,
 181                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 182                );
 183            })
 184            .inspect_err(|_| tracing::error!("error forwarding request"))
 185            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 186    }
 187}
 188
 189#[derive(Clone)]
 190struct Session {
 191    principal: Principal,
 192    connection_id: ConnectionId,
 193    db: Arc<tokio::sync::Mutex<DbHandle>>,
 194    peer: Arc<Peer>,
 195    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 196    app_state: Arc<AppState>,
 197    /// The GeoIP country code for the user.
 198    #[allow(unused)]
 199    geoip_country_code: Option<String>,
 200    #[allow(unused)]
 201    system_id: Option<String>,
 202    _executor: Executor,
 203}
 204
 205impl Session {
 206    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 207        #[cfg(test)]
 208        tokio::task::yield_now().await;
 209        let guard = self.db.lock().await;
 210        #[cfg(test)]
 211        tokio::task::yield_now().await;
 212        guard
 213    }
 214
 215    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 216        #[cfg(test)]
 217        tokio::task::yield_now().await;
 218        let guard = self.connection_pool.lock();
 219        ConnectionPoolGuard {
 220            guard,
 221            _not_send: PhantomData,
 222        }
 223    }
 224
 225    #[expect(dead_code)]
 226    fn is_staff(&self) -> bool {
 227        match &self.principal {
 228            Principal::User(user) => user.admin,
 229            Principal::Impersonated { .. } => true,
 230        }
 231    }
 232
 233    fn user_id(&self) -> UserId {
 234        match &self.principal {
 235            Principal::User(user) => user.id,
 236            Principal::Impersonated { user, .. } => user.id,
 237        }
 238    }
 239}
 240
 241impl Debug for Session {
 242    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 243        let mut result = f.debug_struct("Session");
 244        match &self.principal {
 245            Principal::User(user) => {
 246                result.field("user", &user.github_login);
 247            }
 248            Principal::Impersonated { user, admin } => {
 249                result.field("user", &user.github_login);
 250                result.field("impersonator", &admin.github_login);
 251            }
 252        }
 253        result.field("connection_id", &self.connection_id).finish()
 254    }
 255}
 256
 257struct DbHandle(Arc<Database>);
 258
 259impl Deref for DbHandle {
 260    type Target = Database;
 261
 262    fn deref(&self) -> &Self::Target {
 263        self.0.as_ref()
 264    }
 265}
 266
 267pub struct Server {
 268    id: parking_lot::Mutex<ServerId>,
 269    peer: Arc<Peer>,
 270    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 271    app_state: Arc<AppState>,
 272    handlers: HashMap<TypeId, MessageHandler>,
 273    teardown: watch::Sender<bool>,
 274}
 275
 276pub(crate) struct ConnectionPoolGuard<'a> {
 277    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 278    _not_send: PhantomData<Rc<()>>,
 279}
 280
 281#[derive(Serialize)]
 282pub struct ServerSnapshot<'a> {
 283    peer: &'a Peer,
 284    #[serde(serialize_with = "serialize_deref")]
 285    connection_pool: ConnectionPoolGuard<'a>,
 286}
 287
 288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 289where
 290    S: Serializer,
 291    T: Deref<Target = U>,
 292    U: Serialize,
 293{
 294    Serialize::serialize(value.deref(), serializer)
 295}
 296
 297impl Server {
 298    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 299        let mut server = Self {
 300            id: parking_lot::Mutex::new(id),
 301            peer: Peer::new(id.0 as u32),
 302            app_state,
 303            connection_pool: Default::default(),
 304            handlers: Default::default(),
 305            teardown: watch::channel(false).0,
 306        };
 307
 308        server
 309            .add_request_handler(ping)
 310            .add_request_handler(create_room)
 311            .add_request_handler(join_room)
 312            .add_request_handler(rejoin_room)
 313            .add_request_handler(leave_room)
 314            .add_request_handler(set_room_participant_role)
 315            .add_request_handler(call)
 316            .add_request_handler(cancel_call)
 317            .add_message_handler(decline_call)
 318            .add_request_handler(update_participant_location)
 319            .add_request_handler(share_project)
 320            .add_message_handler(unshare_project)
 321            .add_request_handler(join_project)
 322            .add_message_handler(leave_project)
 323            .add_request_handler(update_project)
 324            .add_request_handler(update_worktree)
 325            .add_request_handler(update_repository)
 326            .add_request_handler(remove_repository)
 327            .add_message_handler(start_language_server)
 328            .add_message_handler(update_language_server)
 329            .add_message_handler(update_diagnostic_summary)
 330            .add_message_handler(update_worktree_settings)
 331            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 332            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 333            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 334            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 335            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 336            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 337            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 338            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 340            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 341            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 346            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 347            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 348            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 349            .add_request_handler(
 350                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 351            )
 352            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 355            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 356            .add_request_handler(
 357                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 358            )
 359            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 360            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 361            .add_request_handler(
 362                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 363            )
 364            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 365            .add_request_handler(
 366                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 367            )
 368            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 369            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 370            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 371            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 372            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 373            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 374            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 375            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 376            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 377            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 378            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 379            .add_request_handler(
 380                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 381            )
 382            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 383            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 384            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 385            .add_request_handler(lsp_query)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 387            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 388            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 390            .add_message_handler(create_buffer_for_peer)
 391            .add_message_handler(create_image_for_peer)
 392            .add_request_handler(update_buffer)
 393            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 395            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 396            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 397            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 398            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 399            .add_message_handler(
 400                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 401            )
 402            .add_request_handler(get_users)
 403            .add_request_handler(fuzzy_search_users)
 404            .add_request_handler(request_contact)
 405            .add_request_handler(remove_contact)
 406            .add_request_handler(respond_to_contact_request)
 407            .add_message_handler(subscribe_to_channels)
 408            .add_request_handler(create_channel)
 409            .add_request_handler(delete_channel)
 410            .add_request_handler(invite_channel_member)
 411            .add_request_handler(remove_channel_member)
 412            .add_request_handler(set_channel_member_role)
 413            .add_request_handler(set_channel_visibility)
 414            .add_request_handler(rename_channel)
 415            .add_request_handler(join_channel_buffer)
 416            .add_request_handler(leave_channel_buffer)
 417            .add_message_handler(update_channel_buffer)
 418            .add_request_handler(rejoin_channel_buffers)
 419            .add_request_handler(get_channel_members)
 420            .add_request_handler(respond_to_channel_invite)
 421            .add_request_handler(join_channel)
 422            .add_request_handler(join_channel_chat)
 423            .add_message_handler(leave_channel_chat)
 424            .add_request_handler(send_channel_message)
 425            .add_request_handler(remove_channel_message)
 426            .add_request_handler(update_channel_message)
 427            .add_request_handler(get_channel_messages)
 428            .add_request_handler(get_channel_messages_by_id)
 429            .add_request_handler(get_notifications)
 430            .add_request_handler(mark_notification_as_read)
 431            .add_request_handler(move_channel)
 432            .add_request_handler(reorder_channel)
 433            .add_request_handler(follow)
 434            .add_message_handler(unfollow)
 435            .add_message_handler(update_followers)
 436            .add_message_handler(acknowledge_channel_message)
 437            .add_message_handler(acknowledge_buffer_version)
 438            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 439            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 440            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 441            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 442            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 443            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 444            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 445            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 446            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 447            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 449            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 450            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 451            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 452            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 453            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 454            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 455            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 456            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 457            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 458            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 459            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 460            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 461            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 465            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 466            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 467            .add_message_handler(update_context)
 468            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 469            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 470            .add_request_handler(share_agent_thread)
 471            .add_request_handler(get_shared_agent_thread)
 472            .add_request_handler(forward_project_search_chunk);
 473
 474        Arc::new(server)
 475    }
 476
 477    pub async fn start(&self) -> Result<()> {
 478        let server_id = *self.id.lock();
 479        let app_state = self.app_state.clone();
 480        let peer = self.peer.clone();
 481        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 482        let pool = self.connection_pool.clone();
 483        let livekit_client = self.app_state.livekit_client.clone();
 484
 485        let span = info_span!("start server");
 486        self.app_state.executor.spawn_detached(
 487            async move {
 488                tracing::info!("waiting for cleanup timeout");
 489                timeout.await;
 490                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 491
 492                app_state
 493                    .db
 494                    .delete_stale_channel_chat_participants(
 495                        &app_state.config.zed_environment,
 496                        server_id,
 497                    )
 498                    .await
 499                    .trace_err();
 500
 501                if let Some((room_ids, channel_ids)) = app_state
 502                    .db
 503                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 504                    .await
 505                    .trace_err()
 506                {
 507                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 508                    tracing::info!(
 509                        stale_channel_buffer_count = channel_ids.len(),
 510                        "retrieved stale channel buffers"
 511                    );
 512
 513                    for channel_id in channel_ids {
 514                        if let Some(refreshed_channel_buffer) = app_state
 515                            .db
 516                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 517                            .await
 518                            .trace_err()
 519                        {
 520                            for connection_id in refreshed_channel_buffer.connection_ids {
 521                                peer.send(
 522                                    connection_id,
 523                                    proto::UpdateChannelBufferCollaborators {
 524                                        channel_id: channel_id.to_proto(),
 525                                        collaborators: refreshed_channel_buffer
 526                                            .collaborators
 527                                            .clone(),
 528                                    },
 529                                )
 530                                .trace_err();
 531                            }
 532                        }
 533                    }
 534
 535                    for room_id in room_ids {
 536                        let mut contacts_to_update = HashSet::default();
 537                        let mut canceled_calls_to_user_ids = Vec::new();
 538                        let mut livekit_room = String::new();
 539                        let mut delete_livekit_room = false;
 540
 541                        if let Some(mut refreshed_room) = app_state
 542                            .db
 543                            .clear_stale_room_participants(room_id, server_id)
 544                            .await
 545                            .trace_err()
 546                        {
 547                            tracing::info!(
 548                                room_id = room_id.0,
 549                                new_participant_count = refreshed_room.room.participants.len(),
 550                                "refreshed room"
 551                            );
 552                            room_updated(&refreshed_room.room, &peer);
 553                            if let Some(channel) = refreshed_room.channel.as_ref() {
 554                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 555                            }
 556                            contacts_to_update
 557                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 558                            contacts_to_update
 559                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 560                            canceled_calls_to_user_ids =
 561                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 562                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 563                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 564                        }
 565
 566                        {
 567                            let pool = pool.lock();
 568                            for canceled_user_id in canceled_calls_to_user_ids {
 569                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 570                                    peer.send(
 571                                        connection_id,
 572                                        proto::CallCanceled {
 573                                            room_id: room_id.to_proto(),
 574                                        },
 575                                    )
 576                                    .trace_err();
 577                                }
 578                            }
 579                        }
 580
 581                        for user_id in contacts_to_update {
 582                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 583                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 584                            if let Some((busy, contacts)) = busy.zip(contacts) {
 585                                let pool = pool.lock();
 586                                let updated_contact = contact_for_user(user_id, busy, &pool);
 587                                for contact in contacts {
 588                                    if let db::Contact::Accepted {
 589                                        user_id: contact_user_id,
 590                                        ..
 591                                    } = contact
 592                                    {
 593                                        for contact_conn_id in
 594                                            pool.user_connection_ids(contact_user_id)
 595                                        {
 596                                            peer.send(
 597                                                contact_conn_id,
 598                                                proto::UpdateContacts {
 599                                                    contacts: vec![updated_contact.clone()],
 600                                                    remove_contacts: Default::default(),
 601                                                    incoming_requests: Default::default(),
 602                                                    remove_incoming_requests: Default::default(),
 603                                                    outgoing_requests: Default::default(),
 604                                                    remove_outgoing_requests: Default::default(),
 605                                                },
 606                                            )
 607                                            .trace_err();
 608                                        }
 609                                    }
 610                                }
 611                            }
 612                        }
 613
 614                        if let Some(live_kit) = livekit_client.as_ref()
 615                            && delete_livekit_room
 616                        {
 617                            live_kit.delete_room(livekit_room).await.trace_err();
 618                        }
 619                    }
 620                }
 621
 622                app_state
 623                    .db
 624                    .delete_stale_channel_chat_participants(
 625                        &app_state.config.zed_environment,
 626                        server_id,
 627                    )
 628                    .await
 629                    .trace_err();
 630
 631                app_state
 632                    .db
 633                    .clear_old_worktree_entries(server_id)
 634                    .await
 635                    .trace_err();
 636
 637                app_state
 638                    .db
 639                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 640                    .await
 641                    .trace_err();
 642            }
 643            .instrument(span),
 644        );
 645        Ok(())
 646    }
 647
 648    pub fn teardown(&self) {
 649        self.peer.teardown();
 650        self.connection_pool.lock().reset();
 651        let _ = self.teardown.send(true);
 652    }
 653
 654    #[cfg(test)]
 655    pub fn reset(&self, id: ServerId) {
 656        self.teardown();
 657        *self.id.lock() = id;
 658        self.peer.reset(id.0 as u32);
 659        let _ = self.teardown.send(false);
 660    }
 661
 662    #[cfg(test)]
 663    pub fn id(&self) -> ServerId {
 664        *self.id.lock()
 665    }
 666
 667    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 668    where
 669        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 670        Fut: 'static + Send + Future<Output = Result<()>>,
 671        M: EnvelopedMessage,
 672    {
 673        let prev_handler = self.handlers.insert(
 674            TypeId::of::<M>(),
 675            Box::new(move |envelope, session, span| {
 676                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 677                let received_at = envelope.received_at;
 678                tracing::info!("message received");
 679                let start_time = Instant::now();
 680                let future = (handler)(
 681                    *envelope,
 682                    MessageContext {
 683                        session,
 684                        span: span.clone(),
 685                    },
 686                );
 687                async move {
 688                    let result = future.await;
 689                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 690                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 691                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 692                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 693                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 694                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 695                    match result {
 696                        Err(error) => {
 697                            tracing::error!(?error, "error handling message")
 698                        }
 699                        Ok(()) => tracing::info!("finished handling message"),
 700                    }
 701                }
 702                .boxed()
 703            }),
 704        );
 705        if prev_handler.is_some() {
 706            panic!("registered a handler for the same message twice");
 707        }
 708        self
 709    }
 710
 711    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 712    where
 713        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 714        Fut: 'static + Send + Future<Output = Result<()>>,
 715        M: EnvelopedMessage,
 716    {
 717        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 718        self
 719    }
 720
 721    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 722    where
 723        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 724        Fut: Send + Future<Output = Result<()>>,
 725        M: RequestMessage,
 726    {
 727        let handler = Arc::new(handler);
 728        self.add_handler(move |envelope, session| {
 729            let receipt = envelope.receipt();
 730            let handler = handler.clone();
 731            async move {
 732                let peer = session.peer.clone();
 733                let responded = Arc::new(AtomicBool::default());
 734                let response = Response {
 735                    peer: peer.clone(),
 736                    responded: responded.clone(),
 737                    receipt,
 738                };
 739                match (handler)(envelope.payload, response, session).await {
 740                    Ok(()) => {
 741                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 742                            Ok(())
 743                        } else {
 744                            Err(anyhow!("handler did not send a response"))?
 745                        }
 746                    }
 747                    Err(error) => {
 748                        let proto_err = match &error {
 749                            Error::Internal(err) => err.to_proto(),
 750                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 751                        };
 752                        peer.respond_with_error(receipt, proto_err)?;
 753                        Err(error)
 754                    }
 755                }
 756            }
 757        })
 758    }
 759
 760    pub fn handle_connection(
 761        self: &Arc<Self>,
 762        connection: Connection,
 763        address: String,
 764        principal: Principal,
 765        zed_version: ZedVersion,
 766        release_channel: Option<String>,
 767        user_agent: Option<String>,
 768        geoip_country_code: Option<String>,
 769        system_id: Option<String>,
 770        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 771        executor: Executor,
 772        connection_guard: Option<ConnectionGuard>,
 773    ) -> impl Future<Output = ()> + use<> {
 774        let this = self.clone();
 775        let span = info_span!("handle connection", %address,
 776            connection_id=field::Empty,
 777            user_id=field::Empty,
 778            login=field::Empty,
 779            impersonator=field::Empty,
 780            user_agent=field::Empty,
 781            geoip_country_code=field::Empty,
 782            release_channel=field::Empty,
 783        );
 784        principal.update_span(&span);
 785        if let Some(user_agent) = user_agent {
 786            span.record("user_agent", user_agent);
 787        }
 788        if let Some(release_channel) = release_channel {
 789            span.record("release_channel", release_channel);
 790        }
 791
 792        if let Some(country_code) = geoip_country_code.as_ref() {
 793            span.record("geoip_country_code", country_code);
 794        }
 795
 796        let mut teardown = self.teardown.subscribe();
 797        async move {
 798            if *teardown.borrow() {
 799                tracing::error!("server is tearing down");
 800                return;
 801            }
 802
 803            let (connection_id, handle_io, mut incoming_rx) =
 804                this.peer.add_connection(connection, {
 805                    let executor = executor.clone();
 806                    move |duration| executor.sleep(duration)
 807                });
 808            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 809
 810            tracing::info!("connection opened");
 811
 812            let session = Session {
 813                principal: principal.clone(),
 814                connection_id,
 815                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 816                peer: this.peer.clone(),
 817                connection_pool: this.connection_pool.clone(),
 818                app_state: this.app_state.clone(),
 819                geoip_country_code,
 820                system_id,
 821                _executor: executor.clone(),
 822            };
 823
 824            if let Err(error) = this
 825                .send_initial_client_update(
 826                    connection_id,
 827                    zed_version,
 828                    send_connection_id,
 829                    &session,
 830                )
 831                .await
 832            {
 833                tracing::error!(?error, "failed to send initial client update");
 834                return;
 835            }
 836            drop(connection_guard);
 837
 838            let handle_io = handle_io.fuse();
 839            futures::pin_mut!(handle_io);
 840
 841            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 842            // This prevents deadlocks when e.g., client A performs a request to client B and
 843            // client B performs a request to client A. If both clients stop processing further
 844            // messages until their respective request completes, they won't have a chance to
 845            // respond to the other client's request and cause a deadlock.
 846            //
 847            // This arrangement ensures we will attempt to process earlier messages first, but fall
 848            // back to processing messages arrived later in the spirit of making progress.
 849            const MAX_CONCURRENT_HANDLERS: usize = 256;
 850            let mut foreground_message_handlers = FuturesUnordered::new();
 851            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 852            let get_concurrent_handlers = {
 853                let concurrent_handlers = concurrent_handlers.clone();
 854                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 855            };
 856            loop {
 857                let next_message = async {
 858                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 859                    let message = incoming_rx.next().await;
 860                    // Cache the concurrent_handlers here, so that we know what the
 861                    // queue looks like as each handler starts
 862                    (permit, message, get_concurrent_handlers())
 863                }
 864                .fuse();
 865                futures::pin_mut!(next_message);
 866                futures::select_biased! {
 867                    _ = teardown.changed().fuse() => return,
 868                    result = handle_io => {
 869                        if let Err(error) = result {
 870                            tracing::error!(?error, "error handling I/O");
 871                        }
 872                        break;
 873                    }
 874                    _ = foreground_message_handlers.next() => {}
 875                    next_message = next_message => {
 876                        let (permit, message, concurrent_handlers) = next_message;
 877                        if let Some(message) = message {
 878                            let type_name = message.payload_type_name();
 879                            // note: we copy all the fields from the parent span so we can query them in the logs.
 880                            // (https://github.com/tokio-rs/tracing/issues/2670).
 881                            let span = tracing::info_span!("receive message",
 882                                %connection_id,
 883                                %address,
 884                                type_name,
 885                                concurrent_handlers,
 886                                user_id=field::Empty,
 887                                login=field::Empty,
 888                                impersonator=field::Empty,
 889                                lsp_query_request=field::Empty,
 890                                release_channel=field::Empty,
 891                                { TOTAL_DURATION_MS }=field::Empty,
 892                                { PROCESSING_DURATION_MS }=field::Empty,
 893                                { QUEUE_DURATION_MS }=field::Empty,
 894                                { HOST_WAITING_MS }=field::Empty
 895                            );
 896                            principal.update_span(&span);
 897                            let span_enter = span.enter();
 898                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 899                                let is_background = message.is_background();
 900                                let handle_message = (handler)(message, session.clone(), span.clone());
 901                                drop(span_enter);
 902
 903                                let handle_message = async move {
 904                                    handle_message.await;
 905                                    drop(permit);
 906                                }.instrument(span);
 907                                if is_background {
 908                                    executor.spawn_detached(handle_message);
 909                                } else {
 910                                    foreground_message_handlers.push(handle_message);
 911                                }
 912                            } else {
 913                                tracing::error!("no message handler");
 914                            }
 915                        } else {
 916                            tracing::info!("connection closed");
 917                            break;
 918                        }
 919                    }
 920                }
 921            }
 922
 923            drop(foreground_message_handlers);
 924            let concurrent_handlers = get_concurrent_handlers();
 925            tracing::info!(concurrent_handlers, "signing out");
 926            if let Err(error) = connection_lost(session, teardown, executor).await {
 927                tracing::error!(?error, "error signing out");
 928            }
 929        }
 930        .instrument(span)
 931    }
 932
 933    async fn send_initial_client_update(
 934        &self,
 935        connection_id: ConnectionId,
 936        zed_version: ZedVersion,
 937        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 938        session: &Session,
 939    ) -> Result<()> {
 940        self.peer.send(
 941            connection_id,
 942            proto::Hello {
 943                peer_id: Some(connection_id.into()),
 944            },
 945        )?;
 946        tracing::info!("sent hello message");
 947        if let Some(send_connection_id) = send_connection_id.take() {
 948            let _ = send_connection_id.send(connection_id);
 949        }
 950
 951        match &session.principal {
 952            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 953                if !user.connected_once {
 954                    self.peer.send(connection_id, proto::ShowContacts {})?;
 955                    self.app_state
 956                        .db
 957                        .set_user_connected_once(user.id, true)
 958                        .await?;
 959                }
 960
 961                let contacts = self.app_state.db.get_contacts(user.id).await?;
 962
 963                {
 964                    let mut pool = self.connection_pool.lock();
 965                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 966                    self.peer.send(
 967                        connection_id,
 968                        build_initial_contacts_update(contacts, &pool),
 969                    )?;
 970                }
 971
 972                if should_auto_subscribe_to_channels(&zed_version) {
 973                    subscribe_user_to_channels(user.id, session).await?;
 974                }
 975
 976                if let Some(incoming_call) =
 977                    self.app_state.db.incoming_call_for_user(user.id).await?
 978                {
 979                    self.peer.send(connection_id, incoming_call)?;
 980                }
 981
 982                update_user_contacts(user.id, session).await?;
 983            }
 984        }
 985
 986        Ok(())
 987    }
 988
 989    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
 990        ServerSnapshot {
 991            connection_pool: ConnectionPoolGuard {
 992                guard: self.connection_pool.lock(),
 993                _not_send: PhantomData,
 994            },
 995            peer: &self.peer,
 996        }
 997    }
 998}
 999
1000impl Deref for ConnectionPoolGuard<'_> {
1001    type Target = ConnectionPool;
1002
1003    fn deref(&self) -> &Self::Target {
1004        &self.guard
1005    }
1006}
1007
1008impl DerefMut for ConnectionPoolGuard<'_> {
1009    fn deref_mut(&mut self) -> &mut Self::Target {
1010        &mut self.guard
1011    }
1012}
1013
1014impl Drop for ConnectionPoolGuard<'_> {
1015    fn drop(&mut self) {
1016        #[cfg(test)]
1017        self.check_invariants();
1018    }
1019}
1020
1021fn broadcast<F>(
1022    sender_id: Option<ConnectionId>,
1023    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1024    mut f: F,
1025) where
1026    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1027{
1028    for receiver_id in receiver_ids {
1029        if Some(receiver_id) != sender_id
1030            && let Err(error) = f(receiver_id)
1031        {
1032            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1033        }
1034    }
1035}
1036
1037pub struct ProtocolVersion(u32);
1038
1039impl Header for ProtocolVersion {
1040    fn name() -> &'static HeaderName {
1041        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1042        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1043    }
1044
1045    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1046    where
1047        Self: Sized,
1048        I: Iterator<Item = &'i axum::http::HeaderValue>,
1049    {
1050        let version = values
1051            .next()
1052            .ok_or_else(axum::headers::Error::invalid)?
1053            .to_str()
1054            .map_err(|_| axum::headers::Error::invalid())?
1055            .parse()
1056            .map_err(|_| axum::headers::Error::invalid())?;
1057        Ok(Self(version))
1058    }
1059
1060    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1061        values.extend([self.0.to_string().parse().unwrap()]);
1062    }
1063}
1064
1065pub struct AppVersionHeader(Version);
1066impl Header for AppVersionHeader {
1067    fn name() -> &'static HeaderName {
1068        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1069        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1070    }
1071
1072    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1073    where
1074        Self: Sized,
1075        I: Iterator<Item = &'i axum::http::HeaderValue>,
1076    {
1077        let version = values
1078            .next()
1079            .ok_or_else(axum::headers::Error::invalid)?
1080            .to_str()
1081            .map_err(|_| axum::headers::Error::invalid())?
1082            .parse()
1083            .map_err(|_| axum::headers::Error::invalid())?;
1084        Ok(Self(version))
1085    }
1086
1087    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1088        values.extend([self.0.to_string().parse().unwrap()]);
1089    }
1090}
1091
1092#[derive(Debug)]
1093pub struct ReleaseChannelHeader(String);
1094
1095impl Header for ReleaseChannelHeader {
1096    fn name() -> &'static HeaderName {
1097        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1098        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1099    }
1100
1101    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1102    where
1103        Self: Sized,
1104        I: Iterator<Item = &'i axum::http::HeaderValue>,
1105    {
1106        Ok(Self(
1107            values
1108                .next()
1109                .ok_or_else(axum::headers::Error::invalid)?
1110                .to_str()
1111                .map_err(|_| axum::headers::Error::invalid())?
1112                .to_owned(),
1113        ))
1114    }
1115
1116    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1117        values.extend([self.0.parse().unwrap()]);
1118    }
1119}
1120
1121pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1122    Router::new()
1123        .route("/rpc", get(handle_websocket_request))
1124        .layer(
1125            ServiceBuilder::new()
1126                .layer(Extension(server.app_state.clone()))
1127                .layer(middleware::from_fn(auth::validate_header)),
1128        )
1129        .route("/metrics", get(handle_metrics))
1130        .layer(Extension(server))
1131}
1132
1133pub async fn handle_websocket_request(
1134    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1135    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1136    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1137    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1138    Extension(server): Extension<Arc<Server>>,
1139    Extension(principal): Extension<Principal>,
1140    user_agent: Option<TypedHeader<UserAgent>>,
1141    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1142    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1143    ws: WebSocketUpgrade,
1144) -> axum::response::Response {
1145    if protocol_version != rpc::PROTOCOL_VERSION {
1146        return (
1147            StatusCode::UPGRADE_REQUIRED,
1148            "client must be upgraded".to_string(),
1149        )
1150            .into_response();
1151    }
1152
1153    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1154        return (
1155            StatusCode::UPGRADE_REQUIRED,
1156            "no version header found".to_string(),
1157        )
1158            .into_response();
1159    };
1160
1161    let release_channel = release_channel_header.map(|header| header.0.0);
1162
1163    if !version.can_collaborate() {
1164        return (
1165            StatusCode::UPGRADE_REQUIRED,
1166            "client must be upgraded".to_string(),
1167        )
1168            .into_response();
1169    }
1170
1171    let socket_address = socket_address.to_string();
1172
1173    // Acquire connection guard before WebSocket upgrade
1174    let connection_guard = match ConnectionGuard::try_acquire() {
1175        Ok(guard) => guard,
1176        Err(()) => {
1177            return (
1178                StatusCode::SERVICE_UNAVAILABLE,
1179                "Too many concurrent connections",
1180            )
1181                .into_response();
1182        }
1183    };
1184
1185    ws.on_upgrade(move |socket| {
1186        let socket = socket
1187            .map_ok(to_tungstenite_message)
1188            .err_into()
1189            .with(|message| async move { to_axum_message(message) });
1190        let connection = Connection::new(Box::pin(socket));
1191        async move {
1192            server
1193                .handle_connection(
1194                    connection,
1195                    socket_address,
1196                    principal,
1197                    version,
1198                    release_channel,
1199                    user_agent.map(|header| header.to_string()),
1200                    country_code_header.map(|header| header.to_string()),
1201                    system_id_header.map(|header| header.to_string()),
1202                    None,
1203                    Executor::Production,
1204                    Some(connection_guard),
1205                )
1206                .await;
1207        }
1208    })
1209}
1210
1211pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1212    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1213    let connections_metric = CONNECTIONS_METRIC
1214        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1215
1216    let connections = server
1217        .connection_pool
1218        .lock()
1219        .connections()
1220        .filter(|connection| !connection.admin)
1221        .count();
1222    connections_metric.set(connections as _);
1223
1224    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1225    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1226        register_int_gauge!(
1227            "shared_projects",
1228            "number of open projects with one or more guests"
1229        )
1230        .unwrap()
1231    });
1232
1233    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1234    shared_projects_metric.set(shared_projects as _);
1235
1236    let encoder = prometheus::TextEncoder::new();
1237    let metric_families = prometheus::gather();
1238    let encoded_metrics = encoder
1239        .encode_to_string(&metric_families)
1240        .map_err(|err| anyhow!("{err}"))?;
1241    Ok(encoded_metrics)
1242}
1243
1244#[instrument(err, skip(executor))]
1245async fn connection_lost(
1246    session: Session,
1247    mut teardown: watch::Receiver<bool>,
1248    executor: Executor,
1249) -> Result<()> {
1250    session.peer.disconnect(session.connection_id);
1251    session
1252        .connection_pool()
1253        .await
1254        .remove_connection(session.connection_id)?;
1255
1256    session
1257        .db()
1258        .await
1259        .connection_lost(session.connection_id)
1260        .await
1261        .trace_err();
1262
1263    futures::select_biased! {
1264        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1265
1266            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1267            leave_room_for_session(&session, session.connection_id).await.trace_err();
1268            leave_channel_buffers_for_session(&session)
1269                .await
1270                .trace_err();
1271
1272            if !session
1273                .connection_pool()
1274                .await
1275                .is_user_online(session.user_id())
1276            {
1277                let db = session.db().await;
1278                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1279                    room_updated(&room, &session.peer);
1280                }
1281            }
1282
1283            update_user_contacts(session.user_id(), &session).await?;
1284        },
1285        _ = teardown.changed().fuse() => {}
1286    }
1287
1288    Ok(())
1289}
1290
1291/// Acknowledges a ping from a client, used to keep the connection alive.
1292async fn ping(
1293    _: proto::Ping,
1294    response: Response<proto::Ping>,
1295    _session: MessageContext,
1296) -> Result<()> {
1297    response.send(proto::Ack {})?;
1298    Ok(())
1299}
1300
1301/// Creates a new room for calling (outside of channels)
1302async fn create_room(
1303    _request: proto::CreateRoom,
1304    response: Response<proto::CreateRoom>,
1305    session: MessageContext,
1306) -> Result<()> {
1307    let livekit_room = nanoid::nanoid!(30);
1308
1309    let live_kit_connection_info = util::maybe!(async {
1310        let live_kit = session.app_state.livekit_client.as_ref();
1311        let live_kit = live_kit?;
1312        let user_id = session.user_id().to_string();
1313
1314        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1315
1316        Some(proto::LiveKitConnectionInfo {
1317            server_url: live_kit.url().into(),
1318            token,
1319            can_publish: true,
1320        })
1321    })
1322    .await;
1323
1324    let room = session
1325        .db()
1326        .await
1327        .create_room(session.user_id(), session.connection_id, &livekit_room)
1328        .await?;
1329
1330    response.send(proto::CreateRoomResponse {
1331        room: Some(room.clone()),
1332        live_kit_connection_info,
1333    })?;
1334
1335    update_user_contacts(session.user_id(), &session).await?;
1336    Ok(())
1337}
1338
1339/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1340async fn join_room(
1341    request: proto::JoinRoom,
1342    response: Response<proto::JoinRoom>,
1343    session: MessageContext,
1344) -> Result<()> {
1345    let room_id = RoomId::from_proto(request.id);
1346
1347    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1348
1349    if let Some(channel_id) = channel_id {
1350        return join_channel_internal(channel_id, Box::new(response), session).await;
1351    }
1352
1353    let joined_room = {
1354        let room = session
1355            .db()
1356            .await
1357            .join_room(room_id, session.user_id(), session.connection_id)
1358            .await?;
1359        room_updated(&room.room, &session.peer);
1360        room.into_inner()
1361    };
1362
1363    for connection_id in session
1364        .connection_pool()
1365        .await
1366        .user_connection_ids(session.user_id())
1367    {
1368        session
1369            .peer
1370            .send(
1371                connection_id,
1372                proto::CallCanceled {
1373                    room_id: room_id.to_proto(),
1374                },
1375            )
1376            .trace_err();
1377    }
1378
1379    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1380    {
1381        live_kit
1382            .room_token(
1383                &joined_room.room.livekit_room,
1384                &session.user_id().to_string(),
1385            )
1386            .trace_err()
1387            .map(|token| proto::LiveKitConnectionInfo {
1388                server_url: live_kit.url().into(),
1389                token,
1390                can_publish: true,
1391            })
1392    } else {
1393        None
1394    };
1395
1396    response.send(proto::JoinRoomResponse {
1397        room: Some(joined_room.room),
1398        channel_id: None,
1399        live_kit_connection_info,
1400    })?;
1401
1402    update_user_contacts(session.user_id(), &session).await?;
1403    Ok(())
1404}
1405
1406/// Rejoin room is used to reconnect to a room after connection errors.
1407async fn rejoin_room(
1408    request: proto::RejoinRoom,
1409    response: Response<proto::RejoinRoom>,
1410    session: MessageContext,
1411) -> Result<()> {
1412    let room;
1413    let channel;
1414    {
1415        let mut rejoined_room = session
1416            .db()
1417            .await
1418            .rejoin_room(request, session.user_id(), session.connection_id)
1419            .await?;
1420
1421        response.send(proto::RejoinRoomResponse {
1422            room: Some(rejoined_room.room.clone()),
1423            reshared_projects: rejoined_room
1424                .reshared_projects
1425                .iter()
1426                .map(|project| proto::ResharedProject {
1427                    id: project.id.to_proto(),
1428                    collaborators: project
1429                        .collaborators
1430                        .iter()
1431                        .map(|collaborator| collaborator.to_proto())
1432                        .collect(),
1433                })
1434                .collect(),
1435            rejoined_projects: rejoined_room
1436                .rejoined_projects
1437                .iter()
1438                .map(|rejoined_project| rejoined_project.to_proto())
1439                .collect(),
1440        })?;
1441        room_updated(&rejoined_room.room, &session.peer);
1442
1443        for project in &rejoined_room.reshared_projects {
1444            for collaborator in &project.collaborators {
1445                session
1446                    .peer
1447                    .send(
1448                        collaborator.connection_id,
1449                        proto::UpdateProjectCollaborator {
1450                            project_id: project.id.to_proto(),
1451                            old_peer_id: Some(project.old_connection_id.into()),
1452                            new_peer_id: Some(session.connection_id.into()),
1453                        },
1454                    )
1455                    .trace_err();
1456            }
1457
1458            broadcast(
1459                Some(session.connection_id),
1460                project
1461                    .collaborators
1462                    .iter()
1463                    .map(|collaborator| collaborator.connection_id),
1464                |connection_id| {
1465                    session.peer.forward_send(
1466                        session.connection_id,
1467                        connection_id,
1468                        proto::UpdateProject {
1469                            project_id: project.id.to_proto(),
1470                            worktrees: project.worktrees.clone(),
1471                        },
1472                    )
1473                },
1474            );
1475        }
1476
1477        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1478
1479        let rejoined_room = rejoined_room.into_inner();
1480
1481        room = rejoined_room.room;
1482        channel = rejoined_room.channel;
1483    }
1484
1485    if let Some(channel) = channel {
1486        channel_updated(
1487            &channel,
1488            &room,
1489            &session.peer,
1490            &*session.connection_pool().await,
1491        );
1492    }
1493
1494    update_user_contacts(session.user_id(), &session).await?;
1495    Ok(())
1496}
1497
1498fn notify_rejoined_projects(
1499    rejoined_projects: &mut Vec<RejoinedProject>,
1500    session: &Session,
1501) -> Result<()> {
1502    for project in rejoined_projects.iter() {
1503        for collaborator in &project.collaborators {
1504            session
1505                .peer
1506                .send(
1507                    collaborator.connection_id,
1508                    proto::UpdateProjectCollaborator {
1509                        project_id: project.id.to_proto(),
1510                        old_peer_id: Some(project.old_connection_id.into()),
1511                        new_peer_id: Some(session.connection_id.into()),
1512                    },
1513                )
1514                .trace_err();
1515        }
1516    }
1517
1518    for project in rejoined_projects {
1519        for worktree in mem::take(&mut project.worktrees) {
1520            // Stream this worktree's entries.
1521            let message = proto::UpdateWorktree {
1522                project_id: project.id.to_proto(),
1523                worktree_id: worktree.id,
1524                abs_path: worktree.abs_path.clone(),
1525                root_name: worktree.root_name,
1526                updated_entries: worktree.updated_entries,
1527                removed_entries: worktree.removed_entries,
1528                scan_id: worktree.scan_id,
1529                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1530                updated_repositories: worktree.updated_repositories,
1531                removed_repositories: worktree.removed_repositories,
1532            };
1533            for update in proto::split_worktree_update(message) {
1534                session.peer.send(session.connection_id, update)?;
1535            }
1536
1537            // Stream this worktree's diagnostics.
1538            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1539            if let Some(summary) = worktree_diagnostics.next() {
1540                let message = proto::UpdateDiagnosticSummary {
1541                    project_id: project.id.to_proto(),
1542                    worktree_id: worktree.id,
1543                    summary: Some(summary),
1544                    more_summaries: worktree_diagnostics.collect(),
1545                };
1546                session.peer.send(session.connection_id, message)?;
1547            }
1548
1549            for settings_file in worktree.settings_files {
1550                session.peer.send(
1551                    session.connection_id,
1552                    proto::UpdateWorktreeSettings {
1553                        project_id: project.id.to_proto(),
1554                        worktree_id: worktree.id,
1555                        path: settings_file.path,
1556                        content: Some(settings_file.content),
1557                        kind: Some(settings_file.kind.to_proto().into()),
1558                        outside_worktree: Some(settings_file.outside_worktree),
1559                    },
1560                )?;
1561            }
1562        }
1563
1564        for repository in mem::take(&mut project.updated_repositories) {
1565            for update in split_repository_update(repository) {
1566                session.peer.send(session.connection_id, update)?;
1567            }
1568        }
1569
1570        for id in mem::take(&mut project.removed_repositories) {
1571            session.peer.send(
1572                session.connection_id,
1573                proto::RemoveRepository {
1574                    project_id: project.id.to_proto(),
1575                    id,
1576                },
1577            )?;
1578        }
1579    }
1580
1581    Ok(())
1582}
1583
1584/// leave room disconnects from the room.
1585async fn leave_room(
1586    _: proto::LeaveRoom,
1587    response: Response<proto::LeaveRoom>,
1588    session: MessageContext,
1589) -> Result<()> {
1590    leave_room_for_session(&session, session.connection_id).await?;
1591    response.send(proto::Ack {})?;
1592    Ok(())
1593}
1594
1595/// Updates the permissions of someone else in the room.
1596async fn set_room_participant_role(
1597    request: proto::SetRoomParticipantRole,
1598    response: Response<proto::SetRoomParticipantRole>,
1599    session: MessageContext,
1600) -> Result<()> {
1601    let user_id = UserId::from_proto(request.user_id);
1602    let role = ChannelRole::from(request.role());
1603
1604    let (livekit_room, can_publish) = {
1605        let room = session
1606            .db()
1607            .await
1608            .set_room_participant_role(
1609                session.user_id(),
1610                RoomId::from_proto(request.room_id),
1611                user_id,
1612                role,
1613            )
1614            .await?;
1615
1616        let livekit_room = room.livekit_room.clone();
1617        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1618        room_updated(&room, &session.peer);
1619        (livekit_room, can_publish)
1620    };
1621
1622    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1623        live_kit
1624            .update_participant(
1625                livekit_room.clone(),
1626                request.user_id.to_string(),
1627                livekit_api::proto::ParticipantPermission {
1628                    can_subscribe: true,
1629                    can_publish,
1630                    can_publish_data: can_publish,
1631                    hidden: false,
1632                    recorder: false,
1633                },
1634            )
1635            .await
1636            .trace_err();
1637    }
1638
1639    response.send(proto::Ack {})?;
1640    Ok(())
1641}
1642
1643/// Call someone else into the current room
1644async fn call(
1645    request: proto::Call,
1646    response: Response<proto::Call>,
1647    session: MessageContext,
1648) -> Result<()> {
1649    let room_id = RoomId::from_proto(request.room_id);
1650    let calling_user_id = session.user_id();
1651    let calling_connection_id = session.connection_id;
1652    let called_user_id = UserId::from_proto(request.called_user_id);
1653    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1654    if !session
1655        .db()
1656        .await
1657        .has_contact(calling_user_id, called_user_id)
1658        .await?
1659    {
1660        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1661    }
1662
1663    let incoming_call = {
1664        let (room, incoming_call) = &mut *session
1665            .db()
1666            .await
1667            .call(
1668                room_id,
1669                calling_user_id,
1670                calling_connection_id,
1671                called_user_id,
1672                initial_project_id,
1673            )
1674            .await?;
1675        room_updated(room, &session.peer);
1676        mem::take(incoming_call)
1677    };
1678    update_user_contacts(called_user_id, &session).await?;
1679
1680    let mut calls = session
1681        .connection_pool()
1682        .await
1683        .user_connection_ids(called_user_id)
1684        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1685        .collect::<FuturesUnordered<_>>();
1686
1687    while let Some(call_response) = calls.next().await {
1688        match call_response.as_ref() {
1689            Ok(_) => {
1690                response.send(proto::Ack {})?;
1691                return Ok(());
1692            }
1693            Err(_) => {
1694                call_response.trace_err();
1695            }
1696        }
1697    }
1698
1699    {
1700        let room = session
1701            .db()
1702            .await
1703            .call_failed(room_id, called_user_id)
1704            .await?;
1705        room_updated(&room, &session.peer);
1706    }
1707    update_user_contacts(called_user_id, &session).await?;
1708
1709    Err(anyhow!("failed to ring user"))?
1710}
1711
1712/// Cancel an outgoing call.
1713async fn cancel_call(
1714    request: proto::CancelCall,
1715    response: Response<proto::CancelCall>,
1716    session: MessageContext,
1717) -> Result<()> {
1718    let called_user_id = UserId::from_proto(request.called_user_id);
1719    let room_id = RoomId::from_proto(request.room_id);
1720    {
1721        let room = session
1722            .db()
1723            .await
1724            .cancel_call(room_id, session.connection_id, called_user_id)
1725            .await?;
1726        room_updated(&room, &session.peer);
1727    }
1728
1729    for connection_id in session
1730        .connection_pool()
1731        .await
1732        .user_connection_ids(called_user_id)
1733    {
1734        session
1735            .peer
1736            .send(
1737                connection_id,
1738                proto::CallCanceled {
1739                    room_id: room_id.to_proto(),
1740                },
1741            )
1742            .trace_err();
1743    }
1744    response.send(proto::Ack {})?;
1745
1746    update_user_contacts(called_user_id, &session).await?;
1747    Ok(())
1748}
1749
1750/// Decline an incoming call.
1751async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1752    let room_id = RoomId::from_proto(message.room_id);
1753    {
1754        let room = session
1755            .db()
1756            .await
1757            .decline_call(Some(room_id), session.user_id())
1758            .await?
1759            .context("declining call")?;
1760        room_updated(&room, &session.peer);
1761    }
1762
1763    for connection_id in session
1764        .connection_pool()
1765        .await
1766        .user_connection_ids(session.user_id())
1767    {
1768        session
1769            .peer
1770            .send(
1771                connection_id,
1772                proto::CallCanceled {
1773                    room_id: room_id.to_proto(),
1774                },
1775            )
1776            .trace_err();
1777    }
1778    update_user_contacts(session.user_id(), &session).await?;
1779    Ok(())
1780}
1781
1782/// Updates other participants in the room with your current location.
1783async fn update_participant_location(
1784    request: proto::UpdateParticipantLocation,
1785    response: Response<proto::UpdateParticipantLocation>,
1786    session: MessageContext,
1787) -> Result<()> {
1788    let room_id = RoomId::from_proto(request.room_id);
1789    let location = request.location.context("invalid location")?;
1790
1791    let db = session.db().await;
1792    let room = db
1793        .update_room_participant_location(room_id, session.connection_id, location)
1794        .await?;
1795
1796    room_updated(&room, &session.peer);
1797    response.send(proto::Ack {})?;
1798    Ok(())
1799}
1800
1801/// Share a project into the room.
1802async fn share_project(
1803    request: proto::ShareProject,
1804    response: Response<proto::ShareProject>,
1805    session: MessageContext,
1806) -> Result<()> {
1807    let (project_id, room) = &*session
1808        .db()
1809        .await
1810        .share_project(
1811            RoomId::from_proto(request.room_id),
1812            session.connection_id,
1813            &request.worktrees,
1814            request.is_ssh_project,
1815            request.windows_paths.unwrap_or(false),
1816        )
1817        .await?;
1818    response.send(proto::ShareProjectResponse {
1819        project_id: project_id.to_proto(),
1820    })?;
1821    room_updated(room, &session.peer);
1822
1823    Ok(())
1824}
1825
1826/// Unshare a project from the room.
1827async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1828    let project_id = ProjectId::from_proto(message.project_id);
1829    unshare_project_internal(project_id, session.connection_id, &session).await
1830}
1831
1832async fn unshare_project_internal(
1833    project_id: ProjectId,
1834    connection_id: ConnectionId,
1835    session: &Session,
1836) -> Result<()> {
1837    let delete = {
1838        let room_guard = session
1839            .db()
1840            .await
1841            .unshare_project(project_id, connection_id)
1842            .await?;
1843
1844        let (delete, room, guest_connection_ids) = &*room_guard;
1845
1846        let message = proto::UnshareProject {
1847            project_id: project_id.to_proto(),
1848        };
1849
1850        broadcast(
1851            Some(connection_id),
1852            guest_connection_ids.iter().copied(),
1853            |conn_id| session.peer.send(conn_id, message.clone()),
1854        );
1855        if let Some(room) = room {
1856            room_updated(room, &session.peer);
1857        }
1858
1859        *delete
1860    };
1861
1862    if delete {
1863        let db = session.db().await;
1864        db.delete_project(project_id).await?;
1865    }
1866
1867    Ok(())
1868}
1869
1870/// Join someone elses shared project.
1871async fn join_project(
1872    request: proto::JoinProject,
1873    response: Response<proto::JoinProject>,
1874    session: MessageContext,
1875) -> Result<()> {
1876    let project_id = ProjectId::from_proto(request.project_id);
1877
1878    tracing::info!(%project_id, "join project");
1879
1880    let db = session.db().await;
1881    let (project, replica_id) = &mut *db
1882        .join_project(
1883            project_id,
1884            session.connection_id,
1885            session.user_id(),
1886            request.committer_name.clone(),
1887            request.committer_email.clone(),
1888        )
1889        .await?;
1890    drop(db);
1891    tracing::info!(%project_id, "join remote project");
1892    let collaborators = project
1893        .collaborators
1894        .iter()
1895        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1896        .map(|collaborator| collaborator.to_proto())
1897        .collect::<Vec<_>>();
1898    let project_id = project.id;
1899    let guest_user_id = session.user_id();
1900
1901    let worktrees = project
1902        .worktrees
1903        .iter()
1904        .map(|(id, worktree)| proto::WorktreeMetadata {
1905            id: *id,
1906            root_name: worktree.root_name.clone(),
1907            visible: worktree.visible,
1908            abs_path: worktree.abs_path.clone(),
1909        })
1910        .collect::<Vec<_>>();
1911
1912    let add_project_collaborator = proto::AddProjectCollaborator {
1913        project_id: project_id.to_proto(),
1914        collaborator: Some(proto::Collaborator {
1915            peer_id: Some(session.connection_id.into()),
1916            replica_id: replica_id.0 as u32,
1917            user_id: guest_user_id.to_proto(),
1918            is_host: false,
1919            committer_name: request.committer_name.clone(),
1920            committer_email: request.committer_email.clone(),
1921        }),
1922    };
1923
1924    for collaborator in &collaborators {
1925        session
1926            .peer
1927            .send(
1928                collaborator.peer_id.unwrap().into(),
1929                add_project_collaborator.clone(),
1930            )
1931            .trace_err();
1932    }
1933
1934    // First, we send the metadata associated with each worktree.
1935    let (language_servers, language_server_capabilities) = project
1936        .language_servers
1937        .clone()
1938        .into_iter()
1939        .map(|server| (server.server, server.capabilities))
1940        .unzip();
1941    response.send(proto::JoinProjectResponse {
1942        project_id: project.id.0 as u64,
1943        worktrees,
1944        replica_id: replica_id.0 as u32,
1945        collaborators,
1946        language_servers,
1947        language_server_capabilities,
1948        role: project.role.into(),
1949        windows_paths: project.path_style == PathStyle::Windows,
1950    })?;
1951
1952    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1953        // Stream this worktree's entries.
1954        let message = proto::UpdateWorktree {
1955            project_id: project_id.to_proto(),
1956            worktree_id,
1957            abs_path: worktree.abs_path.clone(),
1958            root_name: worktree.root_name,
1959            updated_entries: worktree.entries,
1960            removed_entries: Default::default(),
1961            scan_id: worktree.scan_id,
1962            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1963            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1964            removed_repositories: Default::default(),
1965        };
1966        for update in proto::split_worktree_update(message) {
1967            session.peer.send(session.connection_id, update.clone())?;
1968        }
1969
1970        // Stream this worktree's diagnostics.
1971        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1972        if let Some(summary) = worktree_diagnostics.next() {
1973            let message = proto::UpdateDiagnosticSummary {
1974                project_id: project.id.to_proto(),
1975                worktree_id: worktree.id,
1976                summary: Some(summary),
1977                more_summaries: worktree_diagnostics.collect(),
1978            };
1979            session.peer.send(session.connection_id, message)?;
1980        }
1981
1982        for settings_file in worktree.settings_files {
1983            session.peer.send(
1984                session.connection_id,
1985                proto::UpdateWorktreeSettings {
1986                    project_id: project_id.to_proto(),
1987                    worktree_id: worktree.id,
1988                    path: settings_file.path,
1989                    content: Some(settings_file.content),
1990                    kind: Some(settings_file.kind.to_proto() as i32),
1991                    outside_worktree: Some(settings_file.outside_worktree),
1992                },
1993            )?;
1994        }
1995    }
1996
1997    for repository in mem::take(&mut project.repositories) {
1998        for update in split_repository_update(repository) {
1999            session.peer.send(session.connection_id, update)?;
2000        }
2001    }
2002
2003    for language_server in &project.language_servers {
2004        session.peer.send(
2005            session.connection_id,
2006            proto::UpdateLanguageServer {
2007                project_id: project_id.to_proto(),
2008                server_name: Some(language_server.server.name.clone()),
2009                language_server_id: language_server.server.id,
2010                variant: Some(
2011                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2012                        proto::LspDiskBasedDiagnosticsUpdated {},
2013                    ),
2014                ),
2015            },
2016        )?;
2017    }
2018
2019    Ok(())
2020}
2021
2022/// Leave someone elses shared project.
2023async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2024    let sender_id = session.connection_id;
2025    let project_id = ProjectId::from_proto(request.project_id);
2026    let db = session.db().await;
2027
2028    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2029    tracing::info!(
2030        %project_id,
2031        "leave project"
2032    );
2033
2034    project_left(project, &session);
2035    if let Some(room) = room {
2036        room_updated(room, &session.peer);
2037    }
2038
2039    Ok(())
2040}
2041
2042/// Updates other participants with changes to the project
2043async fn update_project(
2044    request: proto::UpdateProject,
2045    response: Response<proto::UpdateProject>,
2046    session: MessageContext,
2047) -> Result<()> {
2048    let project_id = ProjectId::from_proto(request.project_id);
2049    let (room, guest_connection_ids) = &*session
2050        .db()
2051        .await
2052        .update_project(project_id, session.connection_id, &request.worktrees)
2053        .await?;
2054    broadcast(
2055        Some(session.connection_id),
2056        guest_connection_ids.iter().copied(),
2057        |connection_id| {
2058            session
2059                .peer
2060                .forward_send(session.connection_id, connection_id, request.clone())
2061        },
2062    );
2063    if let Some(room) = room {
2064        room_updated(room, &session.peer);
2065    }
2066    response.send(proto::Ack {})?;
2067
2068    Ok(())
2069}
2070
2071/// Updates other participants with changes to the worktree
2072async fn update_worktree(
2073    request: proto::UpdateWorktree,
2074    response: Response<proto::UpdateWorktree>,
2075    session: MessageContext,
2076) -> Result<()> {
2077    let guest_connection_ids = session
2078        .db()
2079        .await
2080        .update_worktree(&request, session.connection_id)
2081        .await?;
2082
2083    broadcast(
2084        Some(session.connection_id),
2085        guest_connection_ids.iter().copied(),
2086        |connection_id| {
2087            session
2088                .peer
2089                .forward_send(session.connection_id, connection_id, request.clone())
2090        },
2091    );
2092    response.send(proto::Ack {})?;
2093    Ok(())
2094}
2095
2096async fn update_repository(
2097    request: proto::UpdateRepository,
2098    response: Response<proto::UpdateRepository>,
2099    session: MessageContext,
2100) -> Result<()> {
2101    let guest_connection_ids = session
2102        .db()
2103        .await
2104        .update_repository(&request, session.connection_id)
2105        .await?;
2106
2107    broadcast(
2108        Some(session.connection_id),
2109        guest_connection_ids.iter().copied(),
2110        |connection_id| {
2111            session
2112                .peer
2113                .forward_send(session.connection_id, connection_id, request.clone())
2114        },
2115    );
2116    response.send(proto::Ack {})?;
2117    Ok(())
2118}
2119
2120async fn remove_repository(
2121    request: proto::RemoveRepository,
2122    response: Response<proto::RemoveRepository>,
2123    session: MessageContext,
2124) -> Result<()> {
2125    let guest_connection_ids = session
2126        .db()
2127        .await
2128        .remove_repository(&request, session.connection_id)
2129        .await?;
2130
2131    broadcast(
2132        Some(session.connection_id),
2133        guest_connection_ids.iter().copied(),
2134        |connection_id| {
2135            session
2136                .peer
2137                .forward_send(session.connection_id, connection_id, request.clone())
2138        },
2139    );
2140    response.send(proto::Ack {})?;
2141    Ok(())
2142}
2143
2144/// Updates other participants with changes to the diagnostics
2145async fn update_diagnostic_summary(
2146    message: proto::UpdateDiagnosticSummary,
2147    session: MessageContext,
2148) -> Result<()> {
2149    let guest_connection_ids = session
2150        .db()
2151        .await
2152        .update_diagnostic_summary(&message, session.connection_id)
2153        .await?;
2154
2155    broadcast(
2156        Some(session.connection_id),
2157        guest_connection_ids.iter().copied(),
2158        |connection_id| {
2159            session
2160                .peer
2161                .forward_send(session.connection_id, connection_id, message.clone())
2162        },
2163    );
2164
2165    Ok(())
2166}
2167
2168/// Updates other participants with changes to the worktree settings
2169async fn update_worktree_settings(
2170    message: proto::UpdateWorktreeSettings,
2171    session: MessageContext,
2172) -> Result<()> {
2173    let guest_connection_ids = session
2174        .db()
2175        .await
2176        .update_worktree_settings(&message, session.connection_id)
2177        .await?;
2178
2179    broadcast(
2180        Some(session.connection_id),
2181        guest_connection_ids.iter().copied(),
2182        |connection_id| {
2183            session
2184                .peer
2185                .forward_send(session.connection_id, connection_id, message.clone())
2186        },
2187    );
2188
2189    Ok(())
2190}
2191
2192/// Notify other participants that a language server has started.
2193async fn start_language_server(
2194    request: proto::StartLanguageServer,
2195    session: MessageContext,
2196) -> Result<()> {
2197    let guest_connection_ids = session
2198        .db()
2199        .await
2200        .start_language_server(&request, session.connection_id)
2201        .await?;
2202
2203    broadcast(
2204        Some(session.connection_id),
2205        guest_connection_ids.iter().copied(),
2206        |connection_id| {
2207            session
2208                .peer
2209                .forward_send(session.connection_id, connection_id, request.clone())
2210        },
2211    );
2212    Ok(())
2213}
2214
2215/// Notify other participants that a language server has changed.
2216async fn update_language_server(
2217    request: proto::UpdateLanguageServer,
2218    session: MessageContext,
2219) -> Result<()> {
2220    let project_id = ProjectId::from_proto(request.project_id);
2221    let db = session.db().await;
2222
2223    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2224        && let Some(capabilities) = update.capabilities.clone()
2225    {
2226        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2227            .await?;
2228    }
2229
2230    let project_connection_ids = db
2231        .project_connection_ids(project_id, session.connection_id, true)
2232        .await?;
2233    broadcast(
2234        Some(session.connection_id),
2235        project_connection_ids.iter().copied(),
2236        |connection_id| {
2237            session
2238                .peer
2239                .forward_send(session.connection_id, connection_id, request.clone())
2240        },
2241    );
2242    Ok(())
2243}
2244
2245/// forward a project request to the host. These requests should be read only
2246/// as guests are allowed to send them.
2247async fn forward_read_only_project_request<T>(
2248    request: T,
2249    response: Response<T>,
2250    session: MessageContext,
2251) -> Result<()>
2252where
2253    T: EntityMessage + RequestMessage,
2254{
2255    let project_id = ProjectId::from_proto(request.remote_entity_id());
2256    let host_connection_id = session
2257        .db()
2258        .await
2259        .host_for_read_only_project_request(project_id, session.connection_id)
2260        .await?;
2261    let payload = session.forward_request(host_connection_id, request).await?;
2262    response.send(payload)?;
2263    Ok(())
2264}
2265
2266/// forward a project request to the host. These requests are disallowed
2267/// for guests.
2268async fn forward_mutating_project_request<T>(
2269    request: T,
2270    response: Response<T>,
2271    session: MessageContext,
2272) -> Result<()>
2273where
2274    T: EntityMessage + RequestMessage,
2275{
2276    let project_id = ProjectId::from_proto(request.remote_entity_id());
2277
2278    let host_connection_id = session
2279        .db()
2280        .await
2281        .host_for_mutating_project_request(project_id, session.connection_id)
2282        .await?;
2283    let payload = session.forward_request(host_connection_id, request).await?;
2284    response.send(payload)?;
2285    Ok(())
2286}
2287
2288async fn lsp_query(
2289    request: proto::LspQuery,
2290    response: Response<proto::LspQuery>,
2291    session: MessageContext,
2292) -> Result<()> {
2293    let (name, should_write) = request.query_name_and_write_permissions();
2294    tracing::Span::current().record("lsp_query_request", name);
2295    tracing::info!("lsp_query message received");
2296    if should_write {
2297        forward_mutating_project_request(request, response, session).await
2298    } else {
2299        forward_read_only_project_request(request, response, session).await
2300    }
2301}
2302
2303/// Notify other participants that a new buffer has been created
2304async fn create_buffer_for_peer(
2305    request: proto::CreateBufferForPeer,
2306    session: MessageContext,
2307) -> Result<()> {
2308    session
2309        .db()
2310        .await
2311        .check_user_is_project_host(
2312            ProjectId::from_proto(request.project_id),
2313            session.connection_id,
2314        )
2315        .await?;
2316    let peer_id = request.peer_id.context("invalid peer id")?;
2317    session
2318        .peer
2319        .forward_send(session.connection_id, peer_id.into(), request)?;
2320    Ok(())
2321}
2322
2323/// Notify other participants that a new image has been created
2324async fn create_image_for_peer(
2325    request: proto::CreateImageForPeer,
2326    session: MessageContext,
2327) -> Result<()> {
2328    session
2329        .db()
2330        .await
2331        .check_user_is_project_host(
2332            ProjectId::from_proto(request.project_id),
2333            session.connection_id,
2334        )
2335        .await?;
2336    let peer_id = request.peer_id.context("invalid peer id")?;
2337    session
2338        .peer
2339        .forward_send(session.connection_id, peer_id.into(), request)?;
2340    Ok(())
2341}
2342
2343/// Notify other participants that a buffer has been updated. This is
2344/// allowed for guests as long as the update is limited to selections.
2345async fn update_buffer(
2346    request: proto::UpdateBuffer,
2347    response: Response<proto::UpdateBuffer>,
2348    session: MessageContext,
2349) -> Result<()> {
2350    let project_id = ProjectId::from_proto(request.project_id);
2351    let mut capability = Capability::ReadOnly;
2352
2353    for op in request.operations.iter() {
2354        match op.variant {
2355            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2356            Some(_) => capability = Capability::ReadWrite,
2357        }
2358    }
2359
2360    let host = {
2361        let guard = session
2362            .db()
2363            .await
2364            .connections_for_buffer_update(project_id, session.connection_id, capability)
2365            .await?;
2366
2367        let (host, guests) = &*guard;
2368
2369        broadcast(
2370            Some(session.connection_id),
2371            guests.clone(),
2372            |connection_id| {
2373                session
2374                    .peer
2375                    .forward_send(session.connection_id, connection_id, request.clone())
2376            },
2377        );
2378
2379        *host
2380    };
2381
2382    if host != session.connection_id {
2383        session.forward_request(host, request.clone()).await?;
2384    }
2385
2386    response.send(proto::Ack {})?;
2387    Ok(())
2388}
2389
2390async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2391    let project_id = ProjectId::from_proto(message.project_id);
2392
2393    let operation = message.operation.as_ref().context("invalid operation")?;
2394    let capability = match operation.variant.as_ref() {
2395        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2396            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2397                match buffer_op.variant {
2398                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2399                        Capability::ReadOnly
2400                    }
2401                    _ => Capability::ReadWrite,
2402                }
2403            } else {
2404                Capability::ReadWrite
2405            }
2406        }
2407        Some(_) => Capability::ReadWrite,
2408        None => Capability::ReadOnly,
2409    };
2410
2411    let guard = session
2412        .db()
2413        .await
2414        .connections_for_buffer_update(project_id, session.connection_id, capability)
2415        .await?;
2416
2417    let (host, guests) = &*guard;
2418
2419    broadcast(
2420        Some(session.connection_id),
2421        guests.iter().chain([host]).copied(),
2422        |connection_id| {
2423            session
2424                .peer
2425                .forward_send(session.connection_id, connection_id, message.clone())
2426        },
2427    );
2428
2429    Ok(())
2430}
2431
2432async fn forward_project_search_chunk(
2433    message: proto::FindSearchCandidatesChunk,
2434    response: Response<proto::FindSearchCandidatesChunk>,
2435    session: MessageContext,
2436) -> Result<()> {
2437    let peer_id = message.peer_id.context("missing peer_id")?;
2438    let payload = session
2439        .peer
2440        .forward_request(session.connection_id, peer_id.into(), message)
2441        .await?;
2442    response.send(payload)?;
2443    Ok(())
2444}
2445
2446/// Notify other participants that a project has been updated.
2447async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2448    request: T,
2449    session: MessageContext,
2450) -> Result<()> {
2451    let project_id = ProjectId::from_proto(request.remote_entity_id());
2452    let project_connection_ids = session
2453        .db()
2454        .await
2455        .project_connection_ids(project_id, session.connection_id, false)
2456        .await?;
2457
2458    broadcast(
2459        Some(session.connection_id),
2460        project_connection_ids.iter().copied(),
2461        |connection_id| {
2462            session
2463                .peer
2464                .forward_send(session.connection_id, connection_id, request.clone())
2465        },
2466    );
2467    Ok(())
2468}
2469
2470/// Start following another user in a call.
2471async fn follow(
2472    request: proto::Follow,
2473    response: Response<proto::Follow>,
2474    session: MessageContext,
2475) -> Result<()> {
2476    let room_id = RoomId::from_proto(request.room_id);
2477    let project_id = request.project_id.map(ProjectId::from_proto);
2478    let leader_id = request.leader_id.context("invalid leader id")?.into();
2479    let follower_id = session.connection_id;
2480
2481    session
2482        .db()
2483        .await
2484        .check_room_participants(room_id, leader_id, session.connection_id)
2485        .await?;
2486
2487    let response_payload = session.forward_request(leader_id, request).await?;
2488    response.send(response_payload)?;
2489
2490    if let Some(project_id) = project_id {
2491        let room = session
2492            .db()
2493            .await
2494            .follow(room_id, project_id, leader_id, follower_id)
2495            .await?;
2496        room_updated(&room, &session.peer);
2497    }
2498
2499    Ok(())
2500}
2501
2502/// Stop following another user in a call.
2503async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2504    let room_id = RoomId::from_proto(request.room_id);
2505    let project_id = request.project_id.map(ProjectId::from_proto);
2506    let leader_id = request.leader_id.context("invalid leader id")?.into();
2507    let follower_id = session.connection_id;
2508
2509    session
2510        .db()
2511        .await
2512        .check_room_participants(room_id, leader_id, session.connection_id)
2513        .await?;
2514
2515    session
2516        .peer
2517        .forward_send(session.connection_id, leader_id, request)?;
2518
2519    if let Some(project_id) = project_id {
2520        let room = session
2521            .db()
2522            .await
2523            .unfollow(room_id, project_id, leader_id, follower_id)
2524            .await?;
2525        room_updated(&room, &session.peer);
2526    }
2527
2528    Ok(())
2529}
2530
2531/// Notify everyone following you of your current location.
2532async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2533    let room_id = RoomId::from_proto(request.room_id);
2534    let database = session.db.lock().await;
2535
2536    let connection_ids = if let Some(project_id) = request.project_id {
2537        let project_id = ProjectId::from_proto(project_id);
2538        database
2539            .project_connection_ids(project_id, session.connection_id, true)
2540            .await?
2541    } else {
2542        database
2543            .room_connection_ids(room_id, session.connection_id)
2544            .await?
2545    };
2546
2547    // For now, don't send view update messages back to that view's current leader.
2548    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2549        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2550        _ => None,
2551    });
2552
2553    for connection_id in connection_ids.iter().cloned() {
2554        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2555            session
2556                .peer
2557                .forward_send(session.connection_id, connection_id, request.clone())?;
2558        }
2559    }
2560    Ok(())
2561}
2562
2563/// Get public data about users.
2564async fn get_users(
2565    request: proto::GetUsers,
2566    response: Response<proto::GetUsers>,
2567    session: MessageContext,
2568) -> Result<()> {
2569    let user_ids = request
2570        .user_ids
2571        .into_iter()
2572        .map(UserId::from_proto)
2573        .collect();
2574    let users = session
2575        .db()
2576        .await
2577        .get_users_by_ids(user_ids)
2578        .await?
2579        .into_iter()
2580        .map(|user| proto::User {
2581            id: user.id.to_proto(),
2582            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2583            github_login: user.github_login,
2584            name: user.name,
2585        })
2586        .collect();
2587    response.send(proto::UsersResponse { users })?;
2588    Ok(())
2589}
2590
2591/// Search for users (to invite) buy Github login
2592async fn fuzzy_search_users(
2593    request: proto::FuzzySearchUsers,
2594    response: Response<proto::FuzzySearchUsers>,
2595    session: MessageContext,
2596) -> Result<()> {
2597    let query = request.query;
2598    let users = match query.len() {
2599        0 => vec![],
2600        1 | 2 => session
2601            .db()
2602            .await
2603            .get_user_by_github_login(&query)
2604            .await?
2605            .into_iter()
2606            .collect(),
2607        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2608    };
2609    let users = users
2610        .into_iter()
2611        .filter(|user| user.id != session.user_id())
2612        .map(|user| proto::User {
2613            id: user.id.to_proto(),
2614            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2615            github_login: user.github_login,
2616            name: user.name,
2617        })
2618        .collect();
2619    response.send(proto::UsersResponse { users })?;
2620    Ok(())
2621}
2622
2623/// Send a contact request to another user.
2624async fn request_contact(
2625    request: proto::RequestContact,
2626    response: Response<proto::RequestContact>,
2627    session: MessageContext,
2628) -> Result<()> {
2629    let requester_id = session.user_id();
2630    let responder_id = UserId::from_proto(request.responder_id);
2631    if requester_id == responder_id {
2632        return Err(anyhow!("cannot add yourself as a contact"))?;
2633    }
2634
2635    let notifications = session
2636        .db()
2637        .await
2638        .send_contact_request(requester_id, responder_id)
2639        .await?;
2640
2641    // Update outgoing contact requests of requester
2642    let mut update = proto::UpdateContacts::default();
2643    update.outgoing_requests.push(responder_id.to_proto());
2644    for connection_id in session
2645        .connection_pool()
2646        .await
2647        .user_connection_ids(requester_id)
2648    {
2649        session.peer.send(connection_id, update.clone())?;
2650    }
2651
2652    // Update incoming contact requests of responder
2653    let mut update = proto::UpdateContacts::default();
2654    update
2655        .incoming_requests
2656        .push(proto::IncomingContactRequest {
2657            requester_id: requester_id.to_proto(),
2658        });
2659    let connection_pool = session.connection_pool().await;
2660    for connection_id in connection_pool.user_connection_ids(responder_id) {
2661        session.peer.send(connection_id, update.clone())?;
2662    }
2663
2664    send_notifications(&connection_pool, &session.peer, notifications);
2665
2666    response.send(proto::Ack {})?;
2667    Ok(())
2668}
2669
2670/// Accept or decline a contact request
2671async fn respond_to_contact_request(
2672    request: proto::RespondToContactRequest,
2673    response: Response<proto::RespondToContactRequest>,
2674    session: MessageContext,
2675) -> Result<()> {
2676    let responder_id = session.user_id();
2677    let requester_id = UserId::from_proto(request.requester_id);
2678    let db = session.db().await;
2679    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2680        db.dismiss_contact_notification(responder_id, requester_id)
2681            .await?;
2682    } else {
2683        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2684
2685        let notifications = db
2686            .respond_to_contact_request(responder_id, requester_id, accept)
2687            .await?;
2688        let requester_busy = db.is_user_busy(requester_id).await?;
2689        let responder_busy = db.is_user_busy(responder_id).await?;
2690
2691        let pool = session.connection_pool().await;
2692        // Update responder with new contact
2693        let mut update = proto::UpdateContacts::default();
2694        if accept {
2695            update
2696                .contacts
2697                .push(contact_for_user(requester_id, requester_busy, &pool));
2698        }
2699        update
2700            .remove_incoming_requests
2701            .push(requester_id.to_proto());
2702        for connection_id in pool.user_connection_ids(responder_id) {
2703            session.peer.send(connection_id, update.clone())?;
2704        }
2705
2706        // Update requester with new contact
2707        let mut update = proto::UpdateContacts::default();
2708        if accept {
2709            update
2710                .contacts
2711                .push(contact_for_user(responder_id, responder_busy, &pool));
2712        }
2713        update
2714            .remove_outgoing_requests
2715            .push(responder_id.to_proto());
2716
2717        for connection_id in pool.user_connection_ids(requester_id) {
2718            session.peer.send(connection_id, update.clone())?;
2719        }
2720
2721        send_notifications(&pool, &session.peer, notifications);
2722    }
2723
2724    response.send(proto::Ack {})?;
2725    Ok(())
2726}
2727
2728/// Remove a contact.
2729async fn remove_contact(
2730    request: proto::RemoveContact,
2731    response: Response<proto::RemoveContact>,
2732    session: MessageContext,
2733) -> Result<()> {
2734    let requester_id = session.user_id();
2735    let responder_id = UserId::from_proto(request.user_id);
2736    let db = session.db().await;
2737    let (contact_accepted, deleted_notification_id) =
2738        db.remove_contact(requester_id, responder_id).await?;
2739
2740    let pool = session.connection_pool().await;
2741    // Update outgoing contact requests of requester
2742    let mut update = proto::UpdateContacts::default();
2743    if contact_accepted {
2744        update.remove_contacts.push(responder_id.to_proto());
2745    } else {
2746        update
2747            .remove_outgoing_requests
2748            .push(responder_id.to_proto());
2749    }
2750    for connection_id in pool.user_connection_ids(requester_id) {
2751        session.peer.send(connection_id, update.clone())?;
2752    }
2753
2754    // Update incoming contact requests of responder
2755    let mut update = proto::UpdateContacts::default();
2756    if contact_accepted {
2757        update.remove_contacts.push(requester_id.to_proto());
2758    } else {
2759        update
2760            .remove_incoming_requests
2761            .push(requester_id.to_proto());
2762    }
2763    for connection_id in pool.user_connection_ids(responder_id) {
2764        session.peer.send(connection_id, update.clone())?;
2765        if let Some(notification_id) = deleted_notification_id {
2766            session.peer.send(
2767                connection_id,
2768                proto::DeleteNotification {
2769                    notification_id: notification_id.to_proto(),
2770                },
2771            )?;
2772        }
2773    }
2774
2775    response.send(proto::Ack {})?;
2776    Ok(())
2777}
2778
2779fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2780    version.0.minor < 139
2781}
2782
2783async fn subscribe_to_channels(
2784    _: proto::SubscribeToChannels,
2785    session: MessageContext,
2786) -> Result<()> {
2787    subscribe_user_to_channels(session.user_id(), &session).await?;
2788    Ok(())
2789}
2790
2791async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2792    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2793    let mut pool = session.connection_pool().await;
2794    for membership in &channels_for_user.channel_memberships {
2795        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2796    }
2797    session.peer.send(
2798        session.connection_id,
2799        build_update_user_channels(&channels_for_user),
2800    )?;
2801    session.peer.send(
2802        session.connection_id,
2803        build_channels_update(channels_for_user),
2804    )?;
2805    Ok(())
2806}
2807
2808/// Creates a new channel.
2809async fn create_channel(
2810    request: proto::CreateChannel,
2811    response: Response<proto::CreateChannel>,
2812    session: MessageContext,
2813) -> Result<()> {
2814    let db = session.db().await;
2815
2816    let parent_id = request.parent_id.map(ChannelId::from_proto);
2817    let (channel, membership) = db
2818        .create_channel(&request.name, parent_id, session.user_id())
2819        .await?;
2820
2821    let root_id = channel.root_id();
2822    let channel = Channel::from_model(channel);
2823
2824    response.send(proto::CreateChannelResponse {
2825        channel: Some(channel.to_proto()),
2826        parent_id: request.parent_id,
2827    })?;
2828
2829    let mut connection_pool = session.connection_pool().await;
2830    if let Some(membership) = membership {
2831        connection_pool.subscribe_to_channel(
2832            membership.user_id,
2833            membership.channel_id,
2834            membership.role,
2835        );
2836        let update = proto::UpdateUserChannels {
2837            channel_memberships: vec![proto::ChannelMembership {
2838                channel_id: membership.channel_id.to_proto(),
2839                role: membership.role.into(),
2840            }],
2841            ..Default::default()
2842        };
2843        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2844            session.peer.send(connection_id, update.clone())?;
2845        }
2846    }
2847
2848    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2849        if !role.can_see_channel(channel.visibility) {
2850            continue;
2851        }
2852
2853        let update = proto::UpdateChannels {
2854            channels: vec![channel.to_proto()],
2855            ..Default::default()
2856        };
2857        session.peer.send(connection_id, update.clone())?;
2858    }
2859
2860    Ok(())
2861}
2862
2863/// Delete a channel
2864async fn delete_channel(
2865    request: proto::DeleteChannel,
2866    response: Response<proto::DeleteChannel>,
2867    session: MessageContext,
2868) -> Result<()> {
2869    let db = session.db().await;
2870
2871    let channel_id = request.channel_id;
2872    let (root_channel, removed_channels) = db
2873        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2874        .await?;
2875    response.send(proto::Ack {})?;
2876
2877    // Notify members of removed channels
2878    let mut update = proto::UpdateChannels::default();
2879    update
2880        .delete_channels
2881        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2882
2883    let connection_pool = session.connection_pool().await;
2884    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2885        session.peer.send(connection_id, update.clone())?;
2886    }
2887
2888    Ok(())
2889}
2890
2891/// Invite someone to join a channel.
2892async fn invite_channel_member(
2893    request: proto::InviteChannelMember,
2894    response: Response<proto::InviteChannelMember>,
2895    session: MessageContext,
2896) -> Result<()> {
2897    let db = session.db().await;
2898    let channel_id = ChannelId::from_proto(request.channel_id);
2899    let invitee_id = UserId::from_proto(request.user_id);
2900    let InviteMemberResult {
2901        channel,
2902        notifications,
2903    } = db
2904        .invite_channel_member(
2905            channel_id,
2906            invitee_id,
2907            session.user_id(),
2908            request.role().into(),
2909        )
2910        .await?;
2911
2912    let update = proto::UpdateChannels {
2913        channel_invitations: vec![channel.to_proto()],
2914        ..Default::default()
2915    };
2916
2917    let connection_pool = session.connection_pool().await;
2918    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2919        session.peer.send(connection_id, update.clone())?;
2920    }
2921
2922    send_notifications(&connection_pool, &session.peer, notifications);
2923
2924    response.send(proto::Ack {})?;
2925    Ok(())
2926}
2927
2928/// remove someone from a channel
2929async fn remove_channel_member(
2930    request: proto::RemoveChannelMember,
2931    response: Response<proto::RemoveChannelMember>,
2932    session: MessageContext,
2933) -> Result<()> {
2934    let db = session.db().await;
2935    let channel_id = ChannelId::from_proto(request.channel_id);
2936    let member_id = UserId::from_proto(request.user_id);
2937
2938    let RemoveChannelMemberResult {
2939        membership_update,
2940        notification_id,
2941    } = db
2942        .remove_channel_member(channel_id, member_id, session.user_id())
2943        .await?;
2944
2945    let mut connection_pool = session.connection_pool().await;
2946    notify_membership_updated(
2947        &mut connection_pool,
2948        membership_update,
2949        member_id,
2950        &session.peer,
2951    );
2952    for connection_id in connection_pool.user_connection_ids(member_id) {
2953        if let Some(notification_id) = notification_id {
2954            session
2955                .peer
2956                .send(
2957                    connection_id,
2958                    proto::DeleteNotification {
2959                        notification_id: notification_id.to_proto(),
2960                    },
2961                )
2962                .trace_err();
2963        }
2964    }
2965
2966    response.send(proto::Ack {})?;
2967    Ok(())
2968}
2969
2970/// Toggle the channel between public and private.
2971/// Care is taken to maintain the invariant that public channels only descend from public channels,
2972/// (though members-only channels can appear at any point in the hierarchy).
2973async fn set_channel_visibility(
2974    request: proto::SetChannelVisibility,
2975    response: Response<proto::SetChannelVisibility>,
2976    session: MessageContext,
2977) -> Result<()> {
2978    let db = session.db().await;
2979    let channel_id = ChannelId::from_proto(request.channel_id);
2980    let visibility = request.visibility().into();
2981
2982    let channel_model = db
2983        .set_channel_visibility(channel_id, visibility, session.user_id())
2984        .await?;
2985    let root_id = channel_model.root_id();
2986    let channel = Channel::from_model(channel_model);
2987
2988    let mut connection_pool = session.connection_pool().await;
2989    for (user_id, role) in connection_pool
2990        .channel_user_ids(root_id)
2991        .collect::<Vec<_>>()
2992        .into_iter()
2993    {
2994        let update = if role.can_see_channel(channel.visibility) {
2995            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2996            proto::UpdateChannels {
2997                channels: vec![channel.to_proto()],
2998                ..Default::default()
2999            }
3000        } else {
3001            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3002            proto::UpdateChannels {
3003                delete_channels: vec![channel.id.to_proto()],
3004                ..Default::default()
3005            }
3006        };
3007
3008        for connection_id in connection_pool.user_connection_ids(user_id) {
3009            session.peer.send(connection_id, update.clone())?;
3010        }
3011    }
3012
3013    response.send(proto::Ack {})?;
3014    Ok(())
3015}
3016
3017/// Alter the role for a user in the channel.
3018async fn set_channel_member_role(
3019    request: proto::SetChannelMemberRole,
3020    response: Response<proto::SetChannelMemberRole>,
3021    session: MessageContext,
3022) -> Result<()> {
3023    let db = session.db().await;
3024    let channel_id = ChannelId::from_proto(request.channel_id);
3025    let member_id = UserId::from_proto(request.user_id);
3026    let result = db
3027        .set_channel_member_role(
3028            channel_id,
3029            session.user_id(),
3030            member_id,
3031            request.role().into(),
3032        )
3033        .await?;
3034
3035    match result {
3036        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3037            let mut connection_pool = session.connection_pool().await;
3038            notify_membership_updated(
3039                &mut connection_pool,
3040                membership_update,
3041                member_id,
3042                &session.peer,
3043            )
3044        }
3045        db::SetMemberRoleResult::InviteUpdated(channel) => {
3046            let update = proto::UpdateChannels {
3047                channel_invitations: vec![channel.to_proto()],
3048                ..Default::default()
3049            };
3050
3051            for connection_id in session
3052                .connection_pool()
3053                .await
3054                .user_connection_ids(member_id)
3055            {
3056                session.peer.send(connection_id, update.clone())?;
3057            }
3058        }
3059    }
3060
3061    response.send(proto::Ack {})?;
3062    Ok(())
3063}
3064
3065/// Change the name of a channel
3066async fn rename_channel(
3067    request: proto::RenameChannel,
3068    response: Response<proto::RenameChannel>,
3069    session: MessageContext,
3070) -> Result<()> {
3071    let db = session.db().await;
3072    let channel_id = ChannelId::from_proto(request.channel_id);
3073    let channel_model = db
3074        .rename_channel(channel_id, session.user_id(), &request.name)
3075        .await?;
3076    let root_id = channel_model.root_id();
3077    let channel = Channel::from_model(channel_model);
3078
3079    response.send(proto::RenameChannelResponse {
3080        channel: Some(channel.to_proto()),
3081    })?;
3082
3083    let connection_pool = session.connection_pool().await;
3084    let update = proto::UpdateChannels {
3085        channels: vec![channel.to_proto()],
3086        ..Default::default()
3087    };
3088    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3089        if role.can_see_channel(channel.visibility) {
3090            session.peer.send(connection_id, update.clone())?;
3091        }
3092    }
3093
3094    Ok(())
3095}
3096
3097/// Move a channel to a new parent.
3098async fn move_channel(
3099    request: proto::MoveChannel,
3100    response: Response<proto::MoveChannel>,
3101    session: MessageContext,
3102) -> Result<()> {
3103    let channel_id = ChannelId::from_proto(request.channel_id);
3104    let to = ChannelId::from_proto(request.to);
3105
3106    let (root_id, channels) = session
3107        .db()
3108        .await
3109        .move_channel(channel_id, to, session.user_id())
3110        .await?;
3111
3112    let connection_pool = session.connection_pool().await;
3113    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3114        let channels = channels
3115            .iter()
3116            .filter_map(|channel| {
3117                if role.can_see_channel(channel.visibility) {
3118                    Some(channel.to_proto())
3119                } else {
3120                    None
3121                }
3122            })
3123            .collect::<Vec<_>>();
3124        if channels.is_empty() {
3125            continue;
3126        }
3127
3128        let update = proto::UpdateChannels {
3129            channels,
3130            ..Default::default()
3131        };
3132
3133        session.peer.send(connection_id, update.clone())?;
3134    }
3135
3136    response.send(Ack {})?;
3137    Ok(())
3138}
3139
3140async fn reorder_channel(
3141    request: proto::ReorderChannel,
3142    response: Response<proto::ReorderChannel>,
3143    session: MessageContext,
3144) -> Result<()> {
3145    let channel_id = ChannelId::from_proto(request.channel_id);
3146    let direction = request.direction();
3147
3148    let updated_channels = session
3149        .db()
3150        .await
3151        .reorder_channel(channel_id, direction, session.user_id())
3152        .await?;
3153
3154    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3155        let connection_pool = session.connection_pool().await;
3156        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3157            let channels = updated_channels
3158                .iter()
3159                .filter_map(|channel| {
3160                    if role.can_see_channel(channel.visibility) {
3161                        Some(channel.to_proto())
3162                    } else {
3163                        None
3164                    }
3165                })
3166                .collect::<Vec<_>>();
3167
3168            if channels.is_empty() {
3169                continue;
3170            }
3171
3172            let update = proto::UpdateChannels {
3173                channels,
3174                ..Default::default()
3175            };
3176
3177            session.peer.send(connection_id, update.clone())?;
3178        }
3179    }
3180
3181    response.send(Ack {})?;
3182    Ok(())
3183}
3184
3185/// Get the list of channel members
3186async fn get_channel_members(
3187    request: proto::GetChannelMembers,
3188    response: Response<proto::GetChannelMembers>,
3189    session: MessageContext,
3190) -> Result<()> {
3191    let db = session.db().await;
3192    let channel_id = ChannelId::from_proto(request.channel_id);
3193    let limit = if request.limit == 0 {
3194        u16::MAX as u64
3195    } else {
3196        request.limit
3197    };
3198    let (members, users) = db
3199        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3200        .await?;
3201    response.send(proto::GetChannelMembersResponse { members, users })?;
3202    Ok(())
3203}
3204
3205/// Accept or decline a channel invitation.
3206async fn respond_to_channel_invite(
3207    request: proto::RespondToChannelInvite,
3208    response: Response<proto::RespondToChannelInvite>,
3209    session: MessageContext,
3210) -> Result<()> {
3211    let db = session.db().await;
3212    let channel_id = ChannelId::from_proto(request.channel_id);
3213    let RespondToChannelInvite {
3214        membership_update,
3215        notifications,
3216    } = db
3217        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3218        .await?;
3219
3220    let mut connection_pool = session.connection_pool().await;
3221    if let Some(membership_update) = membership_update {
3222        notify_membership_updated(
3223            &mut connection_pool,
3224            membership_update,
3225            session.user_id(),
3226            &session.peer,
3227        );
3228    } else {
3229        let update = proto::UpdateChannels {
3230            remove_channel_invitations: vec![channel_id.to_proto()],
3231            ..Default::default()
3232        };
3233
3234        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3235            session.peer.send(connection_id, update.clone())?;
3236        }
3237    };
3238
3239    send_notifications(&connection_pool, &session.peer, notifications);
3240
3241    response.send(proto::Ack {})?;
3242
3243    Ok(())
3244}
3245
3246/// Join the channels' room
3247async fn join_channel(
3248    request: proto::JoinChannel,
3249    response: Response<proto::JoinChannel>,
3250    session: MessageContext,
3251) -> Result<()> {
3252    let channel_id = ChannelId::from_proto(request.channel_id);
3253    join_channel_internal(channel_id, Box::new(response), session).await
3254}
3255
3256trait JoinChannelInternalResponse {
3257    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3258}
3259impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3260    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3261        Response::<proto::JoinChannel>::send(self, result)
3262    }
3263}
3264impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3265    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3266        Response::<proto::JoinRoom>::send(self, result)
3267    }
3268}
3269
3270async fn join_channel_internal(
3271    channel_id: ChannelId,
3272    response: Box<impl JoinChannelInternalResponse>,
3273    session: MessageContext,
3274) -> Result<()> {
3275    let joined_room = {
3276        let mut db = session.db().await;
3277        // If zed quits without leaving the room, and the user re-opens zed before the
3278        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3279        // room they were in.
3280        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3281            tracing::info!(
3282                stale_connection_id = %connection,
3283                "cleaning up stale connection",
3284            );
3285            drop(db);
3286            leave_room_for_session(&session, connection).await?;
3287            db = session.db().await;
3288        }
3289
3290        let (joined_room, membership_updated, role) = db
3291            .join_channel(channel_id, session.user_id(), session.connection_id)
3292            .await?;
3293
3294        let live_kit_connection_info =
3295            session
3296                .app_state
3297                .livekit_client
3298                .as_ref()
3299                .and_then(|live_kit| {
3300                    let (can_publish, token) = if role == ChannelRole::Guest {
3301                        (
3302                            false,
3303                            live_kit
3304                                .guest_token(
3305                                    &joined_room.room.livekit_room,
3306                                    &session.user_id().to_string(),
3307                                )
3308                                .trace_err()?,
3309                        )
3310                    } else {
3311                        (
3312                            true,
3313                            live_kit
3314                                .room_token(
3315                                    &joined_room.room.livekit_room,
3316                                    &session.user_id().to_string(),
3317                                )
3318                                .trace_err()?,
3319                        )
3320                    };
3321
3322                    Some(LiveKitConnectionInfo {
3323                        server_url: live_kit.url().into(),
3324                        token,
3325                        can_publish,
3326                    })
3327                });
3328
3329        response.send(proto::JoinRoomResponse {
3330            room: Some(joined_room.room.clone()),
3331            channel_id: joined_room
3332                .channel
3333                .as_ref()
3334                .map(|channel| channel.id.to_proto()),
3335            live_kit_connection_info,
3336        })?;
3337
3338        let mut connection_pool = session.connection_pool().await;
3339        if let Some(membership_updated) = membership_updated {
3340            notify_membership_updated(
3341                &mut connection_pool,
3342                membership_updated,
3343                session.user_id(),
3344                &session.peer,
3345            );
3346        }
3347
3348        room_updated(&joined_room.room, &session.peer);
3349
3350        joined_room
3351    };
3352
3353    channel_updated(
3354        &joined_room.channel.context("channel not returned")?,
3355        &joined_room.room,
3356        &session.peer,
3357        &*session.connection_pool().await,
3358    );
3359
3360    update_user_contacts(session.user_id(), &session).await?;
3361    Ok(())
3362}
3363
3364/// Start editing the channel notes
3365async fn join_channel_buffer(
3366    request: proto::JoinChannelBuffer,
3367    response: Response<proto::JoinChannelBuffer>,
3368    session: MessageContext,
3369) -> Result<()> {
3370    let db = session.db().await;
3371    let channel_id = ChannelId::from_proto(request.channel_id);
3372
3373    let open_response = db
3374        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3375        .await?;
3376
3377    let collaborators = open_response.collaborators.clone();
3378    response.send(open_response)?;
3379
3380    let update = UpdateChannelBufferCollaborators {
3381        channel_id: channel_id.to_proto(),
3382        collaborators: collaborators.clone(),
3383    };
3384    channel_buffer_updated(
3385        session.connection_id,
3386        collaborators
3387            .iter()
3388            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3389        &update,
3390        &session.peer,
3391    );
3392
3393    Ok(())
3394}
3395
3396/// Edit the channel notes
3397async fn update_channel_buffer(
3398    request: proto::UpdateChannelBuffer,
3399    session: MessageContext,
3400) -> Result<()> {
3401    let db = session.db().await;
3402    let channel_id = ChannelId::from_proto(request.channel_id);
3403
3404    let (collaborators, epoch, version) = db
3405        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3406        .await?;
3407
3408    channel_buffer_updated(
3409        session.connection_id,
3410        collaborators.clone(),
3411        &proto::UpdateChannelBuffer {
3412            channel_id: channel_id.to_proto(),
3413            operations: request.operations,
3414        },
3415        &session.peer,
3416    );
3417
3418    let pool = &*session.connection_pool().await;
3419
3420    let non_collaborators =
3421        pool.channel_connection_ids(channel_id)
3422            .filter_map(|(connection_id, _)| {
3423                if collaborators.contains(&connection_id) {
3424                    None
3425                } else {
3426                    Some(connection_id)
3427                }
3428            });
3429
3430    broadcast(None, non_collaborators, |peer_id| {
3431        session.peer.send(
3432            peer_id,
3433            proto::UpdateChannels {
3434                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3435                    channel_id: channel_id.to_proto(),
3436                    epoch: epoch as u64,
3437                    version: version.clone(),
3438                }],
3439                ..Default::default()
3440            },
3441        )
3442    });
3443
3444    Ok(())
3445}
3446
3447/// Rejoin the channel notes after a connection blip
3448async fn rejoin_channel_buffers(
3449    request: proto::RejoinChannelBuffers,
3450    response: Response<proto::RejoinChannelBuffers>,
3451    session: MessageContext,
3452) -> Result<()> {
3453    let db = session.db().await;
3454    let buffers = db
3455        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3456        .await?;
3457
3458    for rejoined_buffer in &buffers {
3459        let collaborators_to_notify = rejoined_buffer
3460            .buffer
3461            .collaborators
3462            .iter()
3463            .filter_map(|c| Some(c.peer_id?.into()));
3464        channel_buffer_updated(
3465            session.connection_id,
3466            collaborators_to_notify,
3467            &proto::UpdateChannelBufferCollaborators {
3468                channel_id: rejoined_buffer.buffer.channel_id,
3469                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3470            },
3471            &session.peer,
3472        );
3473    }
3474
3475    response.send(proto::RejoinChannelBuffersResponse {
3476        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3477    })?;
3478
3479    Ok(())
3480}
3481
3482/// Stop editing the channel notes
3483async fn leave_channel_buffer(
3484    request: proto::LeaveChannelBuffer,
3485    response: Response<proto::LeaveChannelBuffer>,
3486    session: MessageContext,
3487) -> Result<()> {
3488    let db = session.db().await;
3489    let channel_id = ChannelId::from_proto(request.channel_id);
3490
3491    let left_buffer = db
3492        .leave_channel_buffer(channel_id, session.connection_id)
3493        .await?;
3494
3495    response.send(Ack {})?;
3496
3497    channel_buffer_updated(
3498        session.connection_id,
3499        left_buffer.connections,
3500        &proto::UpdateChannelBufferCollaborators {
3501            channel_id: channel_id.to_proto(),
3502            collaborators: left_buffer.collaborators,
3503        },
3504        &session.peer,
3505    );
3506
3507    Ok(())
3508}
3509
3510fn channel_buffer_updated<T: EnvelopedMessage>(
3511    sender_id: ConnectionId,
3512    collaborators: impl IntoIterator<Item = ConnectionId>,
3513    message: &T,
3514    peer: &Peer,
3515) {
3516    broadcast(Some(sender_id), collaborators, |peer_id| {
3517        peer.send(peer_id, message.clone())
3518    });
3519}
3520
3521fn send_notifications(
3522    connection_pool: &ConnectionPool,
3523    peer: &Peer,
3524    notifications: db::NotificationBatch,
3525) {
3526    for (user_id, notification) in notifications {
3527        for connection_id in connection_pool.user_connection_ids(user_id) {
3528            if let Err(error) = peer.send(
3529                connection_id,
3530                proto::AddNotification {
3531                    notification: Some(notification.clone()),
3532                },
3533            ) {
3534                tracing::error!(
3535                    "failed to send notification to {:?} {}",
3536                    connection_id,
3537                    error
3538                );
3539            }
3540        }
3541    }
3542}
3543
3544/// Send a message to the channel
3545async fn send_channel_message(
3546    _request: proto::SendChannelMessage,
3547    _response: Response<proto::SendChannelMessage>,
3548    _session: MessageContext,
3549) -> Result<()> {
3550    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3551}
3552
3553/// Delete a channel message
3554async fn remove_channel_message(
3555    _request: proto::RemoveChannelMessage,
3556    _response: Response<proto::RemoveChannelMessage>,
3557    _session: MessageContext,
3558) -> Result<()> {
3559    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3560}
3561
3562async fn update_channel_message(
3563    _request: proto::UpdateChannelMessage,
3564    _response: Response<proto::UpdateChannelMessage>,
3565    _session: MessageContext,
3566) -> Result<()> {
3567    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3568}
3569
3570/// Mark a channel message as read
3571async fn acknowledge_channel_message(
3572    _request: proto::AckChannelMessage,
3573    _session: MessageContext,
3574) -> Result<()> {
3575    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3576}
3577
3578/// Mark a buffer version as synced
3579async fn acknowledge_buffer_version(
3580    request: proto::AckBufferOperation,
3581    session: MessageContext,
3582) -> Result<()> {
3583    let buffer_id = BufferId::from_proto(request.buffer_id);
3584    session
3585        .db()
3586        .await
3587        .observe_buffer_version(
3588            buffer_id,
3589            session.user_id(),
3590            request.epoch as i32,
3591            &request.version,
3592        )
3593        .await?;
3594    Ok(())
3595}
3596
3597/// Start receiving chat updates for a channel
3598async fn join_channel_chat(
3599    _request: proto::JoinChannelChat,
3600    _response: Response<proto::JoinChannelChat>,
3601    _session: MessageContext,
3602) -> Result<()> {
3603    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3604}
3605
3606/// Stop receiving chat updates for a channel
3607async fn leave_channel_chat(
3608    _request: proto::LeaveChannelChat,
3609    _session: MessageContext,
3610) -> Result<()> {
3611    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3612}
3613
3614/// Retrieve the chat history for a channel
3615async fn get_channel_messages(
3616    _request: proto::GetChannelMessages,
3617    _response: Response<proto::GetChannelMessages>,
3618    _session: MessageContext,
3619) -> Result<()> {
3620    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3621}
3622
3623/// Retrieve specific chat messages
3624async fn get_channel_messages_by_id(
3625    _request: proto::GetChannelMessagesById,
3626    _response: Response<proto::GetChannelMessagesById>,
3627    _session: MessageContext,
3628) -> Result<()> {
3629    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3630}
3631
3632/// Retrieve the current users notifications
3633async fn get_notifications(
3634    request: proto::GetNotifications,
3635    response: Response<proto::GetNotifications>,
3636    session: MessageContext,
3637) -> Result<()> {
3638    let notifications = session
3639        .db()
3640        .await
3641        .get_notifications(
3642            session.user_id(),
3643            NOTIFICATION_COUNT_PER_PAGE,
3644            request.before_id.map(db::NotificationId::from_proto),
3645        )
3646        .await?;
3647    response.send(proto::GetNotificationsResponse {
3648        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3649        notifications,
3650    })?;
3651    Ok(())
3652}
3653
3654/// Mark notifications as read
3655async fn mark_notification_as_read(
3656    request: proto::MarkNotificationRead,
3657    response: Response<proto::MarkNotificationRead>,
3658    session: MessageContext,
3659) -> Result<()> {
3660    let database = &session.db().await;
3661    let notifications = database
3662        .mark_notification_as_read_by_id(
3663            session.user_id(),
3664            NotificationId::from_proto(request.notification_id),
3665        )
3666        .await?;
3667    send_notifications(
3668        &*session.connection_pool().await,
3669        &session.peer,
3670        notifications,
3671    );
3672    response.send(proto::Ack {})?;
3673    Ok(())
3674}
3675
3676fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3677    let message = match message {
3678        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3679        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3680        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3681        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3682        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3683            code: frame.code.into(),
3684            reason: frame.reason.as_str().to_owned().into(),
3685        })),
3686        // We should never receive a frame while reading the message, according
3687        // to the `tungstenite` maintainers:
3688        //
3689        // > It cannot occur when you read messages from the WebSocket, but it
3690        // > can be used when you want to send the raw frames (e.g. you want to
3691        // > send the frames to the WebSocket without composing the full message first).
3692        // >
3693        // > — https://github.com/snapview/tungstenite-rs/issues/268
3694        TungsteniteMessage::Frame(_) => {
3695            bail!("received an unexpected frame while reading the message")
3696        }
3697    };
3698
3699    Ok(message)
3700}
3701
3702fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3703    match message {
3704        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3705        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3706        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3707        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3708        AxumMessage::Close(frame) => {
3709            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3710                code: frame.code.into(),
3711                reason: frame.reason.as_ref().into(),
3712            }))
3713        }
3714    }
3715}
3716
3717fn notify_membership_updated(
3718    connection_pool: &mut ConnectionPool,
3719    result: MembershipUpdated,
3720    user_id: UserId,
3721    peer: &Peer,
3722) {
3723    for membership in &result.new_channels.channel_memberships {
3724        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3725    }
3726    for channel_id in &result.removed_channels {
3727        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3728    }
3729
3730    let user_channels_update = proto::UpdateUserChannels {
3731        channel_memberships: result
3732            .new_channels
3733            .channel_memberships
3734            .iter()
3735            .map(|cm| proto::ChannelMembership {
3736                channel_id: cm.channel_id.to_proto(),
3737                role: cm.role.into(),
3738            })
3739            .collect(),
3740        ..Default::default()
3741    };
3742
3743    let mut update = build_channels_update(result.new_channels);
3744    update.delete_channels = result
3745        .removed_channels
3746        .into_iter()
3747        .map(|id| id.to_proto())
3748        .collect();
3749    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3750
3751    for connection_id in connection_pool.user_connection_ids(user_id) {
3752        peer.send(connection_id, user_channels_update.clone())
3753            .trace_err();
3754        peer.send(connection_id, update.clone()).trace_err();
3755    }
3756}
3757
3758fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3759    proto::UpdateUserChannels {
3760        channel_memberships: channels
3761            .channel_memberships
3762            .iter()
3763            .map(|m| proto::ChannelMembership {
3764                channel_id: m.channel_id.to_proto(),
3765                role: m.role.into(),
3766            })
3767            .collect(),
3768        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3769    }
3770}
3771
3772fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3773    let mut update = proto::UpdateChannels::default();
3774
3775    for channel in channels.channels {
3776        update.channels.push(channel.to_proto());
3777    }
3778
3779    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3780
3781    for (channel_id, participants) in channels.channel_participants {
3782        update
3783            .channel_participants
3784            .push(proto::ChannelParticipants {
3785                channel_id: channel_id.to_proto(),
3786                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3787            });
3788    }
3789
3790    for channel in channels.invited_channels {
3791        update.channel_invitations.push(channel.to_proto());
3792    }
3793
3794    update
3795}
3796
3797fn build_initial_contacts_update(
3798    contacts: Vec<db::Contact>,
3799    pool: &ConnectionPool,
3800) -> proto::UpdateContacts {
3801    let mut update = proto::UpdateContacts::default();
3802
3803    for contact in contacts {
3804        match contact {
3805            db::Contact::Accepted { user_id, busy } => {
3806                update.contacts.push(contact_for_user(user_id, busy, pool));
3807            }
3808            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3809            db::Contact::Incoming { user_id } => {
3810                update
3811                    .incoming_requests
3812                    .push(proto::IncomingContactRequest {
3813                        requester_id: user_id.to_proto(),
3814                    })
3815            }
3816        }
3817    }
3818
3819    update
3820}
3821
3822fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3823    proto::Contact {
3824        user_id: user_id.to_proto(),
3825        online: pool.is_user_online(user_id),
3826        busy,
3827    }
3828}
3829
3830fn room_updated(room: &proto::Room, peer: &Peer) {
3831    broadcast(
3832        None,
3833        room.participants
3834            .iter()
3835            .filter_map(|participant| Some(participant.peer_id?.into())),
3836        |peer_id| {
3837            peer.send(
3838                peer_id,
3839                proto::RoomUpdated {
3840                    room: Some(room.clone()),
3841                },
3842            )
3843        },
3844    );
3845}
3846
3847fn channel_updated(
3848    channel: &db::channel::Model,
3849    room: &proto::Room,
3850    peer: &Peer,
3851    pool: &ConnectionPool,
3852) {
3853    let participants = room
3854        .participants
3855        .iter()
3856        .map(|p| p.user_id)
3857        .collect::<Vec<_>>();
3858
3859    broadcast(
3860        None,
3861        pool.channel_connection_ids(channel.root_id())
3862            .filter_map(|(channel_id, role)| {
3863                role.can_see_channel(channel.visibility)
3864                    .then_some(channel_id)
3865            }),
3866        |peer_id| {
3867            peer.send(
3868                peer_id,
3869                proto::UpdateChannels {
3870                    channel_participants: vec![proto::ChannelParticipants {
3871                        channel_id: channel.id.to_proto(),
3872                        participant_user_ids: participants.clone(),
3873                    }],
3874                    ..Default::default()
3875                },
3876            )
3877        },
3878    );
3879}
3880
3881async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3882    let db = session.db().await;
3883
3884    let contacts = db.get_contacts(user_id).await?;
3885    let busy = db.is_user_busy(user_id).await?;
3886
3887    let pool = session.connection_pool().await;
3888    let updated_contact = contact_for_user(user_id, busy, &pool);
3889    for contact in contacts {
3890        if let db::Contact::Accepted {
3891            user_id: contact_user_id,
3892            ..
3893        } = contact
3894        {
3895            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3896                session
3897                    .peer
3898                    .send(
3899                        contact_conn_id,
3900                        proto::UpdateContacts {
3901                            contacts: vec![updated_contact.clone()],
3902                            remove_contacts: Default::default(),
3903                            incoming_requests: Default::default(),
3904                            remove_incoming_requests: Default::default(),
3905                            outgoing_requests: Default::default(),
3906                            remove_outgoing_requests: Default::default(),
3907                        },
3908                    )
3909                    .trace_err();
3910            }
3911        }
3912    }
3913    Ok(())
3914}
3915
3916async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3917    let mut contacts_to_update = HashSet::default();
3918
3919    let room_id;
3920    let canceled_calls_to_user_ids;
3921    let livekit_room;
3922    let delete_livekit_room;
3923    let room;
3924    let channel;
3925
3926    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3927        contacts_to_update.insert(session.user_id());
3928
3929        for project in left_room.left_projects.values() {
3930            project_left(project, session);
3931        }
3932
3933        room_id = RoomId::from_proto(left_room.room.id);
3934        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3935        livekit_room = mem::take(&mut left_room.room.livekit_room);
3936        delete_livekit_room = left_room.deleted;
3937        room = mem::take(&mut left_room.room);
3938        channel = mem::take(&mut left_room.channel);
3939
3940        room_updated(&room, &session.peer);
3941    } else {
3942        return Ok(());
3943    }
3944
3945    if let Some(channel) = channel {
3946        channel_updated(
3947            &channel,
3948            &room,
3949            &session.peer,
3950            &*session.connection_pool().await,
3951        );
3952    }
3953
3954    {
3955        let pool = session.connection_pool().await;
3956        for canceled_user_id in canceled_calls_to_user_ids {
3957            for connection_id in pool.user_connection_ids(canceled_user_id) {
3958                session
3959                    .peer
3960                    .send(
3961                        connection_id,
3962                        proto::CallCanceled {
3963                            room_id: room_id.to_proto(),
3964                        },
3965                    )
3966                    .trace_err();
3967            }
3968            contacts_to_update.insert(canceled_user_id);
3969        }
3970    }
3971
3972    for contact_user_id in contacts_to_update {
3973        update_user_contacts(contact_user_id, session).await?;
3974    }
3975
3976    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3977        live_kit
3978            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3979            .await
3980            .trace_err();
3981
3982        if delete_livekit_room {
3983            live_kit.delete_room(livekit_room).await.trace_err();
3984        }
3985    }
3986
3987    Ok(())
3988}
3989
3990async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3991    let left_channel_buffers = session
3992        .db()
3993        .await
3994        .leave_channel_buffers(session.connection_id)
3995        .await?;
3996
3997    for left_buffer in left_channel_buffers {
3998        channel_buffer_updated(
3999            session.connection_id,
4000            left_buffer.connections,
4001            &proto::UpdateChannelBufferCollaborators {
4002                channel_id: left_buffer.channel_id.to_proto(),
4003                collaborators: left_buffer.collaborators,
4004            },
4005            &session.peer,
4006        );
4007    }
4008
4009    Ok(())
4010}
4011
4012fn project_left(project: &db::LeftProject, session: &Session) {
4013    for connection_id in &project.connection_ids {
4014        if project.should_unshare {
4015            session
4016                .peer
4017                .send(
4018                    *connection_id,
4019                    proto::UnshareProject {
4020                        project_id: project.id.to_proto(),
4021                    },
4022                )
4023                .trace_err();
4024        } else {
4025            session
4026                .peer
4027                .send(
4028                    *connection_id,
4029                    proto::RemoveProjectCollaborator {
4030                        project_id: project.id.to_proto(),
4031                        peer_id: Some(session.connection_id.into()),
4032                    },
4033                )
4034                .trace_err();
4035        }
4036    }
4037}
4038
4039async fn share_agent_thread(
4040    request: proto::ShareAgentThread,
4041    response: Response<proto::ShareAgentThread>,
4042    session: MessageContext,
4043) -> Result<()> {
4044    let user_id = session.user_id();
4045
4046    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4047        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4048
4049    session
4050        .db()
4051        .await
4052        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4053        .await?;
4054
4055    response.send(proto::Ack {})?;
4056
4057    Ok(())
4058}
4059
4060async fn get_shared_agent_thread(
4061    request: proto::GetSharedAgentThread,
4062    response: Response<proto::GetSharedAgentThread>,
4063    session: MessageContext,
4064) -> Result<()> {
4065    let share_id = SharedThreadId::from_proto(request.session_id)
4066        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4067
4068    let result = session.db().await.get_shared_thread(share_id).await?;
4069
4070    match result {
4071        Some((thread, username)) => {
4072            response.send(proto::GetSharedAgentThreadResponse {
4073                title: thread.title,
4074                thread_data: thread.data,
4075                sharer_username: username,
4076                created_at: thread.created_at.and_utc().to_rfc3339(),
4077            })?;
4078        }
4079        None => {
4080            return Err(anyhow!("Shared thread not found").into());
4081        }
4082    }
4083
4084    Ok(())
4085}
4086
4087pub trait ResultExt {
4088    type Ok;
4089
4090    fn trace_err(self) -> Option<Self::Ok>;
4091}
4092
4093impl<T, E> ResultExt for Result<T, E>
4094where
4095    E: std::fmt::Debug,
4096{
4097    type Ok = T;
4098
4099    #[track_caller]
4100    fn trace_err(self) -> Option<T> {
4101        match self {
4102            Ok(value) => Some(value),
4103            Err(error) => {
4104                tracing::error!("{:?}", error);
4105                None
4106            }
4107        }
4108    }
4109}