rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use reqwest_client::ReqwestClient;
  36use rpc::proto::split_repository_update;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38use tracing::Span;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semantic_version::SemanticVersion;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        Arc, OnceLock,
  64        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  65    },
  66    time::{Duration, Instant},
  67};
  68use tokio::sync::{Semaphore, watch};
  69use tower::ServiceBuilder;
  70use tracing::{
  71    Instrument,
  72    field::{self},
  73    info_span, instrument,
  74};
  75
  76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  77
  78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  80
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  83
  84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  85
  86const TOTAL_DURATION_MS: &str = "total_duration_ms";
  87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  89const HOST_WAITING_MS: &str = "host_waiting_ms";
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  93
  94pub struct ConnectionGuard;
  95
  96impl ConnectionGuard {
  97    pub fn try_acquire() -> Result<Self, ()> {
  98        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  99        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 100            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 101            tracing::error!(
 102                "too many concurrent connections: {}",
 103                current_connections + 1
 104            );
 105            return Err(());
 106        }
 107        Ok(ConnectionGuard)
 108    }
 109}
 110
 111impl Drop for ConnectionGuard {
 112    fn drop(&mut self) {
 113        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114    }
 115}
 116
 117struct Response<R> {
 118    peer: Arc<Peer>,
 119    receipt: Receipt<R>,
 120    responded: Arc<AtomicBool>,
 121}
 122
 123impl<R: RequestMessage> Response<R> {
 124    fn send(self, payload: R::Response) -> Result<()> {
 125        self.responded.store(true, SeqCst);
 126        self.peer.respond(self.receipt, payload)?;
 127        Ok(())
 128    }
 129}
 130
 131#[derive(Clone, Debug)]
 132pub enum Principal {
 133    User(User),
 134    Impersonated { user: User, admin: User },
 135}
 136
 137impl Principal {
 138    fn update_span(&self, span: &tracing::Span) {
 139        match &self {
 140            Principal::User(user) => {
 141                span.record("user_id", user.id.0);
 142                span.record("login", &user.github_login);
 143            }
 144            Principal::Impersonated { user, admin } => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147                span.record("impersonator", &admin.github_login);
 148            }
 149        }
 150    }
 151}
 152
 153#[derive(Clone)]
 154struct MessageContext {
 155    session: Session,
 156    span: tracing::Span,
 157}
 158
 159impl Deref for MessageContext {
 160    type Target = Session;
 161
 162    fn deref(&self) -> &Self::Target {
 163        &self.session
 164    }
 165}
 166
 167impl MessageContext {
 168    pub fn forward_request<T: RequestMessage>(
 169        &self,
 170        receiver_id: ConnectionId,
 171        request: T,
 172    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 173        let request_start_time = Instant::now();
 174        let span = self.span.clone();
 175        tracing::info!("start forwarding request");
 176        self.peer
 177            .forward_request(self.connection_id, receiver_id, request)
 178            .inspect(move |_| {
 179                span.record(
 180                    HOST_WAITING_MS,
 181                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 182                );
 183            })
 184            .inspect_err(|_| tracing::error!("error forwarding request"))
 185            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 186    }
 187}
 188
 189#[derive(Clone)]
 190struct Session {
 191    principal: Principal,
 192    connection_id: ConnectionId,
 193    db: Arc<tokio::sync::Mutex<DbHandle>>,
 194    peer: Arc<Peer>,
 195    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 196    app_state: Arc<AppState>,
 197    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 198    /// The GeoIP country code for the user.
 199    #[allow(unused)]
 200    geoip_country_code: Option<String>,
 201    #[allow(unused)]
 202    system_id: Option<String>,
 203    _executor: Executor,
 204}
 205
 206impl Session {
 207    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 208        #[cfg(test)]
 209        tokio::task::yield_now().await;
 210        let guard = self.db.lock().await;
 211        #[cfg(test)]
 212        tokio::task::yield_now().await;
 213        guard
 214    }
 215
 216    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 217        #[cfg(test)]
 218        tokio::task::yield_now().await;
 219        let guard = self.connection_pool.lock();
 220        ConnectionPoolGuard {
 221            guard,
 222            _not_send: PhantomData,
 223        }
 224    }
 225
 226    fn is_staff(&self) -> bool {
 227        match &self.principal {
 228            Principal::User(user) => user.admin,
 229            Principal::Impersonated { .. } => true,
 230        }
 231    }
 232
 233    fn user_id(&self) -> UserId {
 234        match &self.principal {
 235            Principal::User(user) => user.id,
 236            Principal::Impersonated { user, .. } => user.id,
 237        }
 238    }
 239
 240    pub fn email(&self) -> Option<String> {
 241        match &self.principal {
 242            Principal::User(user) => user.email_address.clone(),
 243            Principal::Impersonated { user, .. } => user.email_address.clone(),
 244        }
 245    }
 246}
 247
 248impl Debug for Session {
 249    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 250        let mut result = f.debug_struct("Session");
 251        match &self.principal {
 252            Principal::User(user) => {
 253                result.field("user", &user.github_login);
 254            }
 255            Principal::Impersonated { user, admin } => {
 256                result.field("user", &user.github_login);
 257                result.field("impersonator", &admin.github_login);
 258            }
 259        }
 260        result.field("connection_id", &self.connection_id).finish()
 261    }
 262}
 263
 264struct DbHandle(Arc<Database>);
 265
 266impl Deref for DbHandle {
 267    type Target = Database;
 268
 269    fn deref(&self) -> &Self::Target {
 270        self.0.as_ref()
 271    }
 272}
 273
 274pub struct Server {
 275    id: parking_lot::Mutex<ServerId>,
 276    peer: Arc<Peer>,
 277    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 278    app_state: Arc<AppState>,
 279    handlers: HashMap<TypeId, MessageHandler>,
 280    teardown: watch::Sender<bool>,
 281}
 282
 283pub(crate) struct ConnectionPoolGuard<'a> {
 284    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 285    _not_send: PhantomData<Rc<()>>,
 286}
 287
 288#[derive(Serialize)]
 289pub struct ServerSnapshot<'a> {
 290    peer: &'a Peer,
 291    #[serde(serialize_with = "serialize_deref")]
 292    connection_pool: ConnectionPoolGuard<'a>,
 293}
 294
 295pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 296where
 297    S: Serializer,
 298    T: Deref<Target = U>,
 299    U: Serialize,
 300{
 301    Serialize::serialize(value.deref(), serializer)
 302}
 303
 304impl Server {
 305    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 306        let mut server = Self {
 307            id: parking_lot::Mutex::new(id),
 308            peer: Peer::new(id.0 as u32),
 309            app_state,
 310            connection_pool: Default::default(),
 311            handlers: Default::default(),
 312            teardown: watch::channel(false).0,
 313        };
 314
 315        server
 316            .add_request_handler(ping)
 317            .add_request_handler(create_room)
 318            .add_request_handler(join_room)
 319            .add_request_handler(rejoin_room)
 320            .add_request_handler(leave_room)
 321            .add_request_handler(set_room_participant_role)
 322            .add_request_handler(call)
 323            .add_request_handler(cancel_call)
 324            .add_message_handler(decline_call)
 325            .add_request_handler(update_participant_location)
 326            .add_request_handler(share_project)
 327            .add_message_handler(unshare_project)
 328            .add_request_handler(join_project)
 329            .add_message_handler(leave_project)
 330            .add_request_handler(update_project)
 331            .add_request_handler(update_worktree)
 332            .add_request_handler(update_repository)
 333            .add_request_handler(remove_repository)
 334            .add_message_handler(start_language_server)
 335            .add_message_handler(update_language_server)
 336            .add_message_handler(update_diagnostic_summary)
 337            .add_message_handler(update_worktree_settings)
 338            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 342            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 344            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 345            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 346            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 348            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 349            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 350            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 352            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 354            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 355            .add_request_handler(
 356                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 357            )
 358            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 359            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 360            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 362            .add_request_handler(
 363                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 364            )
 365            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 366            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 367            .add_request_handler(
 368                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 369            )
 370            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 371            .add_request_handler(
 372                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 373            )
 374            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 375            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 376            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 377            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 378            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 379            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 381            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 382            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 383            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 384            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 385            .add_request_handler(
 386                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 387            )
 388            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 389            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 390            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 391            .add_request_handler(lsp_query)
 392            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 393            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 394            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 395            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 396            .add_message_handler(create_buffer_for_peer)
 397            .add_request_handler(update_buffer)
 398            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 399            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 400            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 404            .add_message_handler(
 405                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 406            )
 407            .add_request_handler(get_users)
 408            .add_request_handler(fuzzy_search_users)
 409            .add_request_handler(request_contact)
 410            .add_request_handler(remove_contact)
 411            .add_request_handler(respond_to_contact_request)
 412            .add_message_handler(subscribe_to_channels)
 413            .add_request_handler(create_channel)
 414            .add_request_handler(delete_channel)
 415            .add_request_handler(invite_channel_member)
 416            .add_request_handler(remove_channel_member)
 417            .add_request_handler(set_channel_member_role)
 418            .add_request_handler(set_channel_visibility)
 419            .add_request_handler(rename_channel)
 420            .add_request_handler(join_channel_buffer)
 421            .add_request_handler(leave_channel_buffer)
 422            .add_message_handler(update_channel_buffer)
 423            .add_request_handler(rejoin_channel_buffers)
 424            .add_request_handler(get_channel_members)
 425            .add_request_handler(respond_to_channel_invite)
 426            .add_request_handler(join_channel)
 427            .add_request_handler(join_channel_chat)
 428            .add_message_handler(leave_channel_chat)
 429            .add_request_handler(send_channel_message)
 430            .add_request_handler(remove_channel_message)
 431            .add_request_handler(update_channel_message)
 432            .add_request_handler(get_channel_messages)
 433            .add_request_handler(get_channel_messages_by_id)
 434            .add_request_handler(get_notifications)
 435            .add_request_handler(mark_notification_as_read)
 436            .add_request_handler(move_channel)
 437            .add_request_handler(reorder_channel)
 438            .add_request_handler(follow)
 439            .add_message_handler(unfollow)
 440            .add_message_handler(update_followers)
 441            .add_message_handler(acknowledge_channel_message)
 442            .add_message_handler(acknowledge_buffer_version)
 443            .add_request_handler(get_supermaven_api_key)
 444            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 445            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 446            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 447            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 448            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 449            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 450            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 451            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 452            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 453            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 454            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 455            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 456            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 457            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 458            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 459            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 460            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 461            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 465            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 466            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 467            .add_message_handler(update_context)
 468            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 469            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 470
 471        Arc::new(server)
 472    }
 473
 474    pub async fn start(&self) -> Result<()> {
 475        let server_id = *self.id.lock();
 476        let app_state = self.app_state.clone();
 477        let peer = self.peer.clone();
 478        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 479        let pool = self.connection_pool.clone();
 480        let livekit_client = self.app_state.livekit_client.clone();
 481
 482        let span = info_span!("start server");
 483        self.app_state.executor.spawn_detached(
 484            async move {
 485                tracing::info!("waiting for cleanup timeout");
 486                timeout.await;
 487                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 488
 489                app_state
 490                    .db
 491                    .delete_stale_channel_chat_participants(
 492                        &app_state.config.zed_environment,
 493                        server_id,
 494                    )
 495                    .await
 496                    .trace_err();
 497
 498                if let Some((room_ids, channel_ids)) = app_state
 499                    .db
 500                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 501                    .await
 502                    .trace_err()
 503                {
 504                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 505                    tracing::info!(
 506                        stale_channel_buffer_count = channel_ids.len(),
 507                        "retrieved stale channel buffers"
 508                    );
 509
 510                    for channel_id in channel_ids {
 511                        if let Some(refreshed_channel_buffer) = app_state
 512                            .db
 513                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 514                            .await
 515                            .trace_err()
 516                        {
 517                            for connection_id in refreshed_channel_buffer.connection_ids {
 518                                peer.send(
 519                                    connection_id,
 520                                    proto::UpdateChannelBufferCollaborators {
 521                                        channel_id: channel_id.to_proto(),
 522                                        collaborators: refreshed_channel_buffer
 523                                            .collaborators
 524                                            .clone(),
 525                                    },
 526                                )
 527                                .trace_err();
 528                            }
 529                        }
 530                    }
 531
 532                    for room_id in room_ids {
 533                        let mut contacts_to_update = HashSet::default();
 534                        let mut canceled_calls_to_user_ids = Vec::new();
 535                        let mut livekit_room = String::new();
 536                        let mut delete_livekit_room = false;
 537
 538                        if let Some(mut refreshed_room) = app_state
 539                            .db
 540                            .clear_stale_room_participants(room_id, server_id)
 541                            .await
 542                            .trace_err()
 543                        {
 544                            tracing::info!(
 545                                room_id = room_id.0,
 546                                new_participant_count = refreshed_room.room.participants.len(),
 547                                "refreshed room"
 548                            );
 549                            room_updated(&refreshed_room.room, &peer);
 550                            if let Some(channel) = refreshed_room.channel.as_ref() {
 551                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 552                            }
 553                            contacts_to_update
 554                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 555                            contacts_to_update
 556                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 557                            canceled_calls_to_user_ids =
 558                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 559                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 560                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 561                        }
 562
 563                        {
 564                            let pool = pool.lock();
 565                            for canceled_user_id in canceled_calls_to_user_ids {
 566                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 567                                    peer.send(
 568                                        connection_id,
 569                                        proto::CallCanceled {
 570                                            room_id: room_id.to_proto(),
 571                                        },
 572                                    )
 573                                    .trace_err();
 574                                }
 575                            }
 576                        }
 577
 578                        for user_id in contacts_to_update {
 579                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 580                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 581                            if let Some((busy, contacts)) = busy.zip(contacts) {
 582                                let pool = pool.lock();
 583                                let updated_contact = contact_for_user(user_id, busy, &pool);
 584                                for contact in contacts {
 585                                    if let db::Contact::Accepted {
 586                                        user_id: contact_user_id,
 587                                        ..
 588                                    } = contact
 589                                    {
 590                                        for contact_conn_id in
 591                                            pool.user_connection_ids(contact_user_id)
 592                                        {
 593                                            peer.send(
 594                                                contact_conn_id,
 595                                                proto::UpdateContacts {
 596                                                    contacts: vec![updated_contact.clone()],
 597                                                    remove_contacts: Default::default(),
 598                                                    incoming_requests: Default::default(),
 599                                                    remove_incoming_requests: Default::default(),
 600                                                    outgoing_requests: Default::default(),
 601                                                    remove_outgoing_requests: Default::default(),
 602                                                },
 603                                            )
 604                                            .trace_err();
 605                                        }
 606                                    }
 607                                }
 608                            }
 609                        }
 610
 611                        if let Some(live_kit) = livekit_client.as_ref()
 612                            && delete_livekit_room
 613                        {
 614                            live_kit.delete_room(livekit_room).await.trace_err();
 615                        }
 616                    }
 617                }
 618
 619                app_state
 620                    .db
 621                    .delete_stale_channel_chat_participants(
 622                        &app_state.config.zed_environment,
 623                        server_id,
 624                    )
 625                    .await
 626                    .trace_err();
 627
 628                app_state
 629                    .db
 630                    .clear_old_worktree_entries(server_id)
 631                    .await
 632                    .trace_err();
 633
 634                app_state
 635                    .db
 636                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 637                    .await
 638                    .trace_err();
 639            }
 640            .instrument(span),
 641        );
 642        Ok(())
 643    }
 644
 645    pub fn teardown(&self) {
 646        self.peer.teardown();
 647        self.connection_pool.lock().reset();
 648        let _ = self.teardown.send(true);
 649    }
 650
 651    #[cfg(test)]
 652    pub fn reset(&self, id: ServerId) {
 653        self.teardown();
 654        *self.id.lock() = id;
 655        self.peer.reset(id.0 as u32);
 656        let _ = self.teardown.send(false);
 657    }
 658
 659    #[cfg(test)]
 660    pub fn id(&self) -> ServerId {
 661        *self.id.lock()
 662    }
 663
 664    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 665    where
 666        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 667        Fut: 'static + Send + Future<Output = Result<()>>,
 668        M: EnvelopedMessage,
 669    {
 670        let prev_handler = self.handlers.insert(
 671            TypeId::of::<M>(),
 672            Box::new(move |envelope, session, span| {
 673                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 674                let received_at = envelope.received_at;
 675                tracing::info!("message received");
 676                let start_time = Instant::now();
 677                let future = (handler)(
 678                    *envelope,
 679                    MessageContext {
 680                        session,
 681                        span: span.clone(),
 682                    },
 683                );
 684                async move {
 685                    let result = future.await;
 686                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 687                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 688                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 689                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 690                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 691                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 692                    match result {
 693                        Err(error) => {
 694                            tracing::error!(?error, "error handling message")
 695                        }
 696                        Ok(()) => tracing::info!("finished handling message"),
 697                    }
 698                }
 699                .boxed()
 700            }),
 701        );
 702        if prev_handler.is_some() {
 703            panic!("registered a handler for the same message twice");
 704        }
 705        self
 706    }
 707
 708    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 709    where
 710        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 711        Fut: 'static + Send + Future<Output = Result<()>>,
 712        M: EnvelopedMessage,
 713    {
 714        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 715        self
 716    }
 717
 718    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 719    where
 720        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 721        Fut: Send + Future<Output = Result<()>>,
 722        M: RequestMessage,
 723    {
 724        let handler = Arc::new(handler);
 725        self.add_handler(move |envelope, session| {
 726            let receipt = envelope.receipt();
 727            let handler = handler.clone();
 728            async move {
 729                let peer = session.peer.clone();
 730                let responded = Arc::new(AtomicBool::default());
 731                let response = Response {
 732                    peer: peer.clone(),
 733                    responded: responded.clone(),
 734                    receipt,
 735                };
 736                match (handler)(envelope.payload, response, session).await {
 737                    Ok(()) => {
 738                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 739                            Ok(())
 740                        } else {
 741                            Err(anyhow!("handler did not send a response"))?
 742                        }
 743                    }
 744                    Err(error) => {
 745                        let proto_err = match &error {
 746                            Error::Internal(err) => err.to_proto(),
 747                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 748                        };
 749                        peer.respond_with_error(receipt, proto_err)?;
 750                        Err(error)
 751                    }
 752                }
 753            }
 754        })
 755    }
 756
 757    pub fn handle_connection(
 758        self: &Arc<Self>,
 759        connection: Connection,
 760        address: String,
 761        principal: Principal,
 762        zed_version: ZedVersion,
 763        release_channel: Option<String>,
 764        user_agent: Option<String>,
 765        geoip_country_code: Option<String>,
 766        system_id: Option<String>,
 767        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 768        executor: Executor,
 769        connection_guard: Option<ConnectionGuard>,
 770    ) -> impl Future<Output = ()> + use<> {
 771        let this = self.clone();
 772        let span = info_span!("handle connection", %address,
 773            connection_id=field::Empty,
 774            user_id=field::Empty,
 775            login=field::Empty,
 776            impersonator=field::Empty,
 777            user_agent=field::Empty,
 778            geoip_country_code=field::Empty,
 779            release_channel=field::Empty,
 780        );
 781        principal.update_span(&span);
 782        if let Some(user_agent) = user_agent {
 783            span.record("user_agent", user_agent);
 784        }
 785        if let Some(release_channel) = release_channel {
 786            span.record("release_channel", release_channel);
 787        }
 788
 789        if let Some(country_code) = geoip_country_code.as_ref() {
 790            span.record("geoip_country_code", country_code);
 791        }
 792
 793        let mut teardown = self.teardown.subscribe();
 794        async move {
 795            if *teardown.borrow() {
 796                tracing::error!("server is tearing down");
 797                return;
 798            }
 799
 800            let (connection_id, handle_io, mut incoming_rx) =
 801                this.peer.add_connection(connection, {
 802                    let executor = executor.clone();
 803                    move |duration| executor.sleep(duration)
 804                });
 805            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 806
 807            tracing::info!("connection opened");
 808
 809            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 810            let http_client = match ReqwestClient::user_agent(&user_agent) {
 811                Ok(http_client) => Arc::new(http_client),
 812                Err(error) => {
 813                    tracing::error!(?error, "failed to create HTTP client");
 814                    return;
 815                }
 816            };
 817
 818            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 819                |supermaven_admin_api_key| {
 820                    Arc::new(SupermavenAdminApi::new(
 821                        supermaven_admin_api_key.to_string(),
 822                        http_client.clone(),
 823                    ))
 824                },
 825            );
 826
 827            let session = Session {
 828                principal: principal.clone(),
 829                connection_id,
 830                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 831                peer: this.peer.clone(),
 832                connection_pool: this.connection_pool.clone(),
 833                app_state: this.app_state.clone(),
 834                geoip_country_code,
 835                system_id,
 836                _executor: executor.clone(),
 837                supermaven_client,
 838            };
 839
 840            if let Err(error) = this
 841                .send_initial_client_update(
 842                    connection_id,
 843                    zed_version,
 844                    send_connection_id,
 845                    &session,
 846                )
 847                .await
 848            {
 849                tracing::error!(?error, "failed to send initial client update");
 850                return;
 851            }
 852            drop(connection_guard);
 853
 854            let handle_io = handle_io.fuse();
 855            futures::pin_mut!(handle_io);
 856
 857            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 858            // This prevents deadlocks when e.g., client A performs a request to client B and
 859            // client B performs a request to client A. If both clients stop processing further
 860            // messages until their respective request completes, they won't have a chance to
 861            // respond to the other client's request and cause a deadlock.
 862            //
 863            // This arrangement ensures we will attempt to process earlier messages first, but fall
 864            // back to processing messages arrived later in the spirit of making progress.
 865            const MAX_CONCURRENT_HANDLERS: usize = 256;
 866            let mut foreground_message_handlers = FuturesUnordered::new();
 867            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 868            let get_concurrent_handlers = {
 869                let concurrent_handlers = concurrent_handlers.clone();
 870                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 871            };
 872            loop {
 873                let next_message = async {
 874                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 875                    let message = incoming_rx.next().await;
 876                    // Cache the concurrent_handlers here, so that we know what the
 877                    // queue looks like as each handler starts
 878                    (permit, message, get_concurrent_handlers())
 879                }
 880                .fuse();
 881                futures::pin_mut!(next_message);
 882                futures::select_biased! {
 883                    _ = teardown.changed().fuse() => return,
 884                    result = handle_io => {
 885                        if let Err(error) = result {
 886                            tracing::error!(?error, "error handling I/O");
 887                        }
 888                        break;
 889                    }
 890                    _ = foreground_message_handlers.next() => {}
 891                    next_message = next_message => {
 892                        let (permit, message, concurrent_handlers) = next_message;
 893                        if let Some(message) = message {
 894                            let type_name = message.payload_type_name();
 895                            // note: we copy all the fields from the parent span so we can query them in the logs.
 896                            // (https://github.com/tokio-rs/tracing/issues/2670).
 897                            let span = tracing::info_span!("receive message",
 898                                %connection_id,
 899                                %address,
 900                                type_name,
 901                                concurrent_handlers,
 902                                user_id=field::Empty,
 903                                login=field::Empty,
 904                                impersonator=field::Empty,
 905                                lsp_query_request=field::Empty,
 906                                release_channel=field::Empty,
 907                                { TOTAL_DURATION_MS }=field::Empty,
 908                                { PROCESSING_DURATION_MS }=field::Empty,
 909                                { QUEUE_DURATION_MS }=field::Empty,
 910                                { HOST_WAITING_MS }=field::Empty
 911                            );
 912                            principal.update_span(&span);
 913                            let span_enter = span.enter();
 914                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 915                                let is_background = message.is_background();
 916                                let handle_message = (handler)(message, session.clone(), span.clone());
 917                                drop(span_enter);
 918
 919                                let handle_message = async move {
 920                                    handle_message.await;
 921                                    drop(permit);
 922                                }.instrument(span);
 923                                if is_background {
 924                                    executor.spawn_detached(handle_message);
 925                                } else {
 926                                    foreground_message_handlers.push(handle_message);
 927                                }
 928                            } else {
 929                                tracing::error!("no message handler");
 930                            }
 931                        } else {
 932                            tracing::info!("connection closed");
 933                            break;
 934                        }
 935                    }
 936                }
 937            }
 938
 939            drop(foreground_message_handlers);
 940            let concurrent_handlers = get_concurrent_handlers();
 941            tracing::info!(concurrent_handlers, "signing out");
 942            if let Err(error) = connection_lost(session, teardown, executor).await {
 943                tracing::error!(?error, "error signing out");
 944            }
 945        }
 946        .instrument(span)
 947    }
 948
 949    async fn send_initial_client_update(
 950        &self,
 951        connection_id: ConnectionId,
 952        zed_version: ZedVersion,
 953        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 954        session: &Session,
 955    ) -> Result<()> {
 956        self.peer.send(
 957            connection_id,
 958            proto::Hello {
 959                peer_id: Some(connection_id.into()),
 960            },
 961        )?;
 962        tracing::info!("sent hello message");
 963        if let Some(send_connection_id) = send_connection_id.take() {
 964            let _ = send_connection_id.send(connection_id);
 965        }
 966
 967        match &session.principal {
 968            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 969                if !user.connected_once {
 970                    self.peer.send(connection_id, proto::ShowContacts {})?;
 971                    self.app_state
 972                        .db
 973                        .set_user_connected_once(user.id, true)
 974                        .await?;
 975                }
 976
 977                let contacts = self.app_state.db.get_contacts(user.id).await?;
 978
 979                {
 980                    let mut pool = self.connection_pool.lock();
 981                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 982                    self.peer.send(
 983                        connection_id,
 984                        build_initial_contacts_update(contacts, &pool),
 985                    )?;
 986                }
 987
 988                if should_auto_subscribe_to_channels(zed_version) {
 989                    subscribe_user_to_channels(user.id, session).await?;
 990                }
 991
 992                if let Some(incoming_call) =
 993                    self.app_state.db.incoming_call_for_user(user.id).await?
 994                {
 995                    self.peer.send(connection_id, incoming_call)?;
 996                }
 997
 998                update_user_contacts(user.id, session).await?;
 999            }
1000        }
1001
1002        Ok(())
1003    }
1004
1005    pub async fn invite_code_redeemed(
1006        self: &Arc<Self>,
1007        inviter_id: UserId,
1008        invitee_id: UserId,
1009    ) -> Result<()> {
1010        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1011            && let Some(code) = &user.invite_code
1012        {
1013            let pool = self.connection_pool.lock();
1014            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1015            for connection_id in pool.user_connection_ids(inviter_id) {
1016                self.peer.send(
1017                    connection_id,
1018                    proto::UpdateContacts {
1019                        contacts: vec![invitee_contact.clone()],
1020                        ..Default::default()
1021                    },
1022                )?;
1023                self.peer.send(
1024                    connection_id,
1025                    proto::UpdateInviteInfo {
1026                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1027                        count: user.invite_count as u32,
1028                    },
1029                )?;
1030            }
1031        }
1032        Ok(())
1033    }
1034
1035    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1036        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1037            && let Some(invite_code) = &user.invite_code
1038        {
1039            let pool = self.connection_pool.lock();
1040            for connection_id in pool.user_connection_ids(user_id) {
1041                self.peer.send(
1042                    connection_id,
1043                    proto::UpdateInviteInfo {
1044                        url: format!(
1045                            "{}{}",
1046                            self.app_state.config.invite_link_prefix, invite_code
1047                        ),
1048                        count: user.invite_count as u32,
1049                    },
1050                )?;
1051            }
1052        }
1053        Ok(())
1054    }
1055
1056    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1057        ServerSnapshot {
1058            connection_pool: ConnectionPoolGuard {
1059                guard: self.connection_pool.lock(),
1060                _not_send: PhantomData,
1061            },
1062            peer: &self.peer,
1063        }
1064    }
1065}
1066
1067impl Deref for ConnectionPoolGuard<'_> {
1068    type Target = ConnectionPool;
1069
1070    fn deref(&self) -> &Self::Target {
1071        &self.guard
1072    }
1073}
1074
1075impl DerefMut for ConnectionPoolGuard<'_> {
1076    fn deref_mut(&mut self) -> &mut Self::Target {
1077        &mut self.guard
1078    }
1079}
1080
1081impl Drop for ConnectionPoolGuard<'_> {
1082    fn drop(&mut self) {
1083        #[cfg(test)]
1084        self.check_invariants();
1085    }
1086}
1087
1088fn broadcast<F>(
1089    sender_id: Option<ConnectionId>,
1090    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1091    mut f: F,
1092) where
1093    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1094{
1095    for receiver_id in receiver_ids {
1096        if Some(receiver_id) != sender_id
1097            && let Err(error) = f(receiver_id)
1098        {
1099            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1100        }
1101    }
1102}
1103
1104pub struct ProtocolVersion(u32);
1105
1106impl Header for ProtocolVersion {
1107    fn name() -> &'static HeaderName {
1108        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1109        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1110    }
1111
1112    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1113    where
1114        Self: Sized,
1115        I: Iterator<Item = &'i axum::http::HeaderValue>,
1116    {
1117        let version = values
1118            .next()
1119            .ok_or_else(axum::headers::Error::invalid)?
1120            .to_str()
1121            .map_err(|_| axum::headers::Error::invalid())?
1122            .parse()
1123            .map_err(|_| axum::headers::Error::invalid())?;
1124        Ok(Self(version))
1125    }
1126
1127    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1128        values.extend([self.0.to_string().parse().unwrap()]);
1129    }
1130}
1131
1132pub struct AppVersionHeader(SemanticVersion);
1133impl Header for AppVersionHeader {
1134    fn name() -> &'static HeaderName {
1135        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1136        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1137    }
1138
1139    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1140    where
1141        Self: Sized,
1142        I: Iterator<Item = &'i axum::http::HeaderValue>,
1143    {
1144        let version = values
1145            .next()
1146            .ok_or_else(axum::headers::Error::invalid)?
1147            .to_str()
1148            .map_err(|_| axum::headers::Error::invalid())?
1149            .parse()
1150            .map_err(|_| axum::headers::Error::invalid())?;
1151        Ok(Self(version))
1152    }
1153
1154    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1155        values.extend([self.0.to_string().parse().unwrap()]);
1156    }
1157}
1158
1159#[derive(Debug)]
1160pub struct ReleaseChannelHeader(String);
1161
1162impl Header for ReleaseChannelHeader {
1163    fn name() -> &'static HeaderName {
1164        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1165        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1166    }
1167
1168    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1169    where
1170        Self: Sized,
1171        I: Iterator<Item = &'i axum::http::HeaderValue>,
1172    {
1173        Ok(Self(
1174            values
1175                .next()
1176                .ok_or_else(axum::headers::Error::invalid)?
1177                .to_str()
1178                .map_err(|_| axum::headers::Error::invalid())?
1179                .to_owned(),
1180        ))
1181    }
1182
1183    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1184        values.extend([self.0.parse().unwrap()]);
1185    }
1186}
1187
1188pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1189    Router::new()
1190        .route("/rpc", get(handle_websocket_request))
1191        .layer(
1192            ServiceBuilder::new()
1193                .layer(Extension(server.app_state.clone()))
1194                .layer(middleware::from_fn(auth::validate_header)),
1195        )
1196        .route("/metrics", get(handle_metrics))
1197        .layer(Extension(server))
1198}
1199
1200pub async fn handle_websocket_request(
1201    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1202    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1203    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1204    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1205    Extension(server): Extension<Arc<Server>>,
1206    Extension(principal): Extension<Principal>,
1207    user_agent: Option<TypedHeader<UserAgent>>,
1208    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1209    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1210    ws: WebSocketUpgrade,
1211) -> axum::response::Response {
1212    if protocol_version != rpc::PROTOCOL_VERSION {
1213        return (
1214            StatusCode::UPGRADE_REQUIRED,
1215            "client must be upgraded".to_string(),
1216        )
1217            .into_response();
1218    }
1219
1220    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1221        return (
1222            StatusCode::UPGRADE_REQUIRED,
1223            "no version header found".to_string(),
1224        )
1225            .into_response();
1226    };
1227
1228    let release_channel = release_channel_header.map(|header| header.0.0);
1229
1230    if !version.can_collaborate() {
1231        return (
1232            StatusCode::UPGRADE_REQUIRED,
1233            "client must be upgraded".to_string(),
1234        )
1235            .into_response();
1236    }
1237
1238    let socket_address = socket_address.to_string();
1239
1240    // Acquire connection guard before WebSocket upgrade
1241    let connection_guard = match ConnectionGuard::try_acquire() {
1242        Ok(guard) => guard,
1243        Err(()) => {
1244            return (
1245                StatusCode::SERVICE_UNAVAILABLE,
1246                "Too many concurrent connections",
1247            )
1248                .into_response();
1249        }
1250    };
1251
1252    ws.on_upgrade(move |socket| {
1253        let socket = socket
1254            .map_ok(to_tungstenite_message)
1255            .err_into()
1256            .with(|message| async move { to_axum_message(message) });
1257        let connection = Connection::new(Box::pin(socket));
1258        async move {
1259            server
1260                .handle_connection(
1261                    connection,
1262                    socket_address,
1263                    principal,
1264                    version,
1265                    release_channel,
1266                    user_agent.map(|header| header.to_string()),
1267                    country_code_header.map(|header| header.to_string()),
1268                    system_id_header.map(|header| header.to_string()),
1269                    None,
1270                    Executor::Production,
1271                    Some(connection_guard),
1272                )
1273                .await;
1274        }
1275    })
1276}
1277
1278pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1279    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1280    let connections_metric = CONNECTIONS_METRIC
1281        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1282
1283    let connections = server
1284        .connection_pool
1285        .lock()
1286        .connections()
1287        .filter(|connection| !connection.admin)
1288        .count();
1289    connections_metric.set(connections as _);
1290
1291    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1292    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1293        register_int_gauge!(
1294            "shared_projects",
1295            "number of open projects with one or more guests"
1296        )
1297        .unwrap()
1298    });
1299
1300    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1301    shared_projects_metric.set(shared_projects as _);
1302
1303    let encoder = prometheus::TextEncoder::new();
1304    let metric_families = prometheus::gather();
1305    let encoded_metrics = encoder
1306        .encode_to_string(&metric_families)
1307        .map_err(|err| anyhow!("{err}"))?;
1308    Ok(encoded_metrics)
1309}
1310
1311#[instrument(err, skip(executor))]
1312async fn connection_lost(
1313    session: Session,
1314    mut teardown: watch::Receiver<bool>,
1315    executor: Executor,
1316) -> Result<()> {
1317    session.peer.disconnect(session.connection_id);
1318    session
1319        .connection_pool()
1320        .await
1321        .remove_connection(session.connection_id)?;
1322
1323    session
1324        .db()
1325        .await
1326        .connection_lost(session.connection_id)
1327        .await
1328        .trace_err();
1329
1330    futures::select_biased! {
1331        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1332
1333            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1334            leave_room_for_session(&session, session.connection_id).await.trace_err();
1335            leave_channel_buffers_for_session(&session)
1336                .await
1337                .trace_err();
1338
1339            if !session
1340                .connection_pool()
1341                .await
1342                .is_user_online(session.user_id())
1343            {
1344                let db = session.db().await;
1345                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1346                    room_updated(&room, &session.peer);
1347                }
1348            }
1349
1350            update_user_contacts(session.user_id(), &session).await?;
1351        },
1352        _ = teardown.changed().fuse() => {}
1353    }
1354
1355    Ok(())
1356}
1357
1358/// Acknowledges a ping from a client, used to keep the connection alive.
1359async fn ping(
1360    _: proto::Ping,
1361    response: Response<proto::Ping>,
1362    _session: MessageContext,
1363) -> Result<()> {
1364    response.send(proto::Ack {})?;
1365    Ok(())
1366}
1367
1368/// Creates a new room for calling (outside of channels)
1369async fn create_room(
1370    _request: proto::CreateRoom,
1371    response: Response<proto::CreateRoom>,
1372    session: MessageContext,
1373) -> Result<()> {
1374    let livekit_room = nanoid::nanoid!(30);
1375
1376    let live_kit_connection_info = util::maybe!(async {
1377        let live_kit = session.app_state.livekit_client.as_ref();
1378        let live_kit = live_kit?;
1379        let user_id = session.user_id().to_string();
1380
1381        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1382
1383        Some(proto::LiveKitConnectionInfo {
1384            server_url: live_kit.url().into(),
1385            token,
1386            can_publish: true,
1387        })
1388    })
1389    .await;
1390
1391    let room = session
1392        .db()
1393        .await
1394        .create_room(session.user_id(), session.connection_id, &livekit_room)
1395        .await?;
1396
1397    response.send(proto::CreateRoomResponse {
1398        room: Some(room.clone()),
1399        live_kit_connection_info,
1400    })?;
1401
1402    update_user_contacts(session.user_id(), &session).await?;
1403    Ok(())
1404}
1405
1406/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1407async fn join_room(
1408    request: proto::JoinRoom,
1409    response: Response<proto::JoinRoom>,
1410    session: MessageContext,
1411) -> Result<()> {
1412    let room_id = RoomId::from_proto(request.id);
1413
1414    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1415
1416    if let Some(channel_id) = channel_id {
1417        return join_channel_internal(channel_id, Box::new(response), session).await;
1418    }
1419
1420    let joined_room = {
1421        let room = session
1422            .db()
1423            .await
1424            .join_room(room_id, session.user_id(), session.connection_id)
1425            .await?;
1426        room_updated(&room.room, &session.peer);
1427        room.into_inner()
1428    };
1429
1430    for connection_id in session
1431        .connection_pool()
1432        .await
1433        .user_connection_ids(session.user_id())
1434    {
1435        session
1436            .peer
1437            .send(
1438                connection_id,
1439                proto::CallCanceled {
1440                    room_id: room_id.to_proto(),
1441                },
1442            )
1443            .trace_err();
1444    }
1445
1446    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1447    {
1448        live_kit
1449            .room_token(
1450                &joined_room.room.livekit_room,
1451                &session.user_id().to_string(),
1452            )
1453            .trace_err()
1454            .map(|token| proto::LiveKitConnectionInfo {
1455                server_url: live_kit.url().into(),
1456                token,
1457                can_publish: true,
1458            })
1459    } else {
1460        None
1461    };
1462
1463    response.send(proto::JoinRoomResponse {
1464        room: Some(joined_room.room),
1465        channel_id: None,
1466        live_kit_connection_info,
1467    })?;
1468
1469    update_user_contacts(session.user_id(), &session).await?;
1470    Ok(())
1471}
1472
1473/// Rejoin room is used to reconnect to a room after connection errors.
1474async fn rejoin_room(
1475    request: proto::RejoinRoom,
1476    response: Response<proto::RejoinRoom>,
1477    session: MessageContext,
1478) -> Result<()> {
1479    let room;
1480    let channel;
1481    {
1482        let mut rejoined_room = session
1483            .db()
1484            .await
1485            .rejoin_room(request, session.user_id(), session.connection_id)
1486            .await?;
1487
1488        response.send(proto::RejoinRoomResponse {
1489            room: Some(rejoined_room.room.clone()),
1490            reshared_projects: rejoined_room
1491                .reshared_projects
1492                .iter()
1493                .map(|project| proto::ResharedProject {
1494                    id: project.id.to_proto(),
1495                    collaborators: project
1496                        .collaborators
1497                        .iter()
1498                        .map(|collaborator| collaborator.to_proto())
1499                        .collect(),
1500                })
1501                .collect(),
1502            rejoined_projects: rejoined_room
1503                .rejoined_projects
1504                .iter()
1505                .map(|rejoined_project| rejoined_project.to_proto())
1506                .collect(),
1507        })?;
1508        room_updated(&rejoined_room.room, &session.peer);
1509
1510        for project in &rejoined_room.reshared_projects {
1511            for collaborator in &project.collaborators {
1512                session
1513                    .peer
1514                    .send(
1515                        collaborator.connection_id,
1516                        proto::UpdateProjectCollaborator {
1517                            project_id: project.id.to_proto(),
1518                            old_peer_id: Some(project.old_connection_id.into()),
1519                            new_peer_id: Some(session.connection_id.into()),
1520                        },
1521                    )
1522                    .trace_err();
1523            }
1524
1525            broadcast(
1526                Some(session.connection_id),
1527                project
1528                    .collaborators
1529                    .iter()
1530                    .map(|collaborator| collaborator.connection_id),
1531                |connection_id| {
1532                    session.peer.forward_send(
1533                        session.connection_id,
1534                        connection_id,
1535                        proto::UpdateProject {
1536                            project_id: project.id.to_proto(),
1537                            worktrees: project.worktrees.clone(),
1538                        },
1539                    )
1540                },
1541            );
1542        }
1543
1544        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1545
1546        let rejoined_room = rejoined_room.into_inner();
1547
1548        room = rejoined_room.room;
1549        channel = rejoined_room.channel;
1550    }
1551
1552    if let Some(channel) = channel {
1553        channel_updated(
1554            &channel,
1555            &room,
1556            &session.peer,
1557            &*session.connection_pool().await,
1558        );
1559    }
1560
1561    update_user_contacts(session.user_id(), &session).await?;
1562    Ok(())
1563}
1564
1565fn notify_rejoined_projects(
1566    rejoined_projects: &mut Vec<RejoinedProject>,
1567    session: &Session,
1568) -> Result<()> {
1569    for project in rejoined_projects.iter() {
1570        for collaborator in &project.collaborators {
1571            session
1572                .peer
1573                .send(
1574                    collaborator.connection_id,
1575                    proto::UpdateProjectCollaborator {
1576                        project_id: project.id.to_proto(),
1577                        old_peer_id: Some(project.old_connection_id.into()),
1578                        new_peer_id: Some(session.connection_id.into()),
1579                    },
1580                )
1581                .trace_err();
1582        }
1583    }
1584
1585    for project in rejoined_projects {
1586        for worktree in mem::take(&mut project.worktrees) {
1587            // Stream this worktree's entries.
1588            let message = proto::UpdateWorktree {
1589                project_id: project.id.to_proto(),
1590                worktree_id: worktree.id,
1591                abs_path: worktree.abs_path.clone(),
1592                root_name: worktree.root_name,
1593                updated_entries: worktree.updated_entries,
1594                removed_entries: worktree.removed_entries,
1595                scan_id: worktree.scan_id,
1596                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1597                updated_repositories: worktree.updated_repositories,
1598                removed_repositories: worktree.removed_repositories,
1599            };
1600            for update in proto::split_worktree_update(message) {
1601                session.peer.send(session.connection_id, update)?;
1602            }
1603
1604            // Stream this worktree's diagnostics.
1605            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1606            if let Some(summary) = worktree_diagnostics.next() {
1607                let message = proto::UpdateDiagnosticSummary {
1608                    project_id: project.id.to_proto(),
1609                    worktree_id: worktree.id,
1610                    summary: Some(summary),
1611                    more_summaries: worktree_diagnostics.collect(),
1612                };
1613                session.peer.send(session.connection_id, message)?;
1614            }
1615
1616            for settings_file in worktree.settings_files {
1617                session.peer.send(
1618                    session.connection_id,
1619                    proto::UpdateWorktreeSettings {
1620                        project_id: project.id.to_proto(),
1621                        worktree_id: worktree.id,
1622                        path: settings_file.path,
1623                        content: Some(settings_file.content),
1624                        kind: Some(settings_file.kind.to_proto().into()),
1625                    },
1626                )?;
1627            }
1628        }
1629
1630        for repository in mem::take(&mut project.updated_repositories) {
1631            for update in split_repository_update(repository) {
1632                session.peer.send(session.connection_id, update)?;
1633            }
1634        }
1635
1636        for id in mem::take(&mut project.removed_repositories) {
1637            session.peer.send(
1638                session.connection_id,
1639                proto::RemoveRepository {
1640                    project_id: project.id.to_proto(),
1641                    id,
1642                },
1643            )?;
1644        }
1645    }
1646
1647    Ok(())
1648}
1649
1650/// leave room disconnects from the room.
1651async fn leave_room(
1652    _: proto::LeaveRoom,
1653    response: Response<proto::LeaveRoom>,
1654    session: MessageContext,
1655) -> Result<()> {
1656    leave_room_for_session(&session, session.connection_id).await?;
1657    response.send(proto::Ack {})?;
1658    Ok(())
1659}
1660
1661/// Updates the permissions of someone else in the room.
1662async fn set_room_participant_role(
1663    request: proto::SetRoomParticipantRole,
1664    response: Response<proto::SetRoomParticipantRole>,
1665    session: MessageContext,
1666) -> Result<()> {
1667    let user_id = UserId::from_proto(request.user_id);
1668    let role = ChannelRole::from(request.role());
1669
1670    let (livekit_room, can_publish) = {
1671        let room = session
1672            .db()
1673            .await
1674            .set_room_participant_role(
1675                session.user_id(),
1676                RoomId::from_proto(request.room_id),
1677                user_id,
1678                role,
1679            )
1680            .await?;
1681
1682        let livekit_room = room.livekit_room.clone();
1683        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1684        room_updated(&room, &session.peer);
1685        (livekit_room, can_publish)
1686    };
1687
1688    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1689        live_kit
1690            .update_participant(
1691                livekit_room.clone(),
1692                request.user_id.to_string(),
1693                livekit_api::proto::ParticipantPermission {
1694                    can_subscribe: true,
1695                    can_publish,
1696                    can_publish_data: can_publish,
1697                    hidden: false,
1698                    recorder: false,
1699                },
1700            )
1701            .await
1702            .trace_err();
1703    }
1704
1705    response.send(proto::Ack {})?;
1706    Ok(())
1707}
1708
1709/// Call someone else into the current room
1710async fn call(
1711    request: proto::Call,
1712    response: Response<proto::Call>,
1713    session: MessageContext,
1714) -> Result<()> {
1715    let room_id = RoomId::from_proto(request.room_id);
1716    let calling_user_id = session.user_id();
1717    let calling_connection_id = session.connection_id;
1718    let called_user_id = UserId::from_proto(request.called_user_id);
1719    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1720    if !session
1721        .db()
1722        .await
1723        .has_contact(calling_user_id, called_user_id)
1724        .await?
1725    {
1726        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1727    }
1728
1729    let incoming_call = {
1730        let (room, incoming_call) = &mut *session
1731            .db()
1732            .await
1733            .call(
1734                room_id,
1735                calling_user_id,
1736                calling_connection_id,
1737                called_user_id,
1738                initial_project_id,
1739            )
1740            .await?;
1741        room_updated(room, &session.peer);
1742        mem::take(incoming_call)
1743    };
1744    update_user_contacts(called_user_id, &session).await?;
1745
1746    let mut calls = session
1747        .connection_pool()
1748        .await
1749        .user_connection_ids(called_user_id)
1750        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1751        .collect::<FuturesUnordered<_>>();
1752
1753    while let Some(call_response) = calls.next().await {
1754        match call_response.as_ref() {
1755            Ok(_) => {
1756                response.send(proto::Ack {})?;
1757                return Ok(());
1758            }
1759            Err(_) => {
1760                call_response.trace_err();
1761            }
1762        }
1763    }
1764
1765    {
1766        let room = session
1767            .db()
1768            .await
1769            .call_failed(room_id, called_user_id)
1770            .await?;
1771        room_updated(&room, &session.peer);
1772    }
1773    update_user_contacts(called_user_id, &session).await?;
1774
1775    Err(anyhow!("failed to ring user"))?
1776}
1777
1778/// Cancel an outgoing call.
1779async fn cancel_call(
1780    request: proto::CancelCall,
1781    response: Response<proto::CancelCall>,
1782    session: MessageContext,
1783) -> Result<()> {
1784    let called_user_id = UserId::from_proto(request.called_user_id);
1785    let room_id = RoomId::from_proto(request.room_id);
1786    {
1787        let room = session
1788            .db()
1789            .await
1790            .cancel_call(room_id, session.connection_id, called_user_id)
1791            .await?;
1792        room_updated(&room, &session.peer);
1793    }
1794
1795    for connection_id in session
1796        .connection_pool()
1797        .await
1798        .user_connection_ids(called_user_id)
1799    {
1800        session
1801            .peer
1802            .send(
1803                connection_id,
1804                proto::CallCanceled {
1805                    room_id: room_id.to_proto(),
1806                },
1807            )
1808            .trace_err();
1809    }
1810    response.send(proto::Ack {})?;
1811
1812    update_user_contacts(called_user_id, &session).await?;
1813    Ok(())
1814}
1815
1816/// Decline an incoming call.
1817async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1818    let room_id = RoomId::from_proto(message.room_id);
1819    {
1820        let room = session
1821            .db()
1822            .await
1823            .decline_call(Some(room_id), session.user_id())
1824            .await?
1825            .context("declining call")?;
1826        room_updated(&room, &session.peer);
1827    }
1828
1829    for connection_id in session
1830        .connection_pool()
1831        .await
1832        .user_connection_ids(session.user_id())
1833    {
1834        session
1835            .peer
1836            .send(
1837                connection_id,
1838                proto::CallCanceled {
1839                    room_id: room_id.to_proto(),
1840                },
1841            )
1842            .trace_err();
1843    }
1844    update_user_contacts(session.user_id(), &session).await?;
1845    Ok(())
1846}
1847
1848/// Updates other participants in the room with your current location.
1849async fn update_participant_location(
1850    request: proto::UpdateParticipantLocation,
1851    response: Response<proto::UpdateParticipantLocation>,
1852    session: MessageContext,
1853) -> Result<()> {
1854    let room_id = RoomId::from_proto(request.room_id);
1855    let location = request.location.context("invalid location")?;
1856
1857    let db = session.db().await;
1858    let room = db
1859        .update_room_participant_location(room_id, session.connection_id, location)
1860        .await?;
1861
1862    room_updated(&room, &session.peer);
1863    response.send(proto::Ack {})?;
1864    Ok(())
1865}
1866
1867/// Share a project into the room.
1868async fn share_project(
1869    request: proto::ShareProject,
1870    response: Response<proto::ShareProject>,
1871    session: MessageContext,
1872) -> Result<()> {
1873    let (project_id, room) = &*session
1874        .db()
1875        .await
1876        .share_project(
1877            RoomId::from_proto(request.room_id),
1878            session.connection_id,
1879            &request.worktrees,
1880            request.is_ssh_project,
1881        )
1882        .await?;
1883    response.send(proto::ShareProjectResponse {
1884        project_id: project_id.to_proto(),
1885    })?;
1886    room_updated(room, &session.peer);
1887
1888    Ok(())
1889}
1890
1891/// Unshare a project from the room.
1892async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1893    let project_id = ProjectId::from_proto(message.project_id);
1894    unshare_project_internal(project_id, session.connection_id, &session).await
1895}
1896
1897async fn unshare_project_internal(
1898    project_id: ProjectId,
1899    connection_id: ConnectionId,
1900    session: &Session,
1901) -> Result<()> {
1902    let delete = {
1903        let room_guard = session
1904            .db()
1905            .await
1906            .unshare_project(project_id, connection_id)
1907            .await?;
1908
1909        let (delete, room, guest_connection_ids) = &*room_guard;
1910
1911        let message = proto::UnshareProject {
1912            project_id: project_id.to_proto(),
1913        };
1914
1915        broadcast(
1916            Some(connection_id),
1917            guest_connection_ids.iter().copied(),
1918            |conn_id| session.peer.send(conn_id, message.clone()),
1919        );
1920        if let Some(room) = room {
1921            room_updated(room, &session.peer);
1922        }
1923
1924        *delete
1925    };
1926
1927    if delete {
1928        let db = session.db().await;
1929        db.delete_project(project_id).await?;
1930    }
1931
1932    Ok(())
1933}
1934
1935/// Join someone elses shared project.
1936async fn join_project(
1937    request: proto::JoinProject,
1938    response: Response<proto::JoinProject>,
1939    session: MessageContext,
1940) -> Result<()> {
1941    let project_id = ProjectId::from_proto(request.project_id);
1942
1943    tracing::info!(%project_id, "join project");
1944
1945    let db = session.db().await;
1946    let (project, replica_id) = &mut *db
1947        .join_project(
1948            project_id,
1949            session.connection_id,
1950            session.user_id(),
1951            request.committer_name.clone(),
1952            request.committer_email.clone(),
1953        )
1954        .await?;
1955    drop(db);
1956    tracing::info!(%project_id, "join remote project");
1957    let collaborators = project
1958        .collaborators
1959        .iter()
1960        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1961        .map(|collaborator| collaborator.to_proto())
1962        .collect::<Vec<_>>();
1963    let project_id = project.id;
1964    let guest_user_id = session.user_id();
1965
1966    let worktrees = project
1967        .worktrees
1968        .iter()
1969        .map(|(id, worktree)| proto::WorktreeMetadata {
1970            id: *id,
1971            root_name: worktree.root_name.clone(),
1972            visible: worktree.visible,
1973            abs_path: worktree.abs_path.clone(),
1974        })
1975        .collect::<Vec<_>>();
1976
1977    let add_project_collaborator = proto::AddProjectCollaborator {
1978        project_id: project_id.to_proto(),
1979        collaborator: Some(proto::Collaborator {
1980            peer_id: Some(session.connection_id.into()),
1981            replica_id: replica_id.0 as u32,
1982            user_id: guest_user_id.to_proto(),
1983            is_host: false,
1984            committer_name: request.committer_name.clone(),
1985            committer_email: request.committer_email.clone(),
1986        }),
1987    };
1988
1989    for collaborator in &collaborators {
1990        session
1991            .peer
1992            .send(
1993                collaborator.peer_id.unwrap().into(),
1994                add_project_collaborator.clone(),
1995            )
1996            .trace_err();
1997    }
1998
1999    // First, we send the metadata associated with each worktree.
2000    let (language_servers, language_server_capabilities) = project
2001        .language_servers
2002        .clone()
2003        .into_iter()
2004        .map(|server| (server.server, server.capabilities))
2005        .unzip();
2006    response.send(proto::JoinProjectResponse {
2007        project_id: project.id.0 as u64,
2008        worktrees,
2009        replica_id: replica_id.0 as u32,
2010        collaborators,
2011        language_servers,
2012        language_server_capabilities,
2013        role: project.role.into(),
2014    })?;
2015
2016    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2017        // Stream this worktree's entries.
2018        let message = proto::UpdateWorktree {
2019            project_id: project_id.to_proto(),
2020            worktree_id,
2021            abs_path: worktree.abs_path.clone(),
2022            root_name: worktree.root_name,
2023            updated_entries: worktree.entries,
2024            removed_entries: Default::default(),
2025            scan_id: worktree.scan_id,
2026            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2027            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2028            removed_repositories: Default::default(),
2029        };
2030        for update in proto::split_worktree_update(message) {
2031            session.peer.send(session.connection_id, update.clone())?;
2032        }
2033
2034        // Stream this worktree's diagnostics.
2035        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2036        if let Some(summary) = worktree_diagnostics.next() {
2037            let message = proto::UpdateDiagnosticSummary {
2038                project_id: project.id.to_proto(),
2039                worktree_id: worktree.id,
2040                summary: Some(summary),
2041                more_summaries: worktree_diagnostics.collect(),
2042            };
2043            session.peer.send(session.connection_id, message)?;
2044        }
2045
2046        for settings_file in worktree.settings_files {
2047            session.peer.send(
2048                session.connection_id,
2049                proto::UpdateWorktreeSettings {
2050                    project_id: project_id.to_proto(),
2051                    worktree_id: worktree.id,
2052                    path: settings_file.path,
2053                    content: Some(settings_file.content),
2054                    kind: Some(settings_file.kind.to_proto() as i32),
2055                },
2056            )?;
2057        }
2058    }
2059
2060    for repository in mem::take(&mut project.repositories) {
2061        for update in split_repository_update(repository) {
2062            session.peer.send(session.connection_id, update)?;
2063        }
2064    }
2065
2066    for language_server in &project.language_servers {
2067        session.peer.send(
2068            session.connection_id,
2069            proto::UpdateLanguageServer {
2070                project_id: project_id.to_proto(),
2071                server_name: Some(language_server.server.name.clone()),
2072                language_server_id: language_server.server.id,
2073                variant: Some(
2074                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2075                        proto::LspDiskBasedDiagnosticsUpdated {},
2076                    ),
2077                ),
2078            },
2079        )?;
2080    }
2081
2082    Ok(())
2083}
2084
2085/// Leave someone elses shared project.
2086async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2087    let sender_id = session.connection_id;
2088    let project_id = ProjectId::from_proto(request.project_id);
2089    let db = session.db().await;
2090
2091    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2092    tracing::info!(
2093        %project_id,
2094        "leave project"
2095    );
2096
2097    project_left(project, &session);
2098    if let Some(room) = room {
2099        room_updated(room, &session.peer);
2100    }
2101
2102    Ok(())
2103}
2104
2105/// Updates other participants with changes to the project
2106async fn update_project(
2107    request: proto::UpdateProject,
2108    response: Response<proto::UpdateProject>,
2109    session: MessageContext,
2110) -> Result<()> {
2111    let project_id = ProjectId::from_proto(request.project_id);
2112    let (room, guest_connection_ids) = &*session
2113        .db()
2114        .await
2115        .update_project(project_id, session.connection_id, &request.worktrees)
2116        .await?;
2117    broadcast(
2118        Some(session.connection_id),
2119        guest_connection_ids.iter().copied(),
2120        |connection_id| {
2121            session
2122                .peer
2123                .forward_send(session.connection_id, connection_id, request.clone())
2124        },
2125    );
2126    if let Some(room) = room {
2127        room_updated(room, &session.peer);
2128    }
2129    response.send(proto::Ack {})?;
2130
2131    Ok(())
2132}
2133
2134/// Updates other participants with changes to the worktree
2135async fn update_worktree(
2136    request: proto::UpdateWorktree,
2137    response: Response<proto::UpdateWorktree>,
2138    session: MessageContext,
2139) -> Result<()> {
2140    let guest_connection_ids = session
2141        .db()
2142        .await
2143        .update_worktree(&request, session.connection_id)
2144        .await?;
2145
2146    broadcast(
2147        Some(session.connection_id),
2148        guest_connection_ids.iter().copied(),
2149        |connection_id| {
2150            session
2151                .peer
2152                .forward_send(session.connection_id, connection_id, request.clone())
2153        },
2154    );
2155    response.send(proto::Ack {})?;
2156    Ok(())
2157}
2158
2159async fn update_repository(
2160    request: proto::UpdateRepository,
2161    response: Response<proto::UpdateRepository>,
2162    session: MessageContext,
2163) -> Result<()> {
2164    let guest_connection_ids = session
2165        .db()
2166        .await
2167        .update_repository(&request, session.connection_id)
2168        .await?;
2169
2170    broadcast(
2171        Some(session.connection_id),
2172        guest_connection_ids.iter().copied(),
2173        |connection_id| {
2174            session
2175                .peer
2176                .forward_send(session.connection_id, connection_id, request.clone())
2177        },
2178    );
2179    response.send(proto::Ack {})?;
2180    Ok(())
2181}
2182
2183async fn remove_repository(
2184    request: proto::RemoveRepository,
2185    response: Response<proto::RemoveRepository>,
2186    session: MessageContext,
2187) -> Result<()> {
2188    let guest_connection_ids = session
2189        .db()
2190        .await
2191        .remove_repository(&request, session.connection_id)
2192        .await?;
2193
2194    broadcast(
2195        Some(session.connection_id),
2196        guest_connection_ids.iter().copied(),
2197        |connection_id| {
2198            session
2199                .peer
2200                .forward_send(session.connection_id, connection_id, request.clone())
2201        },
2202    );
2203    response.send(proto::Ack {})?;
2204    Ok(())
2205}
2206
2207/// Updates other participants with changes to the diagnostics
2208async fn update_diagnostic_summary(
2209    message: proto::UpdateDiagnosticSummary,
2210    session: MessageContext,
2211) -> Result<()> {
2212    let guest_connection_ids = session
2213        .db()
2214        .await
2215        .update_diagnostic_summary(&message, session.connection_id)
2216        .await?;
2217
2218    broadcast(
2219        Some(session.connection_id),
2220        guest_connection_ids.iter().copied(),
2221        |connection_id| {
2222            session
2223                .peer
2224                .forward_send(session.connection_id, connection_id, message.clone())
2225        },
2226    );
2227
2228    Ok(())
2229}
2230
2231/// Updates other participants with changes to the worktree settings
2232async fn update_worktree_settings(
2233    message: proto::UpdateWorktreeSettings,
2234    session: MessageContext,
2235) -> Result<()> {
2236    let guest_connection_ids = session
2237        .db()
2238        .await
2239        .update_worktree_settings(&message, session.connection_id)
2240        .await?;
2241
2242    broadcast(
2243        Some(session.connection_id),
2244        guest_connection_ids.iter().copied(),
2245        |connection_id| {
2246            session
2247                .peer
2248                .forward_send(session.connection_id, connection_id, message.clone())
2249        },
2250    );
2251
2252    Ok(())
2253}
2254
2255/// Notify other participants that a language server has started.
2256async fn start_language_server(
2257    request: proto::StartLanguageServer,
2258    session: MessageContext,
2259) -> Result<()> {
2260    let guest_connection_ids = session
2261        .db()
2262        .await
2263        .start_language_server(&request, session.connection_id)
2264        .await?;
2265
2266    broadcast(
2267        Some(session.connection_id),
2268        guest_connection_ids.iter().copied(),
2269        |connection_id| {
2270            session
2271                .peer
2272                .forward_send(session.connection_id, connection_id, request.clone())
2273        },
2274    );
2275    Ok(())
2276}
2277
2278/// Notify other participants that a language server has changed.
2279async fn update_language_server(
2280    request: proto::UpdateLanguageServer,
2281    session: MessageContext,
2282) -> Result<()> {
2283    let project_id = ProjectId::from_proto(request.project_id);
2284    let db = session.db().await;
2285
2286    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2287        && let Some(capabilities) = update.capabilities.clone()
2288    {
2289        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2290            .await?;
2291    }
2292
2293    let project_connection_ids = db
2294        .project_connection_ids(project_id, session.connection_id, true)
2295        .await?;
2296    broadcast(
2297        Some(session.connection_id),
2298        project_connection_ids.iter().copied(),
2299        |connection_id| {
2300            session
2301                .peer
2302                .forward_send(session.connection_id, connection_id, request.clone())
2303        },
2304    );
2305    Ok(())
2306}
2307
2308/// forward a project request to the host. These requests should be read only
2309/// as guests are allowed to send them.
2310async fn forward_read_only_project_request<T>(
2311    request: T,
2312    response: Response<T>,
2313    session: MessageContext,
2314) -> Result<()>
2315where
2316    T: EntityMessage + RequestMessage,
2317{
2318    let project_id = ProjectId::from_proto(request.remote_entity_id());
2319    let host_connection_id = session
2320        .db()
2321        .await
2322        .host_for_read_only_project_request(project_id, session.connection_id)
2323        .await?;
2324    let payload = session.forward_request(host_connection_id, request).await?;
2325    response.send(payload)?;
2326    Ok(())
2327}
2328
2329/// forward a project request to the host. These requests are disallowed
2330/// for guests.
2331async fn forward_mutating_project_request<T>(
2332    request: T,
2333    response: Response<T>,
2334    session: MessageContext,
2335) -> Result<()>
2336where
2337    T: EntityMessage + RequestMessage,
2338{
2339    let project_id = ProjectId::from_proto(request.remote_entity_id());
2340
2341    let host_connection_id = session
2342        .db()
2343        .await
2344        .host_for_mutating_project_request(project_id, session.connection_id)
2345        .await?;
2346    let payload = session.forward_request(host_connection_id, request).await?;
2347    response.send(payload)?;
2348    Ok(())
2349}
2350
2351async fn lsp_query(
2352    request: proto::LspQuery,
2353    response: Response<proto::LspQuery>,
2354    session: MessageContext,
2355) -> Result<()> {
2356    let (name, should_write) = request.query_name_and_write_permissions();
2357    tracing::Span::current().record("lsp_query_request", name);
2358    tracing::info!("lsp_query message received");
2359    if should_write {
2360        forward_mutating_project_request(request, response, session).await
2361    } else {
2362        forward_read_only_project_request(request, response, session).await
2363    }
2364}
2365
2366/// Notify other participants that a new buffer has been created
2367async fn create_buffer_for_peer(
2368    request: proto::CreateBufferForPeer,
2369    session: MessageContext,
2370) -> Result<()> {
2371    session
2372        .db()
2373        .await
2374        .check_user_is_project_host(
2375            ProjectId::from_proto(request.project_id),
2376            session.connection_id,
2377        )
2378        .await?;
2379    let peer_id = request.peer_id.context("invalid peer id")?;
2380    session
2381        .peer
2382        .forward_send(session.connection_id, peer_id.into(), request)?;
2383    Ok(())
2384}
2385
2386/// Notify other participants that a buffer has been updated. This is
2387/// allowed for guests as long as the update is limited to selections.
2388async fn update_buffer(
2389    request: proto::UpdateBuffer,
2390    response: Response<proto::UpdateBuffer>,
2391    session: MessageContext,
2392) -> Result<()> {
2393    let project_id = ProjectId::from_proto(request.project_id);
2394    let mut capability = Capability::ReadOnly;
2395
2396    for op in request.operations.iter() {
2397        match op.variant {
2398            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2399            Some(_) => capability = Capability::ReadWrite,
2400        }
2401    }
2402
2403    let host = {
2404        let guard = session
2405            .db()
2406            .await
2407            .connections_for_buffer_update(project_id, session.connection_id, capability)
2408            .await?;
2409
2410        let (host, guests) = &*guard;
2411
2412        broadcast(
2413            Some(session.connection_id),
2414            guests.clone(),
2415            |connection_id| {
2416                session
2417                    .peer
2418                    .forward_send(session.connection_id, connection_id, request.clone())
2419            },
2420        );
2421
2422        *host
2423    };
2424
2425    if host != session.connection_id {
2426        session.forward_request(host, request.clone()).await?;
2427    }
2428
2429    response.send(proto::Ack {})?;
2430    Ok(())
2431}
2432
2433async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2434    let project_id = ProjectId::from_proto(message.project_id);
2435
2436    let operation = message.operation.as_ref().context("invalid operation")?;
2437    let capability = match operation.variant.as_ref() {
2438        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2439            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2440                match buffer_op.variant {
2441                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2442                        Capability::ReadOnly
2443                    }
2444                    _ => Capability::ReadWrite,
2445                }
2446            } else {
2447                Capability::ReadWrite
2448            }
2449        }
2450        Some(_) => Capability::ReadWrite,
2451        None => Capability::ReadOnly,
2452    };
2453
2454    let guard = session
2455        .db()
2456        .await
2457        .connections_for_buffer_update(project_id, session.connection_id, capability)
2458        .await?;
2459
2460    let (host, guests) = &*guard;
2461
2462    broadcast(
2463        Some(session.connection_id),
2464        guests.iter().chain([host]).copied(),
2465        |connection_id| {
2466            session
2467                .peer
2468                .forward_send(session.connection_id, connection_id, message.clone())
2469        },
2470    );
2471
2472    Ok(())
2473}
2474
2475/// Notify other participants that a project has been updated.
2476async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2477    request: T,
2478    session: MessageContext,
2479) -> Result<()> {
2480    let project_id = ProjectId::from_proto(request.remote_entity_id());
2481    let project_connection_ids = session
2482        .db()
2483        .await
2484        .project_connection_ids(project_id, session.connection_id, false)
2485        .await?;
2486
2487    broadcast(
2488        Some(session.connection_id),
2489        project_connection_ids.iter().copied(),
2490        |connection_id| {
2491            session
2492                .peer
2493                .forward_send(session.connection_id, connection_id, request.clone())
2494        },
2495    );
2496    Ok(())
2497}
2498
2499/// Start following another user in a call.
2500async fn follow(
2501    request: proto::Follow,
2502    response: Response<proto::Follow>,
2503    session: MessageContext,
2504) -> Result<()> {
2505    let room_id = RoomId::from_proto(request.room_id);
2506    let project_id = request.project_id.map(ProjectId::from_proto);
2507    let leader_id = request.leader_id.context("invalid leader id")?.into();
2508    let follower_id = session.connection_id;
2509
2510    session
2511        .db()
2512        .await
2513        .check_room_participants(room_id, leader_id, session.connection_id)
2514        .await?;
2515
2516    let response_payload = session.forward_request(leader_id, request).await?;
2517    response.send(response_payload)?;
2518
2519    if let Some(project_id) = project_id {
2520        let room = session
2521            .db()
2522            .await
2523            .follow(room_id, project_id, leader_id, follower_id)
2524            .await?;
2525        room_updated(&room, &session.peer);
2526    }
2527
2528    Ok(())
2529}
2530
2531/// Stop following another user in a call.
2532async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2533    let room_id = RoomId::from_proto(request.room_id);
2534    let project_id = request.project_id.map(ProjectId::from_proto);
2535    let leader_id = request.leader_id.context("invalid leader id")?.into();
2536    let follower_id = session.connection_id;
2537
2538    session
2539        .db()
2540        .await
2541        .check_room_participants(room_id, leader_id, session.connection_id)
2542        .await?;
2543
2544    session
2545        .peer
2546        .forward_send(session.connection_id, leader_id, request)?;
2547
2548    if let Some(project_id) = project_id {
2549        let room = session
2550            .db()
2551            .await
2552            .unfollow(room_id, project_id, leader_id, follower_id)
2553            .await?;
2554        room_updated(&room, &session.peer);
2555    }
2556
2557    Ok(())
2558}
2559
2560/// Notify everyone following you of your current location.
2561async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2562    let room_id = RoomId::from_proto(request.room_id);
2563    let database = session.db.lock().await;
2564
2565    let connection_ids = if let Some(project_id) = request.project_id {
2566        let project_id = ProjectId::from_proto(project_id);
2567        database
2568            .project_connection_ids(project_id, session.connection_id, true)
2569            .await?
2570    } else {
2571        database
2572            .room_connection_ids(room_id, session.connection_id)
2573            .await?
2574    };
2575
2576    // For now, don't send view update messages back to that view's current leader.
2577    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2578        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2579        _ => None,
2580    });
2581
2582    for connection_id in connection_ids.iter().cloned() {
2583        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2584            session
2585                .peer
2586                .forward_send(session.connection_id, connection_id, request.clone())?;
2587        }
2588    }
2589    Ok(())
2590}
2591
2592/// Get public data about users.
2593async fn get_users(
2594    request: proto::GetUsers,
2595    response: Response<proto::GetUsers>,
2596    session: MessageContext,
2597) -> Result<()> {
2598    let user_ids = request
2599        .user_ids
2600        .into_iter()
2601        .map(UserId::from_proto)
2602        .collect();
2603    let users = session
2604        .db()
2605        .await
2606        .get_users_by_ids(user_ids)
2607        .await?
2608        .into_iter()
2609        .map(|user| proto::User {
2610            id: user.id.to_proto(),
2611            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2612            github_login: user.github_login,
2613            name: user.name,
2614        })
2615        .collect();
2616    response.send(proto::UsersResponse { users })?;
2617    Ok(())
2618}
2619
2620/// Search for users (to invite) buy Github login
2621async fn fuzzy_search_users(
2622    request: proto::FuzzySearchUsers,
2623    response: Response<proto::FuzzySearchUsers>,
2624    session: MessageContext,
2625) -> Result<()> {
2626    let query = request.query;
2627    let users = match query.len() {
2628        0 => vec![],
2629        1 | 2 => session
2630            .db()
2631            .await
2632            .get_user_by_github_login(&query)
2633            .await?
2634            .into_iter()
2635            .collect(),
2636        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2637    };
2638    let users = users
2639        .into_iter()
2640        .filter(|user| user.id != session.user_id())
2641        .map(|user| proto::User {
2642            id: user.id.to_proto(),
2643            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2644            github_login: user.github_login,
2645            name: user.name,
2646        })
2647        .collect();
2648    response.send(proto::UsersResponse { users })?;
2649    Ok(())
2650}
2651
2652/// Send a contact request to another user.
2653async fn request_contact(
2654    request: proto::RequestContact,
2655    response: Response<proto::RequestContact>,
2656    session: MessageContext,
2657) -> Result<()> {
2658    let requester_id = session.user_id();
2659    let responder_id = UserId::from_proto(request.responder_id);
2660    if requester_id == responder_id {
2661        return Err(anyhow!("cannot add yourself as a contact"))?;
2662    }
2663
2664    let notifications = session
2665        .db()
2666        .await
2667        .send_contact_request(requester_id, responder_id)
2668        .await?;
2669
2670    // Update outgoing contact requests of requester
2671    let mut update = proto::UpdateContacts::default();
2672    update.outgoing_requests.push(responder_id.to_proto());
2673    for connection_id in session
2674        .connection_pool()
2675        .await
2676        .user_connection_ids(requester_id)
2677    {
2678        session.peer.send(connection_id, update.clone())?;
2679    }
2680
2681    // Update incoming contact requests of responder
2682    let mut update = proto::UpdateContacts::default();
2683    update
2684        .incoming_requests
2685        .push(proto::IncomingContactRequest {
2686            requester_id: requester_id.to_proto(),
2687        });
2688    let connection_pool = session.connection_pool().await;
2689    for connection_id in connection_pool.user_connection_ids(responder_id) {
2690        session.peer.send(connection_id, update.clone())?;
2691    }
2692
2693    send_notifications(&connection_pool, &session.peer, notifications);
2694
2695    response.send(proto::Ack {})?;
2696    Ok(())
2697}
2698
2699/// Accept or decline a contact request
2700async fn respond_to_contact_request(
2701    request: proto::RespondToContactRequest,
2702    response: Response<proto::RespondToContactRequest>,
2703    session: MessageContext,
2704) -> Result<()> {
2705    let responder_id = session.user_id();
2706    let requester_id = UserId::from_proto(request.requester_id);
2707    let db = session.db().await;
2708    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2709        db.dismiss_contact_notification(responder_id, requester_id)
2710            .await?;
2711    } else {
2712        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2713
2714        let notifications = db
2715            .respond_to_contact_request(responder_id, requester_id, accept)
2716            .await?;
2717        let requester_busy = db.is_user_busy(requester_id).await?;
2718        let responder_busy = db.is_user_busy(responder_id).await?;
2719
2720        let pool = session.connection_pool().await;
2721        // Update responder with new contact
2722        let mut update = proto::UpdateContacts::default();
2723        if accept {
2724            update
2725                .contacts
2726                .push(contact_for_user(requester_id, requester_busy, &pool));
2727        }
2728        update
2729            .remove_incoming_requests
2730            .push(requester_id.to_proto());
2731        for connection_id in pool.user_connection_ids(responder_id) {
2732            session.peer.send(connection_id, update.clone())?;
2733        }
2734
2735        // Update requester with new contact
2736        let mut update = proto::UpdateContacts::default();
2737        if accept {
2738            update
2739                .contacts
2740                .push(contact_for_user(responder_id, responder_busy, &pool));
2741        }
2742        update
2743            .remove_outgoing_requests
2744            .push(responder_id.to_proto());
2745
2746        for connection_id in pool.user_connection_ids(requester_id) {
2747            session.peer.send(connection_id, update.clone())?;
2748        }
2749
2750        send_notifications(&pool, &session.peer, notifications);
2751    }
2752
2753    response.send(proto::Ack {})?;
2754    Ok(())
2755}
2756
2757/// Remove a contact.
2758async fn remove_contact(
2759    request: proto::RemoveContact,
2760    response: Response<proto::RemoveContact>,
2761    session: MessageContext,
2762) -> Result<()> {
2763    let requester_id = session.user_id();
2764    let responder_id = UserId::from_proto(request.user_id);
2765    let db = session.db().await;
2766    let (contact_accepted, deleted_notification_id) =
2767        db.remove_contact(requester_id, responder_id).await?;
2768
2769    let pool = session.connection_pool().await;
2770    // Update outgoing contact requests of requester
2771    let mut update = proto::UpdateContacts::default();
2772    if contact_accepted {
2773        update.remove_contacts.push(responder_id.to_proto());
2774    } else {
2775        update
2776            .remove_outgoing_requests
2777            .push(responder_id.to_proto());
2778    }
2779    for connection_id in pool.user_connection_ids(requester_id) {
2780        session.peer.send(connection_id, update.clone())?;
2781    }
2782
2783    // Update incoming contact requests of responder
2784    let mut update = proto::UpdateContacts::default();
2785    if contact_accepted {
2786        update.remove_contacts.push(requester_id.to_proto());
2787    } else {
2788        update
2789            .remove_incoming_requests
2790            .push(requester_id.to_proto());
2791    }
2792    for connection_id in pool.user_connection_ids(responder_id) {
2793        session.peer.send(connection_id, update.clone())?;
2794        if let Some(notification_id) = deleted_notification_id {
2795            session.peer.send(
2796                connection_id,
2797                proto::DeleteNotification {
2798                    notification_id: notification_id.to_proto(),
2799                },
2800            )?;
2801        }
2802    }
2803
2804    response.send(proto::Ack {})?;
2805    Ok(())
2806}
2807
2808fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2809    version.0.minor() < 139
2810}
2811
2812async fn subscribe_to_channels(
2813    _: proto::SubscribeToChannels,
2814    session: MessageContext,
2815) -> Result<()> {
2816    subscribe_user_to_channels(session.user_id(), &session).await?;
2817    Ok(())
2818}
2819
2820async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2821    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2822    let mut pool = session.connection_pool().await;
2823    for membership in &channels_for_user.channel_memberships {
2824        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2825    }
2826    session.peer.send(
2827        session.connection_id,
2828        build_update_user_channels(&channels_for_user),
2829    )?;
2830    session.peer.send(
2831        session.connection_id,
2832        build_channels_update(channels_for_user),
2833    )?;
2834    Ok(())
2835}
2836
2837/// Creates a new channel.
2838async fn create_channel(
2839    request: proto::CreateChannel,
2840    response: Response<proto::CreateChannel>,
2841    session: MessageContext,
2842) -> Result<()> {
2843    let db = session.db().await;
2844
2845    let parent_id = request.parent_id.map(ChannelId::from_proto);
2846    let (channel, membership) = db
2847        .create_channel(&request.name, parent_id, session.user_id())
2848        .await?;
2849
2850    let root_id = channel.root_id();
2851    let channel = Channel::from_model(channel);
2852
2853    response.send(proto::CreateChannelResponse {
2854        channel: Some(channel.to_proto()),
2855        parent_id: request.parent_id,
2856    })?;
2857
2858    let mut connection_pool = session.connection_pool().await;
2859    if let Some(membership) = membership {
2860        connection_pool.subscribe_to_channel(
2861            membership.user_id,
2862            membership.channel_id,
2863            membership.role,
2864        );
2865        let update = proto::UpdateUserChannels {
2866            channel_memberships: vec![proto::ChannelMembership {
2867                channel_id: membership.channel_id.to_proto(),
2868                role: membership.role.into(),
2869            }],
2870            ..Default::default()
2871        };
2872        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2873            session.peer.send(connection_id, update.clone())?;
2874        }
2875    }
2876
2877    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2878        if !role.can_see_channel(channel.visibility) {
2879            continue;
2880        }
2881
2882        let update = proto::UpdateChannels {
2883            channels: vec![channel.to_proto()],
2884            ..Default::default()
2885        };
2886        session.peer.send(connection_id, update.clone())?;
2887    }
2888
2889    Ok(())
2890}
2891
2892/// Delete a channel
2893async fn delete_channel(
2894    request: proto::DeleteChannel,
2895    response: Response<proto::DeleteChannel>,
2896    session: MessageContext,
2897) -> Result<()> {
2898    let db = session.db().await;
2899
2900    let channel_id = request.channel_id;
2901    let (root_channel, removed_channels) = db
2902        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2903        .await?;
2904    response.send(proto::Ack {})?;
2905
2906    // Notify members of removed channels
2907    let mut update = proto::UpdateChannels::default();
2908    update
2909        .delete_channels
2910        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2911
2912    let connection_pool = session.connection_pool().await;
2913    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2914        session.peer.send(connection_id, update.clone())?;
2915    }
2916
2917    Ok(())
2918}
2919
2920/// Invite someone to join a channel.
2921async fn invite_channel_member(
2922    request: proto::InviteChannelMember,
2923    response: Response<proto::InviteChannelMember>,
2924    session: MessageContext,
2925) -> Result<()> {
2926    let db = session.db().await;
2927    let channel_id = ChannelId::from_proto(request.channel_id);
2928    let invitee_id = UserId::from_proto(request.user_id);
2929    let InviteMemberResult {
2930        channel,
2931        notifications,
2932    } = db
2933        .invite_channel_member(
2934            channel_id,
2935            invitee_id,
2936            session.user_id(),
2937            request.role().into(),
2938        )
2939        .await?;
2940
2941    let update = proto::UpdateChannels {
2942        channel_invitations: vec![channel.to_proto()],
2943        ..Default::default()
2944    };
2945
2946    let connection_pool = session.connection_pool().await;
2947    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2948        session.peer.send(connection_id, update.clone())?;
2949    }
2950
2951    send_notifications(&connection_pool, &session.peer, notifications);
2952
2953    response.send(proto::Ack {})?;
2954    Ok(())
2955}
2956
2957/// remove someone from a channel
2958async fn remove_channel_member(
2959    request: proto::RemoveChannelMember,
2960    response: Response<proto::RemoveChannelMember>,
2961    session: MessageContext,
2962) -> Result<()> {
2963    let db = session.db().await;
2964    let channel_id = ChannelId::from_proto(request.channel_id);
2965    let member_id = UserId::from_proto(request.user_id);
2966
2967    let RemoveChannelMemberResult {
2968        membership_update,
2969        notification_id,
2970    } = db
2971        .remove_channel_member(channel_id, member_id, session.user_id())
2972        .await?;
2973
2974    let mut connection_pool = session.connection_pool().await;
2975    notify_membership_updated(
2976        &mut connection_pool,
2977        membership_update,
2978        member_id,
2979        &session.peer,
2980    );
2981    for connection_id in connection_pool.user_connection_ids(member_id) {
2982        if let Some(notification_id) = notification_id {
2983            session
2984                .peer
2985                .send(
2986                    connection_id,
2987                    proto::DeleteNotification {
2988                        notification_id: notification_id.to_proto(),
2989                    },
2990                )
2991                .trace_err();
2992        }
2993    }
2994
2995    response.send(proto::Ack {})?;
2996    Ok(())
2997}
2998
2999/// Toggle the channel between public and private.
3000/// Care is taken to maintain the invariant that public channels only descend from public channels,
3001/// (though members-only channels can appear at any point in the hierarchy).
3002async fn set_channel_visibility(
3003    request: proto::SetChannelVisibility,
3004    response: Response<proto::SetChannelVisibility>,
3005    session: MessageContext,
3006) -> Result<()> {
3007    let db = session.db().await;
3008    let channel_id = ChannelId::from_proto(request.channel_id);
3009    let visibility = request.visibility().into();
3010
3011    let channel_model = db
3012        .set_channel_visibility(channel_id, visibility, session.user_id())
3013        .await?;
3014    let root_id = channel_model.root_id();
3015    let channel = Channel::from_model(channel_model);
3016
3017    let mut connection_pool = session.connection_pool().await;
3018    for (user_id, role) in connection_pool
3019        .channel_user_ids(root_id)
3020        .collect::<Vec<_>>()
3021        .into_iter()
3022    {
3023        let update = if role.can_see_channel(channel.visibility) {
3024            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3025            proto::UpdateChannels {
3026                channels: vec![channel.to_proto()],
3027                ..Default::default()
3028            }
3029        } else {
3030            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3031            proto::UpdateChannels {
3032                delete_channels: vec![channel.id.to_proto()],
3033                ..Default::default()
3034            }
3035        };
3036
3037        for connection_id in connection_pool.user_connection_ids(user_id) {
3038            session.peer.send(connection_id, update.clone())?;
3039        }
3040    }
3041
3042    response.send(proto::Ack {})?;
3043    Ok(())
3044}
3045
3046/// Alter the role for a user in the channel.
3047async fn set_channel_member_role(
3048    request: proto::SetChannelMemberRole,
3049    response: Response<proto::SetChannelMemberRole>,
3050    session: MessageContext,
3051) -> Result<()> {
3052    let db = session.db().await;
3053    let channel_id = ChannelId::from_proto(request.channel_id);
3054    let member_id = UserId::from_proto(request.user_id);
3055    let result = db
3056        .set_channel_member_role(
3057            channel_id,
3058            session.user_id(),
3059            member_id,
3060            request.role().into(),
3061        )
3062        .await?;
3063
3064    match result {
3065        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3066            let mut connection_pool = session.connection_pool().await;
3067            notify_membership_updated(
3068                &mut connection_pool,
3069                membership_update,
3070                member_id,
3071                &session.peer,
3072            )
3073        }
3074        db::SetMemberRoleResult::InviteUpdated(channel) => {
3075            let update = proto::UpdateChannels {
3076                channel_invitations: vec![channel.to_proto()],
3077                ..Default::default()
3078            };
3079
3080            for connection_id in session
3081                .connection_pool()
3082                .await
3083                .user_connection_ids(member_id)
3084            {
3085                session.peer.send(connection_id, update.clone())?;
3086            }
3087        }
3088    }
3089
3090    response.send(proto::Ack {})?;
3091    Ok(())
3092}
3093
3094/// Change the name of a channel
3095async fn rename_channel(
3096    request: proto::RenameChannel,
3097    response: Response<proto::RenameChannel>,
3098    session: MessageContext,
3099) -> Result<()> {
3100    let db = session.db().await;
3101    let channel_id = ChannelId::from_proto(request.channel_id);
3102    let channel_model = db
3103        .rename_channel(channel_id, session.user_id(), &request.name)
3104        .await?;
3105    let root_id = channel_model.root_id();
3106    let channel = Channel::from_model(channel_model);
3107
3108    response.send(proto::RenameChannelResponse {
3109        channel: Some(channel.to_proto()),
3110    })?;
3111
3112    let connection_pool = session.connection_pool().await;
3113    let update = proto::UpdateChannels {
3114        channels: vec![channel.to_proto()],
3115        ..Default::default()
3116    };
3117    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3118        if role.can_see_channel(channel.visibility) {
3119            session.peer.send(connection_id, update.clone())?;
3120        }
3121    }
3122
3123    Ok(())
3124}
3125
3126/// Move a channel to a new parent.
3127async fn move_channel(
3128    request: proto::MoveChannel,
3129    response: Response<proto::MoveChannel>,
3130    session: MessageContext,
3131) -> Result<()> {
3132    let channel_id = ChannelId::from_proto(request.channel_id);
3133    let to = ChannelId::from_proto(request.to);
3134
3135    let (root_id, channels) = session
3136        .db()
3137        .await
3138        .move_channel(channel_id, to, session.user_id())
3139        .await?;
3140
3141    let connection_pool = session.connection_pool().await;
3142    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3143        let channels = channels
3144            .iter()
3145            .filter_map(|channel| {
3146                if role.can_see_channel(channel.visibility) {
3147                    Some(channel.to_proto())
3148                } else {
3149                    None
3150                }
3151            })
3152            .collect::<Vec<_>>();
3153        if channels.is_empty() {
3154            continue;
3155        }
3156
3157        let update = proto::UpdateChannels {
3158            channels,
3159            ..Default::default()
3160        };
3161
3162        session.peer.send(connection_id, update.clone())?;
3163    }
3164
3165    response.send(Ack {})?;
3166    Ok(())
3167}
3168
3169async fn reorder_channel(
3170    request: proto::ReorderChannel,
3171    response: Response<proto::ReorderChannel>,
3172    session: MessageContext,
3173) -> Result<()> {
3174    let channel_id = ChannelId::from_proto(request.channel_id);
3175    let direction = request.direction();
3176
3177    let updated_channels = session
3178        .db()
3179        .await
3180        .reorder_channel(channel_id, direction, session.user_id())
3181        .await?;
3182
3183    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3184        let connection_pool = session.connection_pool().await;
3185        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3186            let channels = updated_channels
3187                .iter()
3188                .filter_map(|channel| {
3189                    if role.can_see_channel(channel.visibility) {
3190                        Some(channel.to_proto())
3191                    } else {
3192                        None
3193                    }
3194                })
3195                .collect::<Vec<_>>();
3196
3197            if channels.is_empty() {
3198                continue;
3199            }
3200
3201            let update = proto::UpdateChannels {
3202                channels,
3203                ..Default::default()
3204            };
3205
3206            session.peer.send(connection_id, update.clone())?;
3207        }
3208    }
3209
3210    response.send(Ack {})?;
3211    Ok(())
3212}
3213
3214/// Get the list of channel members
3215async fn get_channel_members(
3216    request: proto::GetChannelMembers,
3217    response: Response<proto::GetChannelMembers>,
3218    session: MessageContext,
3219) -> Result<()> {
3220    let db = session.db().await;
3221    let channel_id = ChannelId::from_proto(request.channel_id);
3222    let limit = if request.limit == 0 {
3223        u16::MAX as u64
3224    } else {
3225        request.limit
3226    };
3227    let (members, users) = db
3228        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3229        .await?;
3230    response.send(proto::GetChannelMembersResponse { members, users })?;
3231    Ok(())
3232}
3233
3234/// Accept or decline a channel invitation.
3235async fn respond_to_channel_invite(
3236    request: proto::RespondToChannelInvite,
3237    response: Response<proto::RespondToChannelInvite>,
3238    session: MessageContext,
3239) -> Result<()> {
3240    let db = session.db().await;
3241    let channel_id = ChannelId::from_proto(request.channel_id);
3242    let RespondToChannelInvite {
3243        membership_update,
3244        notifications,
3245    } = db
3246        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3247        .await?;
3248
3249    let mut connection_pool = session.connection_pool().await;
3250    if let Some(membership_update) = membership_update {
3251        notify_membership_updated(
3252            &mut connection_pool,
3253            membership_update,
3254            session.user_id(),
3255            &session.peer,
3256        );
3257    } else {
3258        let update = proto::UpdateChannels {
3259            remove_channel_invitations: vec![channel_id.to_proto()],
3260            ..Default::default()
3261        };
3262
3263        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3264            session.peer.send(connection_id, update.clone())?;
3265        }
3266    };
3267
3268    send_notifications(&connection_pool, &session.peer, notifications);
3269
3270    response.send(proto::Ack {})?;
3271
3272    Ok(())
3273}
3274
3275/// Join the channels' room
3276async fn join_channel(
3277    request: proto::JoinChannel,
3278    response: Response<proto::JoinChannel>,
3279    session: MessageContext,
3280) -> Result<()> {
3281    let channel_id = ChannelId::from_proto(request.channel_id);
3282    join_channel_internal(channel_id, Box::new(response), session).await
3283}
3284
3285trait JoinChannelInternalResponse {
3286    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3287}
3288impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3289    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3290        Response::<proto::JoinChannel>::send(self, result)
3291    }
3292}
3293impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3294    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3295        Response::<proto::JoinRoom>::send(self, result)
3296    }
3297}
3298
3299async fn join_channel_internal(
3300    channel_id: ChannelId,
3301    response: Box<impl JoinChannelInternalResponse>,
3302    session: MessageContext,
3303) -> Result<()> {
3304    let joined_room = {
3305        let mut db = session.db().await;
3306        // If zed quits without leaving the room, and the user re-opens zed before the
3307        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3308        // room they were in.
3309        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3310            tracing::info!(
3311                stale_connection_id = %connection,
3312                "cleaning up stale connection",
3313            );
3314            drop(db);
3315            leave_room_for_session(&session, connection).await?;
3316            db = session.db().await;
3317        }
3318
3319        let (joined_room, membership_updated, role) = db
3320            .join_channel(channel_id, session.user_id(), session.connection_id)
3321            .await?;
3322
3323        let live_kit_connection_info =
3324            session
3325                .app_state
3326                .livekit_client
3327                .as_ref()
3328                .and_then(|live_kit| {
3329                    let (can_publish, token) = if role == ChannelRole::Guest {
3330                        (
3331                            false,
3332                            live_kit
3333                                .guest_token(
3334                                    &joined_room.room.livekit_room,
3335                                    &session.user_id().to_string(),
3336                                )
3337                                .trace_err()?,
3338                        )
3339                    } else {
3340                        (
3341                            true,
3342                            live_kit
3343                                .room_token(
3344                                    &joined_room.room.livekit_room,
3345                                    &session.user_id().to_string(),
3346                                )
3347                                .trace_err()?,
3348                        )
3349                    };
3350
3351                    Some(LiveKitConnectionInfo {
3352                        server_url: live_kit.url().into(),
3353                        token,
3354                        can_publish,
3355                    })
3356                });
3357
3358        response.send(proto::JoinRoomResponse {
3359            room: Some(joined_room.room.clone()),
3360            channel_id: joined_room
3361                .channel
3362                .as_ref()
3363                .map(|channel| channel.id.to_proto()),
3364            live_kit_connection_info,
3365        })?;
3366
3367        let mut connection_pool = session.connection_pool().await;
3368        if let Some(membership_updated) = membership_updated {
3369            notify_membership_updated(
3370                &mut connection_pool,
3371                membership_updated,
3372                session.user_id(),
3373                &session.peer,
3374            );
3375        }
3376
3377        room_updated(&joined_room.room, &session.peer);
3378
3379        joined_room
3380    };
3381
3382    channel_updated(
3383        &joined_room.channel.context("channel not returned")?,
3384        &joined_room.room,
3385        &session.peer,
3386        &*session.connection_pool().await,
3387    );
3388
3389    update_user_contacts(session.user_id(), &session).await?;
3390    Ok(())
3391}
3392
3393/// Start editing the channel notes
3394async fn join_channel_buffer(
3395    request: proto::JoinChannelBuffer,
3396    response: Response<proto::JoinChannelBuffer>,
3397    session: MessageContext,
3398) -> Result<()> {
3399    let db = session.db().await;
3400    let channel_id = ChannelId::from_proto(request.channel_id);
3401
3402    let open_response = db
3403        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3404        .await?;
3405
3406    let collaborators = open_response.collaborators.clone();
3407    response.send(open_response)?;
3408
3409    let update = UpdateChannelBufferCollaborators {
3410        channel_id: channel_id.to_proto(),
3411        collaborators: collaborators.clone(),
3412    };
3413    channel_buffer_updated(
3414        session.connection_id,
3415        collaborators
3416            .iter()
3417            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3418        &update,
3419        &session.peer,
3420    );
3421
3422    Ok(())
3423}
3424
3425/// Edit the channel notes
3426async fn update_channel_buffer(
3427    request: proto::UpdateChannelBuffer,
3428    session: MessageContext,
3429) -> Result<()> {
3430    let db = session.db().await;
3431    let channel_id = ChannelId::from_proto(request.channel_id);
3432
3433    let (collaborators, epoch, version) = db
3434        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3435        .await?;
3436
3437    channel_buffer_updated(
3438        session.connection_id,
3439        collaborators.clone(),
3440        &proto::UpdateChannelBuffer {
3441            channel_id: channel_id.to_proto(),
3442            operations: request.operations,
3443        },
3444        &session.peer,
3445    );
3446
3447    let pool = &*session.connection_pool().await;
3448
3449    let non_collaborators =
3450        pool.channel_connection_ids(channel_id)
3451            .filter_map(|(connection_id, _)| {
3452                if collaborators.contains(&connection_id) {
3453                    None
3454                } else {
3455                    Some(connection_id)
3456                }
3457            });
3458
3459    broadcast(None, non_collaborators, |peer_id| {
3460        session.peer.send(
3461            peer_id,
3462            proto::UpdateChannels {
3463                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3464                    channel_id: channel_id.to_proto(),
3465                    epoch: epoch as u64,
3466                    version: version.clone(),
3467                }],
3468                ..Default::default()
3469            },
3470        )
3471    });
3472
3473    Ok(())
3474}
3475
3476/// Rejoin the channel notes after a connection blip
3477async fn rejoin_channel_buffers(
3478    request: proto::RejoinChannelBuffers,
3479    response: Response<proto::RejoinChannelBuffers>,
3480    session: MessageContext,
3481) -> Result<()> {
3482    let db = session.db().await;
3483    let buffers = db
3484        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3485        .await?;
3486
3487    for rejoined_buffer in &buffers {
3488        let collaborators_to_notify = rejoined_buffer
3489            .buffer
3490            .collaborators
3491            .iter()
3492            .filter_map(|c| Some(c.peer_id?.into()));
3493        channel_buffer_updated(
3494            session.connection_id,
3495            collaborators_to_notify,
3496            &proto::UpdateChannelBufferCollaborators {
3497                channel_id: rejoined_buffer.buffer.channel_id,
3498                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3499            },
3500            &session.peer,
3501        );
3502    }
3503
3504    response.send(proto::RejoinChannelBuffersResponse {
3505        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3506    })?;
3507
3508    Ok(())
3509}
3510
3511/// Stop editing the channel notes
3512async fn leave_channel_buffer(
3513    request: proto::LeaveChannelBuffer,
3514    response: Response<proto::LeaveChannelBuffer>,
3515    session: MessageContext,
3516) -> Result<()> {
3517    let db = session.db().await;
3518    let channel_id = ChannelId::from_proto(request.channel_id);
3519
3520    let left_buffer = db
3521        .leave_channel_buffer(channel_id, session.connection_id)
3522        .await?;
3523
3524    response.send(Ack {})?;
3525
3526    channel_buffer_updated(
3527        session.connection_id,
3528        left_buffer.connections,
3529        &proto::UpdateChannelBufferCollaborators {
3530            channel_id: channel_id.to_proto(),
3531            collaborators: left_buffer.collaborators,
3532        },
3533        &session.peer,
3534    );
3535
3536    Ok(())
3537}
3538
3539fn channel_buffer_updated<T: EnvelopedMessage>(
3540    sender_id: ConnectionId,
3541    collaborators: impl IntoIterator<Item = ConnectionId>,
3542    message: &T,
3543    peer: &Peer,
3544) {
3545    broadcast(Some(sender_id), collaborators, |peer_id| {
3546        peer.send(peer_id, message.clone())
3547    });
3548}
3549
3550fn send_notifications(
3551    connection_pool: &ConnectionPool,
3552    peer: &Peer,
3553    notifications: db::NotificationBatch,
3554) {
3555    for (user_id, notification) in notifications {
3556        for connection_id in connection_pool.user_connection_ids(user_id) {
3557            if let Err(error) = peer.send(
3558                connection_id,
3559                proto::AddNotification {
3560                    notification: Some(notification.clone()),
3561                },
3562            ) {
3563                tracing::error!(
3564                    "failed to send notification to {:?} {}",
3565                    connection_id,
3566                    error
3567                );
3568            }
3569        }
3570    }
3571}
3572
3573/// Send a message to the channel
3574async fn send_channel_message(
3575    _request: proto::SendChannelMessage,
3576    _response: Response<proto::SendChannelMessage>,
3577    _session: MessageContext,
3578) -> Result<()> {
3579    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3580}
3581
3582/// Delete a channel message
3583async fn remove_channel_message(
3584    _request: proto::RemoveChannelMessage,
3585    _response: Response<proto::RemoveChannelMessage>,
3586    _session: MessageContext,
3587) -> Result<()> {
3588    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3589}
3590
3591async fn update_channel_message(
3592    _request: proto::UpdateChannelMessage,
3593    _response: Response<proto::UpdateChannelMessage>,
3594    _session: MessageContext,
3595) -> Result<()> {
3596    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3597}
3598
3599/// Mark a channel message as read
3600async fn acknowledge_channel_message(
3601    _request: proto::AckChannelMessage,
3602    _session: MessageContext,
3603) -> Result<()> {
3604    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3605}
3606
3607/// Mark a buffer version as synced
3608async fn acknowledge_buffer_version(
3609    request: proto::AckBufferOperation,
3610    session: MessageContext,
3611) -> Result<()> {
3612    let buffer_id = BufferId::from_proto(request.buffer_id);
3613    session
3614        .db()
3615        .await
3616        .observe_buffer_version(
3617            buffer_id,
3618            session.user_id(),
3619            request.epoch as i32,
3620            &request.version,
3621        )
3622        .await?;
3623    Ok(())
3624}
3625
3626/// Get a Supermaven API key for the user
3627async fn get_supermaven_api_key(
3628    _request: proto::GetSupermavenApiKey,
3629    response: Response<proto::GetSupermavenApiKey>,
3630    session: MessageContext,
3631) -> Result<()> {
3632    let user_id: String = session.user_id().to_string();
3633    if !session.is_staff() {
3634        return Err(anyhow!("supermaven not enabled for this account"))?;
3635    }
3636
3637    let email = session.email().context("user must have an email")?;
3638
3639    let supermaven_admin_api = session
3640        .supermaven_client
3641        .as_ref()
3642        .context("supermaven not configured")?;
3643
3644    let result = supermaven_admin_api
3645        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3646        .await?;
3647
3648    response.send(proto::GetSupermavenApiKeyResponse {
3649        api_key: result.api_key,
3650    })?;
3651
3652    Ok(())
3653}
3654
3655/// Start receiving chat updates for a channel
3656async fn join_channel_chat(
3657    _request: proto::JoinChannelChat,
3658    _response: Response<proto::JoinChannelChat>,
3659    _session: MessageContext,
3660) -> Result<()> {
3661    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3662}
3663
3664/// Stop receiving chat updates for a channel
3665async fn leave_channel_chat(
3666    _request: proto::LeaveChannelChat,
3667    _session: MessageContext,
3668) -> Result<()> {
3669    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3670}
3671
3672/// Retrieve the chat history for a channel
3673async fn get_channel_messages(
3674    _request: proto::GetChannelMessages,
3675    _response: Response<proto::GetChannelMessages>,
3676    _session: MessageContext,
3677) -> Result<()> {
3678    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3679}
3680
3681/// Retrieve specific chat messages
3682async fn get_channel_messages_by_id(
3683    _request: proto::GetChannelMessagesById,
3684    _response: Response<proto::GetChannelMessagesById>,
3685    _session: MessageContext,
3686) -> Result<()> {
3687    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3688}
3689
3690/// Retrieve the current users notifications
3691async fn get_notifications(
3692    request: proto::GetNotifications,
3693    response: Response<proto::GetNotifications>,
3694    session: MessageContext,
3695) -> Result<()> {
3696    let notifications = session
3697        .db()
3698        .await
3699        .get_notifications(
3700            session.user_id(),
3701            NOTIFICATION_COUNT_PER_PAGE,
3702            request.before_id.map(db::NotificationId::from_proto),
3703        )
3704        .await?;
3705    response.send(proto::GetNotificationsResponse {
3706        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3707        notifications,
3708    })?;
3709    Ok(())
3710}
3711
3712/// Mark notifications as read
3713async fn mark_notification_as_read(
3714    request: proto::MarkNotificationRead,
3715    response: Response<proto::MarkNotificationRead>,
3716    session: MessageContext,
3717) -> Result<()> {
3718    let database = &session.db().await;
3719    let notifications = database
3720        .mark_notification_as_read_by_id(
3721            session.user_id(),
3722            NotificationId::from_proto(request.notification_id),
3723        )
3724        .await?;
3725    send_notifications(
3726        &*session.connection_pool().await,
3727        &session.peer,
3728        notifications,
3729    );
3730    response.send(proto::Ack {})?;
3731    Ok(())
3732}
3733
3734fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3735    let message = match message {
3736        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3737        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3738        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3739        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3740        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3741            code: frame.code.into(),
3742            reason: frame.reason.as_str().to_owned().into(),
3743        })),
3744        // We should never receive a frame while reading the message, according
3745        // to the `tungstenite` maintainers:
3746        //
3747        // > It cannot occur when you read messages from the WebSocket, but it
3748        // > can be used when you want to send the raw frames (e.g. you want to
3749        // > send the frames to the WebSocket without composing the full message first).
3750        // >
3751        // > — https://github.com/snapview/tungstenite-rs/issues/268
3752        TungsteniteMessage::Frame(_) => {
3753            bail!("received an unexpected frame while reading the message")
3754        }
3755    };
3756
3757    Ok(message)
3758}
3759
3760fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3761    match message {
3762        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3763        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3764        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3765        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3766        AxumMessage::Close(frame) => {
3767            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3768                code: frame.code.into(),
3769                reason: frame.reason.as_ref().into(),
3770            }))
3771        }
3772    }
3773}
3774
3775fn notify_membership_updated(
3776    connection_pool: &mut ConnectionPool,
3777    result: MembershipUpdated,
3778    user_id: UserId,
3779    peer: &Peer,
3780) {
3781    for membership in &result.new_channels.channel_memberships {
3782        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3783    }
3784    for channel_id in &result.removed_channels {
3785        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3786    }
3787
3788    let user_channels_update = proto::UpdateUserChannels {
3789        channel_memberships: result
3790            .new_channels
3791            .channel_memberships
3792            .iter()
3793            .map(|cm| proto::ChannelMembership {
3794                channel_id: cm.channel_id.to_proto(),
3795                role: cm.role.into(),
3796            })
3797            .collect(),
3798        ..Default::default()
3799    };
3800
3801    let mut update = build_channels_update(result.new_channels);
3802    update.delete_channels = result
3803        .removed_channels
3804        .into_iter()
3805        .map(|id| id.to_proto())
3806        .collect();
3807    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3808
3809    for connection_id in connection_pool.user_connection_ids(user_id) {
3810        peer.send(connection_id, user_channels_update.clone())
3811            .trace_err();
3812        peer.send(connection_id, update.clone()).trace_err();
3813    }
3814}
3815
3816fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3817    proto::UpdateUserChannels {
3818        channel_memberships: channels
3819            .channel_memberships
3820            .iter()
3821            .map(|m| proto::ChannelMembership {
3822                channel_id: m.channel_id.to_proto(),
3823                role: m.role.into(),
3824            })
3825            .collect(),
3826        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3827    }
3828}
3829
3830fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3831    let mut update = proto::UpdateChannels::default();
3832
3833    for channel in channels.channels {
3834        update.channels.push(channel.to_proto());
3835    }
3836
3837    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3838
3839    for (channel_id, participants) in channels.channel_participants {
3840        update
3841            .channel_participants
3842            .push(proto::ChannelParticipants {
3843                channel_id: channel_id.to_proto(),
3844                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3845            });
3846    }
3847
3848    for channel in channels.invited_channels {
3849        update.channel_invitations.push(channel.to_proto());
3850    }
3851
3852    update
3853}
3854
3855fn build_initial_contacts_update(
3856    contacts: Vec<db::Contact>,
3857    pool: &ConnectionPool,
3858) -> proto::UpdateContacts {
3859    let mut update = proto::UpdateContacts::default();
3860
3861    for contact in contacts {
3862        match contact {
3863            db::Contact::Accepted { user_id, busy } => {
3864                update.contacts.push(contact_for_user(user_id, busy, pool));
3865            }
3866            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3867            db::Contact::Incoming { user_id } => {
3868                update
3869                    .incoming_requests
3870                    .push(proto::IncomingContactRequest {
3871                        requester_id: user_id.to_proto(),
3872                    })
3873            }
3874        }
3875    }
3876
3877    update
3878}
3879
3880fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3881    proto::Contact {
3882        user_id: user_id.to_proto(),
3883        online: pool.is_user_online(user_id),
3884        busy,
3885    }
3886}
3887
3888fn room_updated(room: &proto::Room, peer: &Peer) {
3889    broadcast(
3890        None,
3891        room.participants
3892            .iter()
3893            .filter_map(|participant| Some(participant.peer_id?.into())),
3894        |peer_id| {
3895            peer.send(
3896                peer_id,
3897                proto::RoomUpdated {
3898                    room: Some(room.clone()),
3899                },
3900            )
3901        },
3902    );
3903}
3904
3905fn channel_updated(
3906    channel: &db::channel::Model,
3907    room: &proto::Room,
3908    peer: &Peer,
3909    pool: &ConnectionPool,
3910) {
3911    let participants = room
3912        .participants
3913        .iter()
3914        .map(|p| p.user_id)
3915        .collect::<Vec<_>>();
3916
3917    broadcast(
3918        None,
3919        pool.channel_connection_ids(channel.root_id())
3920            .filter_map(|(channel_id, role)| {
3921                role.can_see_channel(channel.visibility)
3922                    .then_some(channel_id)
3923            }),
3924        |peer_id| {
3925            peer.send(
3926                peer_id,
3927                proto::UpdateChannels {
3928                    channel_participants: vec![proto::ChannelParticipants {
3929                        channel_id: channel.id.to_proto(),
3930                        participant_user_ids: participants.clone(),
3931                    }],
3932                    ..Default::default()
3933                },
3934            )
3935        },
3936    );
3937}
3938
3939async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3940    let db = session.db().await;
3941
3942    let contacts = db.get_contacts(user_id).await?;
3943    let busy = db.is_user_busy(user_id).await?;
3944
3945    let pool = session.connection_pool().await;
3946    let updated_contact = contact_for_user(user_id, busy, &pool);
3947    for contact in contacts {
3948        if let db::Contact::Accepted {
3949            user_id: contact_user_id,
3950            ..
3951        } = contact
3952        {
3953            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3954                session
3955                    .peer
3956                    .send(
3957                        contact_conn_id,
3958                        proto::UpdateContacts {
3959                            contacts: vec![updated_contact.clone()],
3960                            remove_contacts: Default::default(),
3961                            incoming_requests: Default::default(),
3962                            remove_incoming_requests: Default::default(),
3963                            outgoing_requests: Default::default(),
3964                            remove_outgoing_requests: Default::default(),
3965                        },
3966                    )
3967                    .trace_err();
3968            }
3969        }
3970    }
3971    Ok(())
3972}
3973
3974async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3975    let mut contacts_to_update = HashSet::default();
3976
3977    let room_id;
3978    let canceled_calls_to_user_ids;
3979    let livekit_room;
3980    let delete_livekit_room;
3981    let room;
3982    let channel;
3983
3984    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3985        contacts_to_update.insert(session.user_id());
3986
3987        for project in left_room.left_projects.values() {
3988            project_left(project, session);
3989        }
3990
3991        room_id = RoomId::from_proto(left_room.room.id);
3992        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3993        livekit_room = mem::take(&mut left_room.room.livekit_room);
3994        delete_livekit_room = left_room.deleted;
3995        room = mem::take(&mut left_room.room);
3996        channel = mem::take(&mut left_room.channel);
3997
3998        room_updated(&room, &session.peer);
3999    } else {
4000        return Ok(());
4001    }
4002
4003    if let Some(channel) = channel {
4004        channel_updated(
4005            &channel,
4006            &room,
4007            &session.peer,
4008            &*session.connection_pool().await,
4009        );
4010    }
4011
4012    {
4013        let pool = session.connection_pool().await;
4014        for canceled_user_id in canceled_calls_to_user_ids {
4015            for connection_id in pool.user_connection_ids(canceled_user_id) {
4016                session
4017                    .peer
4018                    .send(
4019                        connection_id,
4020                        proto::CallCanceled {
4021                            room_id: room_id.to_proto(),
4022                        },
4023                    )
4024                    .trace_err();
4025            }
4026            contacts_to_update.insert(canceled_user_id);
4027        }
4028    }
4029
4030    for contact_user_id in contacts_to_update {
4031        update_user_contacts(contact_user_id, session).await?;
4032    }
4033
4034    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4035        live_kit
4036            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4037            .await
4038            .trace_err();
4039
4040        if delete_livekit_room {
4041            live_kit.delete_room(livekit_room).await.trace_err();
4042        }
4043    }
4044
4045    Ok(())
4046}
4047
4048async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4049    let left_channel_buffers = session
4050        .db()
4051        .await
4052        .leave_channel_buffers(session.connection_id)
4053        .await?;
4054
4055    for left_buffer in left_channel_buffers {
4056        channel_buffer_updated(
4057            session.connection_id,
4058            left_buffer.connections,
4059            &proto::UpdateChannelBufferCollaborators {
4060                channel_id: left_buffer.channel_id.to_proto(),
4061                collaborators: left_buffer.collaborators,
4062            },
4063            &session.peer,
4064        );
4065    }
4066
4067    Ok(())
4068}
4069
4070fn project_left(project: &db::LeftProject, session: &Session) {
4071    for connection_id in &project.connection_ids {
4072        if project.should_unshare {
4073            session
4074                .peer
4075                .send(
4076                    *connection_id,
4077                    proto::UnshareProject {
4078                        project_id: project.id.to_proto(),
4079                    },
4080                )
4081                .trace_err();
4082        } else {
4083            session
4084                .peer
4085                .send(
4086                    *connection_id,
4087                    proto::RemoveProjectCollaborator {
4088                        project_id: project.id.to_proto(),
4089                        peer_id: Some(session.connection_id.into()),
4090                    },
4091                )
4092                .trace_err();
4093        }
4094    }
4095}
4096
4097pub trait ResultExt {
4098    type Ok;
4099
4100    fn trace_err(self) -> Option<Self::Ok>;
4101}
4102
4103impl<T, E> ResultExt for Result<T, E>
4104where
4105    E: std::fmt::Debug,
4106{
4107    type Ok = T;
4108
4109    #[track_caller]
4110    fn trace_err(self) -> Option<T> {
4111        match self {
4112            Ok(value) => Some(value),
4113            Err(error) => {
4114                tracing::error!("{:?}", error);
4115                None
4116            }
4117        }
4118    }
4119}