rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        Arc, OnceLock,
  64        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  65    },
  66    time::{Duration, Instant},
  67};
  68use tokio::sync::{Semaphore, watch};
  69use tower::ServiceBuilder;
  70use tracing::{
  71    Instrument,
  72    field::{self},
  73    info_span, instrument,
  74};
  75
  76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  77
  78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  80
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  83
  84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  85
  86const TOTAL_DURATION_MS: &str = "total_duration_ms";
  87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  89const HOST_WAITING_MS: &str = "host_waiting_ms";
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  93
  94pub struct ConnectionGuard;
  95
  96impl ConnectionGuard {
  97    pub fn try_acquire() -> Result<Self, ()> {
  98        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  99        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 100            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 101            tracing::error!(
 102                "too many concurrent connections: {}",
 103                current_connections + 1
 104            );
 105            return Err(());
 106        }
 107        Ok(ConnectionGuard)
 108    }
 109}
 110
 111impl Drop for ConnectionGuard {
 112    fn drop(&mut self) {
 113        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114    }
 115}
 116
 117struct Response<R> {
 118    peer: Arc<Peer>,
 119    receipt: Receipt<R>,
 120    responded: Arc<AtomicBool>,
 121}
 122
 123impl<R: RequestMessage> Response<R> {
 124    fn send(self, payload: R::Response) -> Result<()> {
 125        self.responded.store(true, SeqCst);
 126        self.peer.respond(self.receipt, payload)?;
 127        Ok(())
 128    }
 129}
 130
 131#[derive(Clone, Debug)]
 132pub enum Principal {
 133    User(User),
 134    Impersonated { user: User, admin: User },
 135}
 136
 137impl Principal {
 138    fn update_span(&self, span: &tracing::Span) {
 139        match &self {
 140            Principal::User(user) => {
 141                span.record("user_id", user.id.0);
 142                span.record("login", &user.github_login);
 143            }
 144            Principal::Impersonated { user, admin } => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147                span.record("impersonator", &admin.github_login);
 148            }
 149        }
 150    }
 151}
 152
 153#[derive(Clone)]
 154struct MessageContext {
 155    session: Session,
 156    span: tracing::Span,
 157}
 158
 159impl Deref for MessageContext {
 160    type Target = Session;
 161
 162    fn deref(&self) -> &Self::Target {
 163        &self.session
 164    }
 165}
 166
 167impl MessageContext {
 168    pub fn forward_request<T: RequestMessage>(
 169        &self,
 170        receiver_id: ConnectionId,
 171        request: T,
 172    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 173        let request_start_time = Instant::now();
 174        let span = self.span.clone();
 175        tracing::info!("start forwarding request");
 176        self.peer
 177            .forward_request(self.connection_id, receiver_id, request)
 178            .inspect(move |_| {
 179                span.record(
 180                    HOST_WAITING_MS,
 181                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 182                );
 183            })
 184            .inspect_err(|_| tracing::error!("error forwarding request"))
 185            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 186    }
 187}
 188
 189#[derive(Clone)]
 190struct Session {
 191    principal: Principal,
 192    connection_id: ConnectionId,
 193    db: Arc<tokio::sync::Mutex<DbHandle>>,
 194    peer: Arc<Peer>,
 195    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 196    app_state: Arc<AppState>,
 197    /// The GeoIP country code for the user.
 198    #[allow(unused)]
 199    geoip_country_code: Option<String>,
 200    #[allow(unused)]
 201    system_id: Option<String>,
 202    _executor: Executor,
 203}
 204
 205impl Session {
 206    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 207        #[cfg(test)]
 208        tokio::task::yield_now().await;
 209        let guard = self.db.lock().await;
 210        #[cfg(test)]
 211        tokio::task::yield_now().await;
 212        guard
 213    }
 214
 215    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 216        #[cfg(test)]
 217        tokio::task::yield_now().await;
 218        let guard = self.connection_pool.lock();
 219        ConnectionPoolGuard {
 220            guard,
 221            _not_send: PhantomData,
 222        }
 223    }
 224
 225    #[expect(dead_code)]
 226    fn is_staff(&self) -> bool {
 227        match &self.principal {
 228            Principal::User(user) => user.admin,
 229            Principal::Impersonated { .. } => true,
 230        }
 231    }
 232
 233    fn user_id(&self) -> UserId {
 234        match &self.principal {
 235            Principal::User(user) => user.id,
 236            Principal::Impersonated { user, .. } => user.id,
 237        }
 238    }
 239}
 240
 241impl Debug for Session {
 242    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 243        let mut result = f.debug_struct("Session");
 244        match &self.principal {
 245            Principal::User(user) => {
 246                result.field("user", &user.github_login);
 247            }
 248            Principal::Impersonated { user, admin } => {
 249                result.field("user", &user.github_login);
 250                result.field("impersonator", &admin.github_login);
 251            }
 252        }
 253        result.field("connection_id", &self.connection_id).finish()
 254    }
 255}
 256
 257struct DbHandle(Arc<Database>);
 258
 259impl Deref for DbHandle {
 260    type Target = Database;
 261
 262    fn deref(&self) -> &Self::Target {
 263        self.0.as_ref()
 264    }
 265}
 266
 267pub struct Server {
 268    id: parking_lot::Mutex<ServerId>,
 269    peer: Arc<Peer>,
 270    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 271    app_state: Arc<AppState>,
 272    handlers: HashMap<TypeId, MessageHandler>,
 273    teardown: watch::Sender<bool>,
 274}
 275
 276pub(crate) struct ConnectionPoolGuard<'a> {
 277    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 278    _not_send: PhantomData<Rc<()>>,
 279}
 280
 281#[derive(Serialize)]
 282pub struct ServerSnapshot<'a> {
 283    peer: &'a Peer,
 284    #[serde(serialize_with = "serialize_deref")]
 285    connection_pool: ConnectionPoolGuard<'a>,
 286}
 287
 288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 289where
 290    S: Serializer,
 291    T: Deref<Target = U>,
 292    U: Serialize,
 293{
 294    Serialize::serialize(value.deref(), serializer)
 295}
 296
 297impl Server {
 298    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 299        let mut server = Self {
 300            id: parking_lot::Mutex::new(id),
 301            peer: Peer::new(id.0 as u32),
 302            app_state,
 303            connection_pool: Default::default(),
 304            handlers: Default::default(),
 305            teardown: watch::channel(false).0,
 306        };
 307
 308        server
 309            .add_request_handler(ping)
 310            .add_request_handler(create_room)
 311            .add_request_handler(join_room)
 312            .add_request_handler(rejoin_room)
 313            .add_request_handler(leave_room)
 314            .add_request_handler(set_room_participant_role)
 315            .add_request_handler(call)
 316            .add_request_handler(cancel_call)
 317            .add_message_handler(decline_call)
 318            .add_request_handler(update_participant_location)
 319            .add_request_handler(share_project)
 320            .add_message_handler(unshare_project)
 321            .add_request_handler(join_project)
 322            .add_message_handler(leave_project)
 323            .add_request_handler(update_project)
 324            .add_request_handler(update_worktree)
 325            .add_request_handler(update_repository)
 326            .add_request_handler(remove_repository)
 327            .add_message_handler(start_language_server)
 328            .add_message_handler(update_language_server)
 329            .add_message_handler(update_diagnostic_summary)
 330            .add_message_handler(update_worktree_settings)
 331            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 332            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 333            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 334            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 335            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 336            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 337            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 338            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 340            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 341            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 346            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 347            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 348            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 349            .add_request_handler(
 350                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 351            )
 352            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 355            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 356            .add_request_handler(
 357                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 358            )
 359            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 360            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 361            .add_request_handler(
 362                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 363            )
 364            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 365            .add_request_handler(
 366                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 367            )
 368            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 369            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 370            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 371            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 372            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 373            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 374            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 375            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 376            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 377            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 378            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 379            .add_request_handler(
 380                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 381            )
 382            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 383            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 384            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 385            .add_request_handler(lsp_query)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 387            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 388            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 390            .add_message_handler(create_buffer_for_peer)
 391            .add_message_handler(create_image_for_peer)
 392            .add_request_handler(update_buffer)
 393            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 395            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 396            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 397            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 398            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 399            .add_message_handler(
 400                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 401            )
 402            .add_request_handler(get_users)
 403            .add_request_handler(fuzzy_search_users)
 404            .add_request_handler(request_contact)
 405            .add_request_handler(remove_contact)
 406            .add_request_handler(respond_to_contact_request)
 407            .add_message_handler(subscribe_to_channels)
 408            .add_request_handler(create_channel)
 409            .add_request_handler(delete_channel)
 410            .add_request_handler(invite_channel_member)
 411            .add_request_handler(remove_channel_member)
 412            .add_request_handler(set_channel_member_role)
 413            .add_request_handler(set_channel_visibility)
 414            .add_request_handler(rename_channel)
 415            .add_request_handler(join_channel_buffer)
 416            .add_request_handler(leave_channel_buffer)
 417            .add_message_handler(update_channel_buffer)
 418            .add_request_handler(rejoin_channel_buffers)
 419            .add_request_handler(get_channel_members)
 420            .add_request_handler(respond_to_channel_invite)
 421            .add_request_handler(join_channel)
 422            .add_request_handler(join_channel_chat)
 423            .add_message_handler(leave_channel_chat)
 424            .add_request_handler(send_channel_message)
 425            .add_request_handler(remove_channel_message)
 426            .add_request_handler(update_channel_message)
 427            .add_request_handler(get_channel_messages)
 428            .add_request_handler(get_channel_messages_by_id)
 429            .add_request_handler(get_notifications)
 430            .add_request_handler(mark_notification_as_read)
 431            .add_request_handler(move_channel)
 432            .add_request_handler(reorder_channel)
 433            .add_request_handler(follow)
 434            .add_message_handler(unfollow)
 435            .add_message_handler(update_followers)
 436            .add_message_handler(acknowledge_channel_message)
 437            .add_message_handler(acknowledge_buffer_version)
 438            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 439            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 440            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 441            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 442            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 443            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 444            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 445            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 446            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 447            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 449            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 450            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 451            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 452            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 453            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 454            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 455            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 456            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 457            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 458            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 459            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 460            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 461            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 465            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 466            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 467            .add_message_handler(update_context)
 468            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 469            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 470            .add_request_handler(share_agent_thread)
 471            .add_request_handler(get_shared_agent_thread)
 472            .add_request_handler(forward_project_search_chunk);
 473
 474        Arc::new(server)
 475    }
 476
 477    pub async fn start(&self) -> Result<()> {
 478        let server_id = *self.id.lock();
 479        let app_state = self.app_state.clone();
 480        let peer = self.peer.clone();
 481        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 482        let pool = self.connection_pool.clone();
 483        let livekit_client = self.app_state.livekit_client.clone();
 484
 485        let span = info_span!("start server");
 486        self.app_state.executor.spawn_detached(
 487            async move {
 488                tracing::info!("waiting for cleanup timeout");
 489                timeout.await;
 490                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 491
 492                app_state
 493                    .db
 494                    .delete_stale_channel_chat_participants(
 495                        &app_state.config.zed_environment,
 496                        server_id,
 497                    )
 498                    .await
 499                    .trace_err();
 500
 501                if let Some((room_ids, channel_ids)) = app_state
 502                    .db
 503                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 504                    .await
 505                    .trace_err()
 506                {
 507                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 508                    tracing::info!(
 509                        stale_channel_buffer_count = channel_ids.len(),
 510                        "retrieved stale channel buffers"
 511                    );
 512
 513                    for channel_id in channel_ids {
 514                        if let Some(refreshed_channel_buffer) = app_state
 515                            .db
 516                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 517                            .await
 518                            .trace_err()
 519                        {
 520                            for connection_id in refreshed_channel_buffer.connection_ids {
 521                                peer.send(
 522                                    connection_id,
 523                                    proto::UpdateChannelBufferCollaborators {
 524                                        channel_id: channel_id.to_proto(),
 525                                        collaborators: refreshed_channel_buffer
 526                                            .collaborators
 527                                            .clone(),
 528                                    },
 529                                )
 530                                .trace_err();
 531                            }
 532                        }
 533                    }
 534
 535                    for room_id in room_ids {
 536                        let mut contacts_to_update = HashSet::default();
 537                        let mut canceled_calls_to_user_ids = Vec::new();
 538                        let mut livekit_room = String::new();
 539                        let mut delete_livekit_room = false;
 540
 541                        if let Some(mut refreshed_room) = app_state
 542                            .db
 543                            .clear_stale_room_participants(room_id, server_id)
 544                            .await
 545                            .trace_err()
 546                        {
 547                            tracing::info!(
 548                                room_id = room_id.0,
 549                                new_participant_count = refreshed_room.room.participants.len(),
 550                                "refreshed room"
 551                            );
 552                            room_updated(&refreshed_room.room, &peer);
 553                            if let Some(channel) = refreshed_room.channel.as_ref() {
 554                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 555                            }
 556                            contacts_to_update
 557                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 558                            contacts_to_update
 559                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 560                            canceled_calls_to_user_ids =
 561                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 562                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 563                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 564                        }
 565
 566                        {
 567                            let pool = pool.lock();
 568                            for canceled_user_id in canceled_calls_to_user_ids {
 569                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 570                                    peer.send(
 571                                        connection_id,
 572                                        proto::CallCanceled {
 573                                            room_id: room_id.to_proto(),
 574                                        },
 575                                    )
 576                                    .trace_err();
 577                                }
 578                            }
 579                        }
 580
 581                        for user_id in contacts_to_update {
 582                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 583                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 584                            if let Some((busy, contacts)) = busy.zip(contacts) {
 585                                let pool = pool.lock();
 586                                let updated_contact = contact_for_user(user_id, busy, &pool);
 587                                for contact in contacts {
 588                                    if let db::Contact::Accepted {
 589                                        user_id: contact_user_id,
 590                                        ..
 591                                    } = contact
 592                                    {
 593                                        for contact_conn_id in
 594                                            pool.user_connection_ids(contact_user_id)
 595                                        {
 596                                            peer.send(
 597                                                contact_conn_id,
 598                                                proto::UpdateContacts {
 599                                                    contacts: vec![updated_contact.clone()],
 600                                                    remove_contacts: Default::default(),
 601                                                    incoming_requests: Default::default(),
 602                                                    remove_incoming_requests: Default::default(),
 603                                                    outgoing_requests: Default::default(),
 604                                                    remove_outgoing_requests: Default::default(),
 605                                                },
 606                                            )
 607                                            .trace_err();
 608                                        }
 609                                    }
 610                                }
 611                            }
 612                        }
 613
 614                        if let Some(live_kit) = livekit_client.as_ref()
 615                            && delete_livekit_room
 616                        {
 617                            live_kit.delete_room(livekit_room).await.trace_err();
 618                        }
 619                    }
 620                }
 621
 622                app_state
 623                    .db
 624                    .delete_stale_channel_chat_participants(
 625                        &app_state.config.zed_environment,
 626                        server_id,
 627                    )
 628                    .await
 629                    .trace_err();
 630
 631                app_state
 632                    .db
 633                    .clear_old_worktree_entries(server_id)
 634                    .await
 635                    .trace_err();
 636
 637                app_state
 638                    .db
 639                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 640                    .await
 641                    .trace_err();
 642            }
 643            .instrument(span),
 644        );
 645        Ok(())
 646    }
 647
 648    pub fn teardown(&self) {
 649        self.peer.teardown();
 650        self.connection_pool.lock().reset();
 651        let _ = self.teardown.send(true);
 652    }
 653
 654    #[cfg(test)]
 655    pub fn reset(&self, id: ServerId) {
 656        self.teardown();
 657        *self.id.lock() = id;
 658        self.peer.reset(id.0 as u32);
 659        let _ = self.teardown.send(false);
 660    }
 661
 662    #[cfg(test)]
 663    pub fn id(&self) -> ServerId {
 664        *self.id.lock()
 665    }
 666
 667    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 668    where
 669        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 670        Fut: 'static + Send + Future<Output = Result<()>>,
 671        M: EnvelopedMessage,
 672    {
 673        let prev_handler = self.handlers.insert(
 674            TypeId::of::<M>(),
 675            Box::new(move |envelope, session, span| {
 676                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 677                let received_at = envelope.received_at;
 678                tracing::info!("message received");
 679                let start_time = Instant::now();
 680                let future = (handler)(
 681                    *envelope,
 682                    MessageContext {
 683                        session,
 684                        span: span.clone(),
 685                    },
 686                );
 687                async move {
 688                    let result = future.await;
 689                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 690                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 691                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 692                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 693                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 694                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 695                    match result {
 696                        Err(error) => {
 697                            tracing::error!(?error, "error handling message")
 698                        }
 699                        Ok(()) => tracing::info!("finished handling message"),
 700                    }
 701                }
 702                .boxed()
 703            }),
 704        );
 705        if prev_handler.is_some() {
 706            panic!("registered a handler for the same message twice");
 707        }
 708        self
 709    }
 710
 711    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 712    where
 713        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 714        Fut: 'static + Send + Future<Output = Result<()>>,
 715        M: EnvelopedMessage,
 716    {
 717        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 718        self
 719    }
 720
 721    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 722    where
 723        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 724        Fut: Send + Future<Output = Result<()>>,
 725        M: RequestMessage,
 726    {
 727        let handler = Arc::new(handler);
 728        self.add_handler(move |envelope, session| {
 729            let receipt = envelope.receipt();
 730            let handler = handler.clone();
 731            async move {
 732                let peer = session.peer.clone();
 733                let responded = Arc::new(AtomicBool::default());
 734                let response = Response {
 735                    peer: peer.clone(),
 736                    responded: responded.clone(),
 737                    receipt,
 738                };
 739                match (handler)(envelope.payload, response, session).await {
 740                    Ok(()) => {
 741                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 742                            Ok(())
 743                        } else {
 744                            Err(anyhow!("handler did not send a response"))?
 745                        }
 746                    }
 747                    Err(error) => {
 748                        let proto_err = match &error {
 749                            Error::Internal(err) => err.to_proto(),
 750                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 751                        };
 752                        peer.respond_with_error(receipt, proto_err)?;
 753                        Err(error)
 754                    }
 755                }
 756            }
 757        })
 758    }
 759
 760    pub fn handle_connection(
 761        self: &Arc<Self>,
 762        connection: Connection,
 763        address: String,
 764        principal: Principal,
 765        zed_version: ZedVersion,
 766        release_channel: Option<String>,
 767        user_agent: Option<String>,
 768        geoip_country_code: Option<String>,
 769        system_id: Option<String>,
 770        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 771        executor: Executor,
 772        connection_guard: Option<ConnectionGuard>,
 773    ) -> impl Future<Output = ()> + use<> {
 774        let this = self.clone();
 775        let span = info_span!("handle connection", %address,
 776            connection_id=field::Empty,
 777            user_id=field::Empty,
 778            login=field::Empty,
 779            impersonator=field::Empty,
 780            user_agent=field::Empty,
 781            geoip_country_code=field::Empty,
 782            release_channel=field::Empty,
 783        );
 784        principal.update_span(&span);
 785        if let Some(user_agent) = user_agent {
 786            span.record("user_agent", user_agent);
 787        }
 788        if let Some(release_channel) = release_channel {
 789            span.record("release_channel", release_channel);
 790        }
 791
 792        if let Some(country_code) = geoip_country_code.as_ref() {
 793            span.record("geoip_country_code", country_code);
 794        }
 795
 796        let mut teardown = self.teardown.subscribe();
 797        async move {
 798            if *teardown.borrow() {
 799                tracing::error!("server is tearing down");
 800                return;
 801            }
 802
 803            let (connection_id, handle_io, mut incoming_rx) =
 804                this.peer.add_connection(connection, {
 805                    let executor = executor.clone();
 806                    move |duration| executor.sleep(duration)
 807                });
 808            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 809
 810            tracing::info!("connection opened");
 811
 812            let session = Session {
 813                principal: principal.clone(),
 814                connection_id,
 815                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 816                peer: this.peer.clone(),
 817                connection_pool: this.connection_pool.clone(),
 818                app_state: this.app_state.clone(),
 819                geoip_country_code,
 820                system_id,
 821                _executor: executor.clone(),
 822            };
 823
 824            if let Err(error) = this
 825                .send_initial_client_update(
 826                    connection_id,
 827                    zed_version,
 828                    send_connection_id,
 829                    &session,
 830                )
 831                .await
 832            {
 833                tracing::error!(?error, "failed to send initial client update");
 834                return;
 835            }
 836            drop(connection_guard);
 837
 838            let handle_io = handle_io.fuse();
 839            futures::pin_mut!(handle_io);
 840
 841            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 842            // This prevents deadlocks when e.g., client A performs a request to client B and
 843            // client B performs a request to client A. If both clients stop processing further
 844            // messages until their respective request completes, they won't have a chance to
 845            // respond to the other client's request and cause a deadlock.
 846            //
 847            // This arrangement ensures we will attempt to process earlier messages first, but fall
 848            // back to processing messages arrived later in the spirit of making progress.
 849            const MAX_CONCURRENT_HANDLERS: usize = 256;
 850            let mut foreground_message_handlers = FuturesUnordered::new();
 851            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 852            let get_concurrent_handlers = {
 853                let concurrent_handlers = concurrent_handlers.clone();
 854                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 855            };
 856            loop {
 857                let next_message = async {
 858                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 859                    let message = incoming_rx.next().await;
 860                    // Cache the concurrent_handlers here, so that we know what the
 861                    // queue looks like as each handler starts
 862                    (permit, message, get_concurrent_handlers())
 863                }
 864                .fuse();
 865                futures::pin_mut!(next_message);
 866                futures::select_biased! {
 867                    _ = teardown.changed().fuse() => return,
 868                    result = handle_io => {
 869                        if let Err(error) = result {
 870                            tracing::error!(?error, "error handling I/O");
 871                        }
 872                        break;
 873                    }
 874                    _ = foreground_message_handlers.next() => {}
 875                    next_message = next_message => {
 876                        let (permit, message, concurrent_handlers) = next_message;
 877                        if let Some(message) = message {
 878                            let type_name = message.payload_type_name();
 879                            // note: we copy all the fields from the parent span so we can query them in the logs.
 880                            // (https://github.com/tokio-rs/tracing/issues/2670).
 881                            let span = tracing::info_span!("receive message",
 882                                %connection_id,
 883                                %address,
 884                                type_name,
 885                                concurrent_handlers,
 886                                user_id=field::Empty,
 887                                login=field::Empty,
 888                                impersonator=field::Empty,
 889                                lsp_query_request=field::Empty,
 890                                release_channel=field::Empty,
 891                                { TOTAL_DURATION_MS }=field::Empty,
 892                                { PROCESSING_DURATION_MS }=field::Empty,
 893                                { QUEUE_DURATION_MS }=field::Empty,
 894                                { HOST_WAITING_MS }=field::Empty
 895                            );
 896                            principal.update_span(&span);
 897                            let span_enter = span.enter();
 898                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 899                                let is_background = message.is_background();
 900                                let handle_message = (handler)(message, session.clone(), span.clone());
 901                                drop(span_enter);
 902
 903                                let handle_message = async move {
 904                                    handle_message.await;
 905                                    drop(permit);
 906                                }.instrument(span);
 907                                if is_background {
 908                                    executor.spawn_detached(handle_message);
 909                                } else {
 910                                    foreground_message_handlers.push(handle_message);
 911                                }
 912                            } else {
 913                                tracing::error!("no message handler");
 914                            }
 915                        } else {
 916                            tracing::info!("connection closed");
 917                            break;
 918                        }
 919                    }
 920                }
 921            }
 922
 923            drop(foreground_message_handlers);
 924            let concurrent_handlers = get_concurrent_handlers();
 925            tracing::info!(concurrent_handlers, "signing out");
 926            if let Err(error) = connection_lost(session, teardown, executor).await {
 927                tracing::error!(?error, "error signing out");
 928            }
 929        }
 930        .instrument(span)
 931    }
 932
 933    async fn send_initial_client_update(
 934        &self,
 935        connection_id: ConnectionId,
 936        zed_version: ZedVersion,
 937        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 938        session: &Session,
 939    ) -> Result<()> {
 940        self.peer.send(
 941            connection_id,
 942            proto::Hello {
 943                peer_id: Some(connection_id.into()),
 944            },
 945        )?;
 946        tracing::info!("sent hello message");
 947        if let Some(send_connection_id) = send_connection_id.take() {
 948            let _ = send_connection_id.send(connection_id);
 949        }
 950
 951        match &session.principal {
 952            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 953                if !user.connected_once {
 954                    self.peer.send(connection_id, proto::ShowContacts {})?;
 955                    self.app_state
 956                        .db
 957                        .set_user_connected_once(user.id, true)
 958                        .await?;
 959                }
 960
 961                let contacts = self.app_state.db.get_contacts(user.id).await?;
 962
 963                {
 964                    let mut pool = self.connection_pool.lock();
 965                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 966                    self.peer.send(
 967                        connection_id,
 968                        build_initial_contacts_update(contacts, &pool),
 969                    )?;
 970                }
 971
 972                if should_auto_subscribe_to_channels(&zed_version) {
 973                    subscribe_user_to_channels(user.id, session).await?;
 974                }
 975
 976                if let Some(incoming_call) =
 977                    self.app_state.db.incoming_call_for_user(user.id).await?
 978                {
 979                    self.peer.send(connection_id, incoming_call)?;
 980                }
 981
 982                update_user_contacts(user.id, session).await?;
 983            }
 984        }
 985
 986        Ok(())
 987    }
 988
 989    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
 990        ServerSnapshot {
 991            connection_pool: ConnectionPoolGuard {
 992                guard: self.connection_pool.lock(),
 993                _not_send: PhantomData,
 994            },
 995            peer: &self.peer,
 996        }
 997    }
 998}
 999
1000impl Deref for ConnectionPoolGuard<'_> {
1001    type Target = ConnectionPool;
1002
1003    fn deref(&self) -> &Self::Target {
1004        &self.guard
1005    }
1006}
1007
1008impl DerefMut for ConnectionPoolGuard<'_> {
1009    fn deref_mut(&mut self) -> &mut Self::Target {
1010        &mut self.guard
1011    }
1012}
1013
1014impl Drop for ConnectionPoolGuard<'_> {
1015    fn drop(&mut self) {
1016        #[cfg(test)]
1017        self.check_invariants();
1018    }
1019}
1020
1021fn broadcast<F>(
1022    sender_id: Option<ConnectionId>,
1023    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1024    mut f: F,
1025) where
1026    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1027{
1028    for receiver_id in receiver_ids {
1029        if Some(receiver_id) != sender_id
1030            && let Err(error) = f(receiver_id)
1031        {
1032            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1033        }
1034    }
1035}
1036
1037pub struct ProtocolVersion(u32);
1038
1039impl Header for ProtocolVersion {
1040    fn name() -> &'static HeaderName {
1041        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1042        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1043    }
1044
1045    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1046    where
1047        Self: Sized,
1048        I: Iterator<Item = &'i axum::http::HeaderValue>,
1049    {
1050        let version = values
1051            .next()
1052            .ok_or_else(axum::headers::Error::invalid)?
1053            .to_str()
1054            .map_err(|_| axum::headers::Error::invalid())?
1055            .parse()
1056            .map_err(|_| axum::headers::Error::invalid())?;
1057        Ok(Self(version))
1058    }
1059
1060    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1061        values.extend([self.0.to_string().parse().unwrap()]);
1062    }
1063}
1064
1065pub struct AppVersionHeader(Version);
1066impl Header for AppVersionHeader {
1067    fn name() -> &'static HeaderName {
1068        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1069        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1070    }
1071
1072    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1073    where
1074        Self: Sized,
1075        I: Iterator<Item = &'i axum::http::HeaderValue>,
1076    {
1077        let version = values
1078            .next()
1079            .ok_or_else(axum::headers::Error::invalid)?
1080            .to_str()
1081            .map_err(|_| axum::headers::Error::invalid())?
1082            .parse()
1083            .map_err(|_| axum::headers::Error::invalid())?;
1084        Ok(Self(version))
1085    }
1086
1087    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1088        values.extend([self.0.to_string().parse().unwrap()]);
1089    }
1090}
1091
1092#[derive(Debug)]
1093pub struct ReleaseChannelHeader(String);
1094
1095impl Header for ReleaseChannelHeader {
1096    fn name() -> &'static HeaderName {
1097        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1098        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1099    }
1100
1101    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1102    where
1103        Self: Sized,
1104        I: Iterator<Item = &'i axum::http::HeaderValue>,
1105    {
1106        Ok(Self(
1107            values
1108                .next()
1109                .ok_or_else(axum::headers::Error::invalid)?
1110                .to_str()
1111                .map_err(|_| axum::headers::Error::invalid())?
1112                .to_owned(),
1113        ))
1114    }
1115
1116    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1117        values.extend([self.0.parse().unwrap()]);
1118    }
1119}
1120
1121pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1122    Router::new()
1123        .route("/rpc", get(handle_websocket_request))
1124        .layer(
1125            ServiceBuilder::new()
1126                .layer(Extension(server.app_state.clone()))
1127                .layer(middleware::from_fn(auth::validate_header)),
1128        )
1129        .route("/metrics", get(handle_metrics))
1130        .layer(Extension(server))
1131}
1132
1133pub async fn handle_websocket_request(
1134    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1135    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1136    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1137    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1138    Extension(server): Extension<Arc<Server>>,
1139    Extension(principal): Extension<Principal>,
1140    user_agent: Option<TypedHeader<UserAgent>>,
1141    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1142    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1143    ws: WebSocketUpgrade,
1144) -> axum::response::Response {
1145    if protocol_version != rpc::PROTOCOL_VERSION {
1146        return (
1147            StatusCode::UPGRADE_REQUIRED,
1148            "client must be upgraded".to_string(),
1149        )
1150            .into_response();
1151    }
1152
1153    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1154        return (
1155            StatusCode::UPGRADE_REQUIRED,
1156            "no version header found".to_string(),
1157        )
1158            .into_response();
1159    };
1160
1161    let release_channel = release_channel_header.map(|header| header.0.0);
1162
1163    if !version.can_collaborate() {
1164        return (
1165            StatusCode::UPGRADE_REQUIRED,
1166            "client must be upgraded".to_string(),
1167        )
1168            .into_response();
1169    }
1170
1171    let socket_address = socket_address.to_string();
1172
1173    // Acquire connection guard before WebSocket upgrade
1174    let connection_guard = match ConnectionGuard::try_acquire() {
1175        Ok(guard) => guard,
1176        Err(()) => {
1177            return (
1178                StatusCode::SERVICE_UNAVAILABLE,
1179                "Too many concurrent connections",
1180            )
1181                .into_response();
1182        }
1183    };
1184
1185    ws.on_upgrade(move |socket| {
1186        let socket = socket
1187            .map_ok(to_tungstenite_message)
1188            .err_into()
1189            .with(|message| async move { to_axum_message(message) });
1190        let connection = Connection::new(Box::pin(socket));
1191        async move {
1192            server
1193                .handle_connection(
1194                    connection,
1195                    socket_address,
1196                    principal,
1197                    version,
1198                    release_channel,
1199                    user_agent.map(|header| header.to_string()),
1200                    country_code_header.map(|header| header.to_string()),
1201                    system_id_header.map(|header| header.to_string()),
1202                    None,
1203                    Executor::Production,
1204                    Some(connection_guard),
1205                )
1206                .await;
1207        }
1208    })
1209}
1210
1211pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1212    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1213    let connections_metric = CONNECTIONS_METRIC
1214        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1215
1216    let connections = server
1217        .connection_pool
1218        .lock()
1219        .connections()
1220        .filter(|connection| !connection.admin)
1221        .count();
1222    connections_metric.set(connections as _);
1223
1224    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1225    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1226        register_int_gauge!(
1227            "shared_projects",
1228            "number of open projects with one or more guests"
1229        )
1230        .unwrap()
1231    });
1232
1233    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1234    shared_projects_metric.set(shared_projects as _);
1235
1236    let encoder = prometheus::TextEncoder::new();
1237    let metric_families = prometheus::gather();
1238    let encoded_metrics = encoder
1239        .encode_to_string(&metric_families)
1240        .map_err(|err| anyhow!("{err}"))?;
1241    Ok(encoded_metrics)
1242}
1243
1244#[instrument(err, skip(executor))]
1245async fn connection_lost(
1246    session: Session,
1247    mut teardown: watch::Receiver<bool>,
1248    executor: Executor,
1249) -> Result<()> {
1250    session.peer.disconnect(session.connection_id);
1251    session
1252        .connection_pool()
1253        .await
1254        .remove_connection(session.connection_id)?;
1255
1256    session
1257        .db()
1258        .await
1259        .connection_lost(session.connection_id)
1260        .await
1261        .trace_err();
1262
1263    futures::select_biased! {
1264        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1265
1266            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1267            leave_room_for_session(&session, session.connection_id).await.trace_err();
1268            leave_channel_buffers_for_session(&session)
1269                .await
1270                .trace_err();
1271
1272            if !session
1273                .connection_pool()
1274                .await
1275                .is_user_online(session.user_id())
1276            {
1277                let db = session.db().await;
1278                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1279                    room_updated(&room, &session.peer);
1280                }
1281            }
1282
1283            update_user_contacts(session.user_id(), &session).await?;
1284        },
1285        _ = teardown.changed().fuse() => {}
1286    }
1287
1288    Ok(())
1289}
1290
1291/// Acknowledges a ping from a client, used to keep the connection alive.
1292async fn ping(
1293    _: proto::Ping,
1294    response: Response<proto::Ping>,
1295    _session: MessageContext,
1296) -> Result<()> {
1297    response.send(proto::Ack {})?;
1298    Ok(())
1299}
1300
1301/// Creates a new room for calling (outside of channels)
1302async fn create_room(
1303    _request: proto::CreateRoom,
1304    response: Response<proto::CreateRoom>,
1305    session: MessageContext,
1306) -> Result<()> {
1307    let livekit_room = nanoid::nanoid!(30);
1308
1309    let live_kit_connection_info = util::maybe!(async {
1310        let live_kit = session.app_state.livekit_client.as_ref();
1311        let live_kit = live_kit?;
1312        let user_id = session.user_id().to_string();
1313
1314        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1315
1316        Some(proto::LiveKitConnectionInfo {
1317            server_url: live_kit.url().into(),
1318            token,
1319            can_publish: true,
1320        })
1321    })
1322    .await;
1323
1324    let room = session
1325        .db()
1326        .await
1327        .create_room(session.user_id(), session.connection_id, &livekit_room)
1328        .await?;
1329
1330    response.send(proto::CreateRoomResponse {
1331        room: Some(room.clone()),
1332        live_kit_connection_info,
1333    })?;
1334
1335    update_user_contacts(session.user_id(), &session).await?;
1336    Ok(())
1337}
1338
1339/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1340async fn join_room(
1341    request: proto::JoinRoom,
1342    response: Response<proto::JoinRoom>,
1343    session: MessageContext,
1344) -> Result<()> {
1345    let room_id = RoomId::from_proto(request.id);
1346
1347    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1348
1349    if let Some(channel_id) = channel_id {
1350        return join_channel_internal(channel_id, Box::new(response), session).await;
1351    }
1352
1353    let joined_room = {
1354        let room = session
1355            .db()
1356            .await
1357            .join_room(room_id, session.user_id(), session.connection_id)
1358            .await?;
1359        room_updated(&room.room, &session.peer);
1360        room.into_inner()
1361    };
1362
1363    for connection_id in session
1364        .connection_pool()
1365        .await
1366        .user_connection_ids(session.user_id())
1367    {
1368        session
1369            .peer
1370            .send(
1371                connection_id,
1372                proto::CallCanceled {
1373                    room_id: room_id.to_proto(),
1374                },
1375            )
1376            .trace_err();
1377    }
1378
1379    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1380    {
1381        live_kit
1382            .room_token(
1383                &joined_room.room.livekit_room,
1384                &session.user_id().to_string(),
1385            )
1386            .trace_err()
1387            .map(|token| proto::LiveKitConnectionInfo {
1388                server_url: live_kit.url().into(),
1389                token,
1390                can_publish: true,
1391            })
1392    } else {
1393        None
1394    };
1395
1396    response.send(proto::JoinRoomResponse {
1397        room: Some(joined_room.room),
1398        channel_id: None,
1399        live_kit_connection_info,
1400    })?;
1401
1402    update_user_contacts(session.user_id(), &session).await?;
1403    Ok(())
1404}
1405
1406/// Rejoin room is used to reconnect to a room after connection errors.
1407async fn rejoin_room(
1408    request: proto::RejoinRoom,
1409    response: Response<proto::RejoinRoom>,
1410    session: MessageContext,
1411) -> Result<()> {
1412    let room;
1413    let channel;
1414    {
1415        let mut rejoined_room = session
1416            .db()
1417            .await
1418            .rejoin_room(request, session.user_id(), session.connection_id)
1419            .await?;
1420
1421        response.send(proto::RejoinRoomResponse {
1422            room: Some(rejoined_room.room.clone()),
1423            reshared_projects: rejoined_room
1424                .reshared_projects
1425                .iter()
1426                .map(|project| proto::ResharedProject {
1427                    id: project.id.to_proto(),
1428                    collaborators: project
1429                        .collaborators
1430                        .iter()
1431                        .map(|collaborator| collaborator.to_proto())
1432                        .collect(),
1433                })
1434                .collect(),
1435            rejoined_projects: rejoined_room
1436                .rejoined_projects
1437                .iter()
1438                .map(|rejoined_project| rejoined_project.to_proto())
1439                .collect(),
1440        })?;
1441        room_updated(&rejoined_room.room, &session.peer);
1442
1443        for project in &rejoined_room.reshared_projects {
1444            for collaborator in &project.collaborators {
1445                session
1446                    .peer
1447                    .send(
1448                        collaborator.connection_id,
1449                        proto::UpdateProjectCollaborator {
1450                            project_id: project.id.to_proto(),
1451                            old_peer_id: Some(project.old_connection_id.into()),
1452                            new_peer_id: Some(session.connection_id.into()),
1453                        },
1454                    )
1455                    .trace_err();
1456            }
1457
1458            broadcast(
1459                Some(session.connection_id),
1460                project
1461                    .collaborators
1462                    .iter()
1463                    .map(|collaborator| collaborator.connection_id),
1464                |connection_id| {
1465                    session.peer.forward_send(
1466                        session.connection_id,
1467                        connection_id,
1468                        proto::UpdateProject {
1469                            project_id: project.id.to_proto(),
1470                            worktrees: project.worktrees.clone(),
1471                        },
1472                    )
1473                },
1474            );
1475        }
1476
1477        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1478
1479        let rejoined_room = rejoined_room.into_inner();
1480
1481        room = rejoined_room.room;
1482        channel = rejoined_room.channel;
1483    }
1484
1485    if let Some(channel) = channel {
1486        channel_updated(
1487            &channel,
1488            &room,
1489            &session.peer,
1490            &*session.connection_pool().await,
1491        );
1492    }
1493
1494    update_user_contacts(session.user_id(), &session).await?;
1495    Ok(())
1496}
1497
1498fn notify_rejoined_projects(
1499    rejoined_projects: &mut Vec<RejoinedProject>,
1500    session: &Session,
1501) -> Result<()> {
1502    for project in rejoined_projects.iter() {
1503        for collaborator in &project.collaborators {
1504            session
1505                .peer
1506                .send(
1507                    collaborator.connection_id,
1508                    proto::UpdateProjectCollaborator {
1509                        project_id: project.id.to_proto(),
1510                        old_peer_id: Some(project.old_connection_id.into()),
1511                        new_peer_id: Some(session.connection_id.into()),
1512                    },
1513                )
1514                .trace_err();
1515        }
1516    }
1517
1518    for project in rejoined_projects {
1519        for worktree in mem::take(&mut project.worktrees) {
1520            // Stream this worktree's entries.
1521            let message = proto::UpdateWorktree {
1522                project_id: project.id.to_proto(),
1523                worktree_id: worktree.id,
1524                abs_path: worktree.abs_path.clone(),
1525                root_name: worktree.root_name,
1526                updated_entries: worktree.updated_entries,
1527                removed_entries: worktree.removed_entries,
1528                scan_id: worktree.scan_id,
1529                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1530                updated_repositories: worktree.updated_repositories,
1531                removed_repositories: worktree.removed_repositories,
1532            };
1533            for update in proto::split_worktree_update(message) {
1534                session.peer.send(session.connection_id, update)?;
1535            }
1536
1537            // Stream this worktree's diagnostics.
1538            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1539            if let Some(summary) = worktree_diagnostics.next() {
1540                let message = proto::UpdateDiagnosticSummary {
1541                    project_id: project.id.to_proto(),
1542                    worktree_id: worktree.id,
1543                    summary: Some(summary),
1544                    more_summaries: worktree_diagnostics.collect(),
1545                };
1546                session.peer.send(session.connection_id, message)?;
1547            }
1548
1549            for settings_file in worktree.settings_files {
1550                session.peer.send(
1551                    session.connection_id,
1552                    proto::UpdateWorktreeSettings {
1553                        project_id: project.id.to_proto(),
1554                        worktree_id: worktree.id,
1555                        path: settings_file.path,
1556                        content: Some(settings_file.content),
1557                        kind: Some(settings_file.kind.to_proto().into()),
1558                    },
1559                )?;
1560            }
1561        }
1562
1563        for repository in mem::take(&mut project.updated_repositories) {
1564            for update in split_repository_update(repository) {
1565                session.peer.send(session.connection_id, update)?;
1566            }
1567        }
1568
1569        for id in mem::take(&mut project.removed_repositories) {
1570            session.peer.send(
1571                session.connection_id,
1572                proto::RemoveRepository {
1573                    project_id: project.id.to_proto(),
1574                    id,
1575                },
1576            )?;
1577        }
1578    }
1579
1580    Ok(())
1581}
1582
1583/// leave room disconnects from the room.
1584async fn leave_room(
1585    _: proto::LeaveRoom,
1586    response: Response<proto::LeaveRoom>,
1587    session: MessageContext,
1588) -> Result<()> {
1589    leave_room_for_session(&session, session.connection_id).await?;
1590    response.send(proto::Ack {})?;
1591    Ok(())
1592}
1593
1594/// Updates the permissions of someone else in the room.
1595async fn set_room_participant_role(
1596    request: proto::SetRoomParticipantRole,
1597    response: Response<proto::SetRoomParticipantRole>,
1598    session: MessageContext,
1599) -> Result<()> {
1600    let user_id = UserId::from_proto(request.user_id);
1601    let role = ChannelRole::from(request.role());
1602
1603    let (livekit_room, can_publish) = {
1604        let room = session
1605            .db()
1606            .await
1607            .set_room_participant_role(
1608                session.user_id(),
1609                RoomId::from_proto(request.room_id),
1610                user_id,
1611                role,
1612            )
1613            .await?;
1614
1615        let livekit_room = room.livekit_room.clone();
1616        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1617        room_updated(&room, &session.peer);
1618        (livekit_room, can_publish)
1619    };
1620
1621    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1622        live_kit
1623            .update_participant(
1624                livekit_room.clone(),
1625                request.user_id.to_string(),
1626                livekit_api::proto::ParticipantPermission {
1627                    can_subscribe: true,
1628                    can_publish,
1629                    can_publish_data: can_publish,
1630                    hidden: false,
1631                    recorder: false,
1632                },
1633            )
1634            .await
1635            .trace_err();
1636    }
1637
1638    response.send(proto::Ack {})?;
1639    Ok(())
1640}
1641
1642/// Call someone else into the current room
1643async fn call(
1644    request: proto::Call,
1645    response: Response<proto::Call>,
1646    session: MessageContext,
1647) -> Result<()> {
1648    let room_id = RoomId::from_proto(request.room_id);
1649    let calling_user_id = session.user_id();
1650    let calling_connection_id = session.connection_id;
1651    let called_user_id = UserId::from_proto(request.called_user_id);
1652    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1653    if !session
1654        .db()
1655        .await
1656        .has_contact(calling_user_id, called_user_id)
1657        .await?
1658    {
1659        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1660    }
1661
1662    let incoming_call = {
1663        let (room, incoming_call) = &mut *session
1664            .db()
1665            .await
1666            .call(
1667                room_id,
1668                calling_user_id,
1669                calling_connection_id,
1670                called_user_id,
1671                initial_project_id,
1672            )
1673            .await?;
1674        room_updated(room, &session.peer);
1675        mem::take(incoming_call)
1676    };
1677    update_user_contacts(called_user_id, &session).await?;
1678
1679    let mut calls = session
1680        .connection_pool()
1681        .await
1682        .user_connection_ids(called_user_id)
1683        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1684        .collect::<FuturesUnordered<_>>();
1685
1686    while let Some(call_response) = calls.next().await {
1687        match call_response.as_ref() {
1688            Ok(_) => {
1689                response.send(proto::Ack {})?;
1690                return Ok(());
1691            }
1692            Err(_) => {
1693                call_response.trace_err();
1694            }
1695        }
1696    }
1697
1698    {
1699        let room = session
1700            .db()
1701            .await
1702            .call_failed(room_id, called_user_id)
1703            .await?;
1704        room_updated(&room, &session.peer);
1705    }
1706    update_user_contacts(called_user_id, &session).await?;
1707
1708    Err(anyhow!("failed to ring user"))?
1709}
1710
1711/// Cancel an outgoing call.
1712async fn cancel_call(
1713    request: proto::CancelCall,
1714    response: Response<proto::CancelCall>,
1715    session: MessageContext,
1716) -> Result<()> {
1717    let called_user_id = UserId::from_proto(request.called_user_id);
1718    let room_id = RoomId::from_proto(request.room_id);
1719    {
1720        let room = session
1721            .db()
1722            .await
1723            .cancel_call(room_id, session.connection_id, called_user_id)
1724            .await?;
1725        room_updated(&room, &session.peer);
1726    }
1727
1728    for connection_id in session
1729        .connection_pool()
1730        .await
1731        .user_connection_ids(called_user_id)
1732    {
1733        session
1734            .peer
1735            .send(
1736                connection_id,
1737                proto::CallCanceled {
1738                    room_id: room_id.to_proto(),
1739                },
1740            )
1741            .trace_err();
1742    }
1743    response.send(proto::Ack {})?;
1744
1745    update_user_contacts(called_user_id, &session).await?;
1746    Ok(())
1747}
1748
1749/// Decline an incoming call.
1750async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1751    let room_id = RoomId::from_proto(message.room_id);
1752    {
1753        let room = session
1754            .db()
1755            .await
1756            .decline_call(Some(room_id), session.user_id())
1757            .await?
1758            .context("declining call")?;
1759        room_updated(&room, &session.peer);
1760    }
1761
1762    for connection_id in session
1763        .connection_pool()
1764        .await
1765        .user_connection_ids(session.user_id())
1766    {
1767        session
1768            .peer
1769            .send(
1770                connection_id,
1771                proto::CallCanceled {
1772                    room_id: room_id.to_proto(),
1773                },
1774            )
1775            .trace_err();
1776    }
1777    update_user_contacts(session.user_id(), &session).await?;
1778    Ok(())
1779}
1780
1781/// Updates other participants in the room with your current location.
1782async fn update_participant_location(
1783    request: proto::UpdateParticipantLocation,
1784    response: Response<proto::UpdateParticipantLocation>,
1785    session: MessageContext,
1786) -> Result<()> {
1787    let room_id = RoomId::from_proto(request.room_id);
1788    let location = request.location.context("invalid location")?;
1789
1790    let db = session.db().await;
1791    let room = db
1792        .update_room_participant_location(room_id, session.connection_id, location)
1793        .await?;
1794
1795    room_updated(&room, &session.peer);
1796    response.send(proto::Ack {})?;
1797    Ok(())
1798}
1799
1800/// Share a project into the room.
1801async fn share_project(
1802    request: proto::ShareProject,
1803    response: Response<proto::ShareProject>,
1804    session: MessageContext,
1805) -> Result<()> {
1806    let (project_id, room) = &*session
1807        .db()
1808        .await
1809        .share_project(
1810            RoomId::from_proto(request.room_id),
1811            session.connection_id,
1812            &request.worktrees,
1813            request.is_ssh_project,
1814            request.windows_paths.unwrap_or(false),
1815        )
1816        .await?;
1817    response.send(proto::ShareProjectResponse {
1818        project_id: project_id.to_proto(),
1819    })?;
1820    room_updated(room, &session.peer);
1821
1822    Ok(())
1823}
1824
1825/// Unshare a project from the room.
1826async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1827    let project_id = ProjectId::from_proto(message.project_id);
1828    unshare_project_internal(project_id, session.connection_id, &session).await
1829}
1830
1831async fn unshare_project_internal(
1832    project_id: ProjectId,
1833    connection_id: ConnectionId,
1834    session: &Session,
1835) -> Result<()> {
1836    let delete = {
1837        let room_guard = session
1838            .db()
1839            .await
1840            .unshare_project(project_id, connection_id)
1841            .await?;
1842
1843        let (delete, room, guest_connection_ids) = &*room_guard;
1844
1845        let message = proto::UnshareProject {
1846            project_id: project_id.to_proto(),
1847        };
1848
1849        broadcast(
1850            Some(connection_id),
1851            guest_connection_ids.iter().copied(),
1852            |conn_id| session.peer.send(conn_id, message.clone()),
1853        );
1854        if let Some(room) = room {
1855            room_updated(room, &session.peer);
1856        }
1857
1858        *delete
1859    };
1860
1861    if delete {
1862        let db = session.db().await;
1863        db.delete_project(project_id).await?;
1864    }
1865
1866    Ok(())
1867}
1868
1869/// Join someone elses shared project.
1870async fn join_project(
1871    request: proto::JoinProject,
1872    response: Response<proto::JoinProject>,
1873    session: MessageContext,
1874) -> Result<()> {
1875    let project_id = ProjectId::from_proto(request.project_id);
1876
1877    tracing::info!(%project_id, "join project");
1878
1879    let db = session.db().await;
1880    let (project, replica_id) = &mut *db
1881        .join_project(
1882            project_id,
1883            session.connection_id,
1884            session.user_id(),
1885            request.committer_name.clone(),
1886            request.committer_email.clone(),
1887        )
1888        .await?;
1889    drop(db);
1890    tracing::info!(%project_id, "join remote project");
1891    let collaborators = project
1892        .collaborators
1893        .iter()
1894        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1895        .map(|collaborator| collaborator.to_proto())
1896        .collect::<Vec<_>>();
1897    let project_id = project.id;
1898    let guest_user_id = session.user_id();
1899
1900    let worktrees = project
1901        .worktrees
1902        .iter()
1903        .map(|(id, worktree)| proto::WorktreeMetadata {
1904            id: *id,
1905            root_name: worktree.root_name.clone(),
1906            visible: worktree.visible,
1907            abs_path: worktree.abs_path.clone(),
1908        })
1909        .collect::<Vec<_>>();
1910
1911    let add_project_collaborator = proto::AddProjectCollaborator {
1912        project_id: project_id.to_proto(),
1913        collaborator: Some(proto::Collaborator {
1914            peer_id: Some(session.connection_id.into()),
1915            replica_id: replica_id.0 as u32,
1916            user_id: guest_user_id.to_proto(),
1917            is_host: false,
1918            committer_name: request.committer_name.clone(),
1919            committer_email: request.committer_email.clone(),
1920        }),
1921    };
1922
1923    for collaborator in &collaborators {
1924        session
1925            .peer
1926            .send(
1927                collaborator.peer_id.unwrap().into(),
1928                add_project_collaborator.clone(),
1929            )
1930            .trace_err();
1931    }
1932
1933    // First, we send the metadata associated with each worktree.
1934    let (language_servers, language_server_capabilities) = project
1935        .language_servers
1936        .clone()
1937        .into_iter()
1938        .map(|server| (server.server, server.capabilities))
1939        .unzip();
1940    response.send(proto::JoinProjectResponse {
1941        project_id: project.id.0 as u64,
1942        worktrees,
1943        replica_id: replica_id.0 as u32,
1944        collaborators,
1945        language_servers,
1946        language_server_capabilities,
1947        role: project.role.into(),
1948        windows_paths: project.path_style == PathStyle::Windows,
1949    })?;
1950
1951    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1952        // Stream this worktree's entries.
1953        let message = proto::UpdateWorktree {
1954            project_id: project_id.to_proto(),
1955            worktree_id,
1956            abs_path: worktree.abs_path.clone(),
1957            root_name: worktree.root_name,
1958            updated_entries: worktree.entries,
1959            removed_entries: Default::default(),
1960            scan_id: worktree.scan_id,
1961            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1962            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1963            removed_repositories: Default::default(),
1964        };
1965        for update in proto::split_worktree_update(message) {
1966            session.peer.send(session.connection_id, update.clone())?;
1967        }
1968
1969        // Stream this worktree's diagnostics.
1970        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1971        if let Some(summary) = worktree_diagnostics.next() {
1972            let message = proto::UpdateDiagnosticSummary {
1973                project_id: project.id.to_proto(),
1974                worktree_id: worktree.id,
1975                summary: Some(summary),
1976                more_summaries: worktree_diagnostics.collect(),
1977            };
1978            session.peer.send(session.connection_id, message)?;
1979        }
1980
1981        for settings_file in worktree.settings_files {
1982            session.peer.send(
1983                session.connection_id,
1984                proto::UpdateWorktreeSettings {
1985                    project_id: project_id.to_proto(),
1986                    worktree_id: worktree.id,
1987                    path: settings_file.path,
1988                    content: Some(settings_file.content),
1989                    kind: Some(settings_file.kind.to_proto() as i32),
1990                },
1991            )?;
1992        }
1993    }
1994
1995    for repository in mem::take(&mut project.repositories) {
1996        for update in split_repository_update(repository) {
1997            session.peer.send(session.connection_id, update)?;
1998        }
1999    }
2000
2001    for language_server in &project.language_servers {
2002        session.peer.send(
2003            session.connection_id,
2004            proto::UpdateLanguageServer {
2005                project_id: project_id.to_proto(),
2006                server_name: Some(language_server.server.name.clone()),
2007                language_server_id: language_server.server.id,
2008                variant: Some(
2009                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2010                        proto::LspDiskBasedDiagnosticsUpdated {},
2011                    ),
2012                ),
2013            },
2014        )?;
2015    }
2016
2017    Ok(())
2018}
2019
2020/// Leave someone elses shared project.
2021async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2022    let sender_id = session.connection_id;
2023    let project_id = ProjectId::from_proto(request.project_id);
2024    let db = session.db().await;
2025
2026    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2027    tracing::info!(
2028        %project_id,
2029        "leave project"
2030    );
2031
2032    project_left(project, &session);
2033    if let Some(room) = room {
2034        room_updated(room, &session.peer);
2035    }
2036
2037    Ok(())
2038}
2039
2040/// Updates other participants with changes to the project
2041async fn update_project(
2042    request: proto::UpdateProject,
2043    response: Response<proto::UpdateProject>,
2044    session: MessageContext,
2045) -> Result<()> {
2046    let project_id = ProjectId::from_proto(request.project_id);
2047    let (room, guest_connection_ids) = &*session
2048        .db()
2049        .await
2050        .update_project(project_id, session.connection_id, &request.worktrees)
2051        .await?;
2052    broadcast(
2053        Some(session.connection_id),
2054        guest_connection_ids.iter().copied(),
2055        |connection_id| {
2056            session
2057                .peer
2058                .forward_send(session.connection_id, connection_id, request.clone())
2059        },
2060    );
2061    if let Some(room) = room {
2062        room_updated(room, &session.peer);
2063    }
2064    response.send(proto::Ack {})?;
2065
2066    Ok(())
2067}
2068
2069/// Updates other participants with changes to the worktree
2070async fn update_worktree(
2071    request: proto::UpdateWorktree,
2072    response: Response<proto::UpdateWorktree>,
2073    session: MessageContext,
2074) -> Result<()> {
2075    let guest_connection_ids = session
2076        .db()
2077        .await
2078        .update_worktree(&request, session.connection_id)
2079        .await?;
2080
2081    broadcast(
2082        Some(session.connection_id),
2083        guest_connection_ids.iter().copied(),
2084        |connection_id| {
2085            session
2086                .peer
2087                .forward_send(session.connection_id, connection_id, request.clone())
2088        },
2089    );
2090    response.send(proto::Ack {})?;
2091    Ok(())
2092}
2093
2094async fn update_repository(
2095    request: proto::UpdateRepository,
2096    response: Response<proto::UpdateRepository>,
2097    session: MessageContext,
2098) -> Result<()> {
2099    let guest_connection_ids = session
2100        .db()
2101        .await
2102        .update_repository(&request, session.connection_id)
2103        .await?;
2104
2105    broadcast(
2106        Some(session.connection_id),
2107        guest_connection_ids.iter().copied(),
2108        |connection_id| {
2109            session
2110                .peer
2111                .forward_send(session.connection_id, connection_id, request.clone())
2112        },
2113    );
2114    response.send(proto::Ack {})?;
2115    Ok(())
2116}
2117
2118async fn remove_repository(
2119    request: proto::RemoveRepository,
2120    response: Response<proto::RemoveRepository>,
2121    session: MessageContext,
2122) -> Result<()> {
2123    let guest_connection_ids = session
2124        .db()
2125        .await
2126        .remove_repository(&request, session.connection_id)
2127        .await?;
2128
2129    broadcast(
2130        Some(session.connection_id),
2131        guest_connection_ids.iter().copied(),
2132        |connection_id| {
2133            session
2134                .peer
2135                .forward_send(session.connection_id, connection_id, request.clone())
2136        },
2137    );
2138    response.send(proto::Ack {})?;
2139    Ok(())
2140}
2141
2142/// Updates other participants with changes to the diagnostics
2143async fn update_diagnostic_summary(
2144    message: proto::UpdateDiagnosticSummary,
2145    session: MessageContext,
2146) -> Result<()> {
2147    let guest_connection_ids = session
2148        .db()
2149        .await
2150        .update_diagnostic_summary(&message, session.connection_id)
2151        .await?;
2152
2153    broadcast(
2154        Some(session.connection_id),
2155        guest_connection_ids.iter().copied(),
2156        |connection_id| {
2157            session
2158                .peer
2159                .forward_send(session.connection_id, connection_id, message.clone())
2160        },
2161    );
2162
2163    Ok(())
2164}
2165
2166/// Updates other participants with changes to the worktree settings
2167async fn update_worktree_settings(
2168    message: proto::UpdateWorktreeSettings,
2169    session: MessageContext,
2170) -> Result<()> {
2171    let guest_connection_ids = session
2172        .db()
2173        .await
2174        .update_worktree_settings(&message, session.connection_id)
2175        .await?;
2176
2177    broadcast(
2178        Some(session.connection_id),
2179        guest_connection_ids.iter().copied(),
2180        |connection_id| {
2181            session
2182                .peer
2183                .forward_send(session.connection_id, connection_id, message.clone())
2184        },
2185    );
2186
2187    Ok(())
2188}
2189
2190/// Notify other participants that a language server has started.
2191async fn start_language_server(
2192    request: proto::StartLanguageServer,
2193    session: MessageContext,
2194) -> Result<()> {
2195    let guest_connection_ids = session
2196        .db()
2197        .await
2198        .start_language_server(&request, session.connection_id)
2199        .await?;
2200
2201    broadcast(
2202        Some(session.connection_id),
2203        guest_connection_ids.iter().copied(),
2204        |connection_id| {
2205            session
2206                .peer
2207                .forward_send(session.connection_id, connection_id, request.clone())
2208        },
2209    );
2210    Ok(())
2211}
2212
2213/// Notify other participants that a language server has changed.
2214async fn update_language_server(
2215    request: proto::UpdateLanguageServer,
2216    session: MessageContext,
2217) -> Result<()> {
2218    let project_id = ProjectId::from_proto(request.project_id);
2219    let db = session.db().await;
2220
2221    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2222        && let Some(capabilities) = update.capabilities.clone()
2223    {
2224        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2225            .await?;
2226    }
2227
2228    let project_connection_ids = db
2229        .project_connection_ids(project_id, session.connection_id, true)
2230        .await?;
2231    broadcast(
2232        Some(session.connection_id),
2233        project_connection_ids.iter().copied(),
2234        |connection_id| {
2235            session
2236                .peer
2237                .forward_send(session.connection_id, connection_id, request.clone())
2238        },
2239    );
2240    Ok(())
2241}
2242
2243/// forward a project request to the host. These requests should be read only
2244/// as guests are allowed to send them.
2245async fn forward_read_only_project_request<T>(
2246    request: T,
2247    response: Response<T>,
2248    session: MessageContext,
2249) -> Result<()>
2250where
2251    T: EntityMessage + RequestMessage,
2252{
2253    let project_id = ProjectId::from_proto(request.remote_entity_id());
2254    let host_connection_id = session
2255        .db()
2256        .await
2257        .host_for_read_only_project_request(project_id, session.connection_id)
2258        .await?;
2259    let payload = session.forward_request(host_connection_id, request).await?;
2260    response.send(payload)?;
2261    Ok(())
2262}
2263
2264/// forward a project request to the host. These requests are disallowed
2265/// for guests.
2266async fn forward_mutating_project_request<T>(
2267    request: T,
2268    response: Response<T>,
2269    session: MessageContext,
2270) -> Result<()>
2271where
2272    T: EntityMessage + RequestMessage,
2273{
2274    let project_id = ProjectId::from_proto(request.remote_entity_id());
2275
2276    let host_connection_id = session
2277        .db()
2278        .await
2279        .host_for_mutating_project_request(project_id, session.connection_id)
2280        .await?;
2281    let payload = session.forward_request(host_connection_id, request).await?;
2282    response.send(payload)?;
2283    Ok(())
2284}
2285
2286async fn lsp_query(
2287    request: proto::LspQuery,
2288    response: Response<proto::LspQuery>,
2289    session: MessageContext,
2290) -> Result<()> {
2291    let (name, should_write) = request.query_name_and_write_permissions();
2292    tracing::Span::current().record("lsp_query_request", name);
2293    tracing::info!("lsp_query message received");
2294    if should_write {
2295        forward_mutating_project_request(request, response, session).await
2296    } else {
2297        forward_read_only_project_request(request, response, session).await
2298    }
2299}
2300
2301/// Notify other participants that a new buffer has been created
2302async fn create_buffer_for_peer(
2303    request: proto::CreateBufferForPeer,
2304    session: MessageContext,
2305) -> Result<()> {
2306    session
2307        .db()
2308        .await
2309        .check_user_is_project_host(
2310            ProjectId::from_proto(request.project_id),
2311            session.connection_id,
2312        )
2313        .await?;
2314    let peer_id = request.peer_id.context("invalid peer id")?;
2315    session
2316        .peer
2317        .forward_send(session.connection_id, peer_id.into(), request)?;
2318    Ok(())
2319}
2320
2321/// Notify other participants that a new image has been created
2322async fn create_image_for_peer(
2323    request: proto::CreateImageForPeer,
2324    session: MessageContext,
2325) -> Result<()> {
2326    session
2327        .db()
2328        .await
2329        .check_user_is_project_host(
2330            ProjectId::from_proto(request.project_id),
2331            session.connection_id,
2332        )
2333        .await?;
2334    let peer_id = request.peer_id.context("invalid peer id")?;
2335    session
2336        .peer
2337        .forward_send(session.connection_id, peer_id.into(), request)?;
2338    Ok(())
2339}
2340
2341/// Notify other participants that a buffer has been updated. This is
2342/// allowed for guests as long as the update is limited to selections.
2343async fn update_buffer(
2344    request: proto::UpdateBuffer,
2345    response: Response<proto::UpdateBuffer>,
2346    session: MessageContext,
2347) -> Result<()> {
2348    let project_id = ProjectId::from_proto(request.project_id);
2349    let mut capability = Capability::ReadOnly;
2350
2351    for op in request.operations.iter() {
2352        match op.variant {
2353            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2354            Some(_) => capability = Capability::ReadWrite,
2355        }
2356    }
2357
2358    let host = {
2359        let guard = session
2360            .db()
2361            .await
2362            .connections_for_buffer_update(project_id, session.connection_id, capability)
2363            .await?;
2364
2365        let (host, guests) = &*guard;
2366
2367        broadcast(
2368            Some(session.connection_id),
2369            guests.clone(),
2370            |connection_id| {
2371                session
2372                    .peer
2373                    .forward_send(session.connection_id, connection_id, request.clone())
2374            },
2375        );
2376
2377        *host
2378    };
2379
2380    if host != session.connection_id {
2381        session.forward_request(host, request.clone()).await?;
2382    }
2383
2384    response.send(proto::Ack {})?;
2385    Ok(())
2386}
2387
2388async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2389    let project_id = ProjectId::from_proto(message.project_id);
2390
2391    let operation = message.operation.as_ref().context("invalid operation")?;
2392    let capability = match operation.variant.as_ref() {
2393        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2394            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2395                match buffer_op.variant {
2396                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2397                        Capability::ReadOnly
2398                    }
2399                    _ => Capability::ReadWrite,
2400                }
2401            } else {
2402                Capability::ReadWrite
2403            }
2404        }
2405        Some(_) => Capability::ReadWrite,
2406        None => Capability::ReadOnly,
2407    };
2408
2409    let guard = session
2410        .db()
2411        .await
2412        .connections_for_buffer_update(project_id, session.connection_id, capability)
2413        .await?;
2414
2415    let (host, guests) = &*guard;
2416
2417    broadcast(
2418        Some(session.connection_id),
2419        guests.iter().chain([host]).copied(),
2420        |connection_id| {
2421            session
2422                .peer
2423                .forward_send(session.connection_id, connection_id, message.clone())
2424        },
2425    );
2426
2427    Ok(())
2428}
2429
2430async fn forward_project_search_chunk(
2431    message: proto::FindSearchCandidatesChunk,
2432    response: Response<proto::FindSearchCandidatesChunk>,
2433    session: MessageContext,
2434) -> Result<()> {
2435    let peer_id = message.peer_id.context("missing peer_id")?;
2436    let payload = session
2437        .peer
2438        .forward_request(session.connection_id, peer_id.into(), message)
2439        .await?;
2440    response.send(payload)?;
2441    Ok(())
2442}
2443
2444/// Notify other participants that a project has been updated.
2445async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2446    request: T,
2447    session: MessageContext,
2448) -> Result<()> {
2449    let project_id = ProjectId::from_proto(request.remote_entity_id());
2450    let project_connection_ids = session
2451        .db()
2452        .await
2453        .project_connection_ids(project_id, session.connection_id, false)
2454        .await?;
2455
2456    broadcast(
2457        Some(session.connection_id),
2458        project_connection_ids.iter().copied(),
2459        |connection_id| {
2460            session
2461                .peer
2462                .forward_send(session.connection_id, connection_id, request.clone())
2463        },
2464    );
2465    Ok(())
2466}
2467
2468/// Start following another user in a call.
2469async fn follow(
2470    request: proto::Follow,
2471    response: Response<proto::Follow>,
2472    session: MessageContext,
2473) -> Result<()> {
2474    let room_id = RoomId::from_proto(request.room_id);
2475    let project_id = request.project_id.map(ProjectId::from_proto);
2476    let leader_id = request.leader_id.context("invalid leader id")?.into();
2477    let follower_id = session.connection_id;
2478
2479    session
2480        .db()
2481        .await
2482        .check_room_participants(room_id, leader_id, session.connection_id)
2483        .await?;
2484
2485    let response_payload = session.forward_request(leader_id, request).await?;
2486    response.send(response_payload)?;
2487
2488    if let Some(project_id) = project_id {
2489        let room = session
2490            .db()
2491            .await
2492            .follow(room_id, project_id, leader_id, follower_id)
2493            .await?;
2494        room_updated(&room, &session.peer);
2495    }
2496
2497    Ok(())
2498}
2499
2500/// Stop following another user in a call.
2501async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2502    let room_id = RoomId::from_proto(request.room_id);
2503    let project_id = request.project_id.map(ProjectId::from_proto);
2504    let leader_id = request.leader_id.context("invalid leader id")?.into();
2505    let follower_id = session.connection_id;
2506
2507    session
2508        .db()
2509        .await
2510        .check_room_participants(room_id, leader_id, session.connection_id)
2511        .await?;
2512
2513    session
2514        .peer
2515        .forward_send(session.connection_id, leader_id, request)?;
2516
2517    if let Some(project_id) = project_id {
2518        let room = session
2519            .db()
2520            .await
2521            .unfollow(room_id, project_id, leader_id, follower_id)
2522            .await?;
2523        room_updated(&room, &session.peer);
2524    }
2525
2526    Ok(())
2527}
2528
2529/// Notify everyone following you of your current location.
2530async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2531    let room_id = RoomId::from_proto(request.room_id);
2532    let database = session.db.lock().await;
2533
2534    let connection_ids = if let Some(project_id) = request.project_id {
2535        let project_id = ProjectId::from_proto(project_id);
2536        database
2537            .project_connection_ids(project_id, session.connection_id, true)
2538            .await?
2539    } else {
2540        database
2541            .room_connection_ids(room_id, session.connection_id)
2542            .await?
2543    };
2544
2545    // For now, don't send view update messages back to that view's current leader.
2546    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2547        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2548        _ => None,
2549    });
2550
2551    for connection_id in connection_ids.iter().cloned() {
2552        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2553            session
2554                .peer
2555                .forward_send(session.connection_id, connection_id, request.clone())?;
2556        }
2557    }
2558    Ok(())
2559}
2560
2561/// Get public data about users.
2562async fn get_users(
2563    request: proto::GetUsers,
2564    response: Response<proto::GetUsers>,
2565    session: MessageContext,
2566) -> Result<()> {
2567    let user_ids = request
2568        .user_ids
2569        .into_iter()
2570        .map(UserId::from_proto)
2571        .collect();
2572    let users = session
2573        .db()
2574        .await
2575        .get_users_by_ids(user_ids)
2576        .await?
2577        .into_iter()
2578        .map(|user| proto::User {
2579            id: user.id.to_proto(),
2580            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2581            github_login: user.github_login,
2582            name: user.name,
2583        })
2584        .collect();
2585    response.send(proto::UsersResponse { users })?;
2586    Ok(())
2587}
2588
2589/// Search for users (to invite) buy Github login
2590async fn fuzzy_search_users(
2591    request: proto::FuzzySearchUsers,
2592    response: Response<proto::FuzzySearchUsers>,
2593    session: MessageContext,
2594) -> Result<()> {
2595    let query = request.query;
2596    let users = match query.len() {
2597        0 => vec![],
2598        1 | 2 => session
2599            .db()
2600            .await
2601            .get_user_by_github_login(&query)
2602            .await?
2603            .into_iter()
2604            .collect(),
2605        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2606    };
2607    let users = users
2608        .into_iter()
2609        .filter(|user| user.id != session.user_id())
2610        .map(|user| proto::User {
2611            id: user.id.to_proto(),
2612            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2613            github_login: user.github_login,
2614            name: user.name,
2615        })
2616        .collect();
2617    response.send(proto::UsersResponse { users })?;
2618    Ok(())
2619}
2620
2621/// Send a contact request to another user.
2622async fn request_contact(
2623    request: proto::RequestContact,
2624    response: Response<proto::RequestContact>,
2625    session: MessageContext,
2626) -> Result<()> {
2627    let requester_id = session.user_id();
2628    let responder_id = UserId::from_proto(request.responder_id);
2629    if requester_id == responder_id {
2630        return Err(anyhow!("cannot add yourself as a contact"))?;
2631    }
2632
2633    let notifications = session
2634        .db()
2635        .await
2636        .send_contact_request(requester_id, responder_id)
2637        .await?;
2638
2639    // Update outgoing contact requests of requester
2640    let mut update = proto::UpdateContacts::default();
2641    update.outgoing_requests.push(responder_id.to_proto());
2642    for connection_id in session
2643        .connection_pool()
2644        .await
2645        .user_connection_ids(requester_id)
2646    {
2647        session.peer.send(connection_id, update.clone())?;
2648    }
2649
2650    // Update incoming contact requests of responder
2651    let mut update = proto::UpdateContacts::default();
2652    update
2653        .incoming_requests
2654        .push(proto::IncomingContactRequest {
2655            requester_id: requester_id.to_proto(),
2656        });
2657    let connection_pool = session.connection_pool().await;
2658    for connection_id in connection_pool.user_connection_ids(responder_id) {
2659        session.peer.send(connection_id, update.clone())?;
2660    }
2661
2662    send_notifications(&connection_pool, &session.peer, notifications);
2663
2664    response.send(proto::Ack {})?;
2665    Ok(())
2666}
2667
2668/// Accept or decline a contact request
2669async fn respond_to_contact_request(
2670    request: proto::RespondToContactRequest,
2671    response: Response<proto::RespondToContactRequest>,
2672    session: MessageContext,
2673) -> Result<()> {
2674    let responder_id = session.user_id();
2675    let requester_id = UserId::from_proto(request.requester_id);
2676    let db = session.db().await;
2677    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2678        db.dismiss_contact_notification(responder_id, requester_id)
2679            .await?;
2680    } else {
2681        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2682
2683        let notifications = db
2684            .respond_to_contact_request(responder_id, requester_id, accept)
2685            .await?;
2686        let requester_busy = db.is_user_busy(requester_id).await?;
2687        let responder_busy = db.is_user_busy(responder_id).await?;
2688
2689        let pool = session.connection_pool().await;
2690        // Update responder with new contact
2691        let mut update = proto::UpdateContacts::default();
2692        if accept {
2693            update
2694                .contacts
2695                .push(contact_for_user(requester_id, requester_busy, &pool));
2696        }
2697        update
2698            .remove_incoming_requests
2699            .push(requester_id.to_proto());
2700        for connection_id in pool.user_connection_ids(responder_id) {
2701            session.peer.send(connection_id, update.clone())?;
2702        }
2703
2704        // Update requester with new contact
2705        let mut update = proto::UpdateContacts::default();
2706        if accept {
2707            update
2708                .contacts
2709                .push(contact_for_user(responder_id, responder_busy, &pool));
2710        }
2711        update
2712            .remove_outgoing_requests
2713            .push(responder_id.to_proto());
2714
2715        for connection_id in pool.user_connection_ids(requester_id) {
2716            session.peer.send(connection_id, update.clone())?;
2717        }
2718
2719        send_notifications(&pool, &session.peer, notifications);
2720    }
2721
2722    response.send(proto::Ack {})?;
2723    Ok(())
2724}
2725
2726/// Remove a contact.
2727async fn remove_contact(
2728    request: proto::RemoveContact,
2729    response: Response<proto::RemoveContact>,
2730    session: MessageContext,
2731) -> Result<()> {
2732    let requester_id = session.user_id();
2733    let responder_id = UserId::from_proto(request.user_id);
2734    let db = session.db().await;
2735    let (contact_accepted, deleted_notification_id) =
2736        db.remove_contact(requester_id, responder_id).await?;
2737
2738    let pool = session.connection_pool().await;
2739    // Update outgoing contact requests of requester
2740    let mut update = proto::UpdateContacts::default();
2741    if contact_accepted {
2742        update.remove_contacts.push(responder_id.to_proto());
2743    } else {
2744        update
2745            .remove_outgoing_requests
2746            .push(responder_id.to_proto());
2747    }
2748    for connection_id in pool.user_connection_ids(requester_id) {
2749        session.peer.send(connection_id, update.clone())?;
2750    }
2751
2752    // Update incoming contact requests of responder
2753    let mut update = proto::UpdateContacts::default();
2754    if contact_accepted {
2755        update.remove_contacts.push(requester_id.to_proto());
2756    } else {
2757        update
2758            .remove_incoming_requests
2759            .push(requester_id.to_proto());
2760    }
2761    for connection_id in pool.user_connection_ids(responder_id) {
2762        session.peer.send(connection_id, update.clone())?;
2763        if let Some(notification_id) = deleted_notification_id {
2764            session.peer.send(
2765                connection_id,
2766                proto::DeleteNotification {
2767                    notification_id: notification_id.to_proto(),
2768                },
2769            )?;
2770        }
2771    }
2772
2773    response.send(proto::Ack {})?;
2774    Ok(())
2775}
2776
2777fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2778    version.0.minor < 139
2779}
2780
2781async fn subscribe_to_channels(
2782    _: proto::SubscribeToChannels,
2783    session: MessageContext,
2784) -> Result<()> {
2785    subscribe_user_to_channels(session.user_id(), &session).await?;
2786    Ok(())
2787}
2788
2789async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2790    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2791    let mut pool = session.connection_pool().await;
2792    for membership in &channels_for_user.channel_memberships {
2793        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2794    }
2795    session.peer.send(
2796        session.connection_id,
2797        build_update_user_channels(&channels_for_user),
2798    )?;
2799    session.peer.send(
2800        session.connection_id,
2801        build_channels_update(channels_for_user),
2802    )?;
2803    Ok(())
2804}
2805
2806/// Creates a new channel.
2807async fn create_channel(
2808    request: proto::CreateChannel,
2809    response: Response<proto::CreateChannel>,
2810    session: MessageContext,
2811) -> Result<()> {
2812    let db = session.db().await;
2813
2814    let parent_id = request.parent_id.map(ChannelId::from_proto);
2815    let (channel, membership) = db
2816        .create_channel(&request.name, parent_id, session.user_id())
2817        .await?;
2818
2819    let root_id = channel.root_id();
2820    let channel = Channel::from_model(channel);
2821
2822    response.send(proto::CreateChannelResponse {
2823        channel: Some(channel.to_proto()),
2824        parent_id: request.parent_id,
2825    })?;
2826
2827    let mut connection_pool = session.connection_pool().await;
2828    if let Some(membership) = membership {
2829        connection_pool.subscribe_to_channel(
2830            membership.user_id,
2831            membership.channel_id,
2832            membership.role,
2833        );
2834        let update = proto::UpdateUserChannels {
2835            channel_memberships: vec![proto::ChannelMembership {
2836                channel_id: membership.channel_id.to_proto(),
2837                role: membership.role.into(),
2838            }],
2839            ..Default::default()
2840        };
2841        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2842            session.peer.send(connection_id, update.clone())?;
2843        }
2844    }
2845
2846    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2847        if !role.can_see_channel(channel.visibility) {
2848            continue;
2849        }
2850
2851        let update = proto::UpdateChannels {
2852            channels: vec![channel.to_proto()],
2853            ..Default::default()
2854        };
2855        session.peer.send(connection_id, update.clone())?;
2856    }
2857
2858    Ok(())
2859}
2860
2861/// Delete a channel
2862async fn delete_channel(
2863    request: proto::DeleteChannel,
2864    response: Response<proto::DeleteChannel>,
2865    session: MessageContext,
2866) -> Result<()> {
2867    let db = session.db().await;
2868
2869    let channel_id = request.channel_id;
2870    let (root_channel, removed_channels) = db
2871        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2872        .await?;
2873    response.send(proto::Ack {})?;
2874
2875    // Notify members of removed channels
2876    let mut update = proto::UpdateChannels::default();
2877    update
2878        .delete_channels
2879        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2880
2881    let connection_pool = session.connection_pool().await;
2882    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2883        session.peer.send(connection_id, update.clone())?;
2884    }
2885
2886    Ok(())
2887}
2888
2889/// Invite someone to join a channel.
2890async fn invite_channel_member(
2891    request: proto::InviteChannelMember,
2892    response: Response<proto::InviteChannelMember>,
2893    session: MessageContext,
2894) -> Result<()> {
2895    let db = session.db().await;
2896    let channel_id = ChannelId::from_proto(request.channel_id);
2897    let invitee_id = UserId::from_proto(request.user_id);
2898    let InviteMemberResult {
2899        channel,
2900        notifications,
2901    } = db
2902        .invite_channel_member(
2903            channel_id,
2904            invitee_id,
2905            session.user_id(),
2906            request.role().into(),
2907        )
2908        .await?;
2909
2910    let update = proto::UpdateChannels {
2911        channel_invitations: vec![channel.to_proto()],
2912        ..Default::default()
2913    };
2914
2915    let connection_pool = session.connection_pool().await;
2916    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2917        session.peer.send(connection_id, update.clone())?;
2918    }
2919
2920    send_notifications(&connection_pool, &session.peer, notifications);
2921
2922    response.send(proto::Ack {})?;
2923    Ok(())
2924}
2925
2926/// remove someone from a channel
2927async fn remove_channel_member(
2928    request: proto::RemoveChannelMember,
2929    response: Response<proto::RemoveChannelMember>,
2930    session: MessageContext,
2931) -> Result<()> {
2932    let db = session.db().await;
2933    let channel_id = ChannelId::from_proto(request.channel_id);
2934    let member_id = UserId::from_proto(request.user_id);
2935
2936    let RemoveChannelMemberResult {
2937        membership_update,
2938        notification_id,
2939    } = db
2940        .remove_channel_member(channel_id, member_id, session.user_id())
2941        .await?;
2942
2943    let mut connection_pool = session.connection_pool().await;
2944    notify_membership_updated(
2945        &mut connection_pool,
2946        membership_update,
2947        member_id,
2948        &session.peer,
2949    );
2950    for connection_id in connection_pool.user_connection_ids(member_id) {
2951        if let Some(notification_id) = notification_id {
2952            session
2953                .peer
2954                .send(
2955                    connection_id,
2956                    proto::DeleteNotification {
2957                        notification_id: notification_id.to_proto(),
2958                    },
2959                )
2960                .trace_err();
2961        }
2962    }
2963
2964    response.send(proto::Ack {})?;
2965    Ok(())
2966}
2967
2968/// Toggle the channel between public and private.
2969/// Care is taken to maintain the invariant that public channels only descend from public channels,
2970/// (though members-only channels can appear at any point in the hierarchy).
2971async fn set_channel_visibility(
2972    request: proto::SetChannelVisibility,
2973    response: Response<proto::SetChannelVisibility>,
2974    session: MessageContext,
2975) -> Result<()> {
2976    let db = session.db().await;
2977    let channel_id = ChannelId::from_proto(request.channel_id);
2978    let visibility = request.visibility().into();
2979
2980    let channel_model = db
2981        .set_channel_visibility(channel_id, visibility, session.user_id())
2982        .await?;
2983    let root_id = channel_model.root_id();
2984    let channel = Channel::from_model(channel_model);
2985
2986    let mut connection_pool = session.connection_pool().await;
2987    for (user_id, role) in connection_pool
2988        .channel_user_ids(root_id)
2989        .collect::<Vec<_>>()
2990        .into_iter()
2991    {
2992        let update = if role.can_see_channel(channel.visibility) {
2993            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2994            proto::UpdateChannels {
2995                channels: vec![channel.to_proto()],
2996                ..Default::default()
2997            }
2998        } else {
2999            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3000            proto::UpdateChannels {
3001                delete_channels: vec![channel.id.to_proto()],
3002                ..Default::default()
3003            }
3004        };
3005
3006        for connection_id in connection_pool.user_connection_ids(user_id) {
3007            session.peer.send(connection_id, update.clone())?;
3008        }
3009    }
3010
3011    response.send(proto::Ack {})?;
3012    Ok(())
3013}
3014
3015/// Alter the role for a user in the channel.
3016async fn set_channel_member_role(
3017    request: proto::SetChannelMemberRole,
3018    response: Response<proto::SetChannelMemberRole>,
3019    session: MessageContext,
3020) -> Result<()> {
3021    let db = session.db().await;
3022    let channel_id = ChannelId::from_proto(request.channel_id);
3023    let member_id = UserId::from_proto(request.user_id);
3024    let result = db
3025        .set_channel_member_role(
3026            channel_id,
3027            session.user_id(),
3028            member_id,
3029            request.role().into(),
3030        )
3031        .await?;
3032
3033    match result {
3034        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3035            let mut connection_pool = session.connection_pool().await;
3036            notify_membership_updated(
3037                &mut connection_pool,
3038                membership_update,
3039                member_id,
3040                &session.peer,
3041            )
3042        }
3043        db::SetMemberRoleResult::InviteUpdated(channel) => {
3044            let update = proto::UpdateChannels {
3045                channel_invitations: vec![channel.to_proto()],
3046                ..Default::default()
3047            };
3048
3049            for connection_id in session
3050                .connection_pool()
3051                .await
3052                .user_connection_ids(member_id)
3053            {
3054                session.peer.send(connection_id, update.clone())?;
3055            }
3056        }
3057    }
3058
3059    response.send(proto::Ack {})?;
3060    Ok(())
3061}
3062
3063/// Change the name of a channel
3064async fn rename_channel(
3065    request: proto::RenameChannel,
3066    response: Response<proto::RenameChannel>,
3067    session: MessageContext,
3068) -> Result<()> {
3069    let db = session.db().await;
3070    let channel_id = ChannelId::from_proto(request.channel_id);
3071    let channel_model = db
3072        .rename_channel(channel_id, session.user_id(), &request.name)
3073        .await?;
3074    let root_id = channel_model.root_id();
3075    let channel = Channel::from_model(channel_model);
3076
3077    response.send(proto::RenameChannelResponse {
3078        channel: Some(channel.to_proto()),
3079    })?;
3080
3081    let connection_pool = session.connection_pool().await;
3082    let update = proto::UpdateChannels {
3083        channels: vec![channel.to_proto()],
3084        ..Default::default()
3085    };
3086    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3087        if role.can_see_channel(channel.visibility) {
3088            session.peer.send(connection_id, update.clone())?;
3089        }
3090    }
3091
3092    Ok(())
3093}
3094
3095/// Move a channel to a new parent.
3096async fn move_channel(
3097    request: proto::MoveChannel,
3098    response: Response<proto::MoveChannel>,
3099    session: MessageContext,
3100) -> Result<()> {
3101    let channel_id = ChannelId::from_proto(request.channel_id);
3102    let to = ChannelId::from_proto(request.to);
3103
3104    let (root_id, channels) = session
3105        .db()
3106        .await
3107        .move_channel(channel_id, to, session.user_id())
3108        .await?;
3109
3110    let connection_pool = session.connection_pool().await;
3111    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3112        let channels = channels
3113            .iter()
3114            .filter_map(|channel| {
3115                if role.can_see_channel(channel.visibility) {
3116                    Some(channel.to_proto())
3117                } else {
3118                    None
3119                }
3120            })
3121            .collect::<Vec<_>>();
3122        if channels.is_empty() {
3123            continue;
3124        }
3125
3126        let update = proto::UpdateChannels {
3127            channels,
3128            ..Default::default()
3129        };
3130
3131        session.peer.send(connection_id, update.clone())?;
3132    }
3133
3134    response.send(Ack {})?;
3135    Ok(())
3136}
3137
3138async fn reorder_channel(
3139    request: proto::ReorderChannel,
3140    response: Response<proto::ReorderChannel>,
3141    session: MessageContext,
3142) -> Result<()> {
3143    let channel_id = ChannelId::from_proto(request.channel_id);
3144    let direction = request.direction();
3145
3146    let updated_channels = session
3147        .db()
3148        .await
3149        .reorder_channel(channel_id, direction, session.user_id())
3150        .await?;
3151
3152    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3153        let connection_pool = session.connection_pool().await;
3154        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3155            let channels = updated_channels
3156                .iter()
3157                .filter_map(|channel| {
3158                    if role.can_see_channel(channel.visibility) {
3159                        Some(channel.to_proto())
3160                    } else {
3161                        None
3162                    }
3163                })
3164                .collect::<Vec<_>>();
3165
3166            if channels.is_empty() {
3167                continue;
3168            }
3169
3170            let update = proto::UpdateChannels {
3171                channels,
3172                ..Default::default()
3173            };
3174
3175            session.peer.send(connection_id, update.clone())?;
3176        }
3177    }
3178
3179    response.send(Ack {})?;
3180    Ok(())
3181}
3182
3183/// Get the list of channel members
3184async fn get_channel_members(
3185    request: proto::GetChannelMembers,
3186    response: Response<proto::GetChannelMembers>,
3187    session: MessageContext,
3188) -> Result<()> {
3189    let db = session.db().await;
3190    let channel_id = ChannelId::from_proto(request.channel_id);
3191    let limit = if request.limit == 0 {
3192        u16::MAX as u64
3193    } else {
3194        request.limit
3195    };
3196    let (members, users) = db
3197        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3198        .await?;
3199    response.send(proto::GetChannelMembersResponse { members, users })?;
3200    Ok(())
3201}
3202
3203/// Accept or decline a channel invitation.
3204async fn respond_to_channel_invite(
3205    request: proto::RespondToChannelInvite,
3206    response: Response<proto::RespondToChannelInvite>,
3207    session: MessageContext,
3208) -> Result<()> {
3209    let db = session.db().await;
3210    let channel_id = ChannelId::from_proto(request.channel_id);
3211    let RespondToChannelInvite {
3212        membership_update,
3213        notifications,
3214    } = db
3215        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3216        .await?;
3217
3218    let mut connection_pool = session.connection_pool().await;
3219    if let Some(membership_update) = membership_update {
3220        notify_membership_updated(
3221            &mut connection_pool,
3222            membership_update,
3223            session.user_id(),
3224            &session.peer,
3225        );
3226    } else {
3227        let update = proto::UpdateChannels {
3228            remove_channel_invitations: vec![channel_id.to_proto()],
3229            ..Default::default()
3230        };
3231
3232        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3233            session.peer.send(connection_id, update.clone())?;
3234        }
3235    };
3236
3237    send_notifications(&connection_pool, &session.peer, notifications);
3238
3239    response.send(proto::Ack {})?;
3240
3241    Ok(())
3242}
3243
3244/// Join the channels' room
3245async fn join_channel(
3246    request: proto::JoinChannel,
3247    response: Response<proto::JoinChannel>,
3248    session: MessageContext,
3249) -> Result<()> {
3250    let channel_id = ChannelId::from_proto(request.channel_id);
3251    join_channel_internal(channel_id, Box::new(response), session).await
3252}
3253
3254trait JoinChannelInternalResponse {
3255    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3256}
3257impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3258    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3259        Response::<proto::JoinChannel>::send(self, result)
3260    }
3261}
3262impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3263    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3264        Response::<proto::JoinRoom>::send(self, result)
3265    }
3266}
3267
3268async fn join_channel_internal(
3269    channel_id: ChannelId,
3270    response: Box<impl JoinChannelInternalResponse>,
3271    session: MessageContext,
3272) -> Result<()> {
3273    let joined_room = {
3274        let mut db = session.db().await;
3275        // If zed quits without leaving the room, and the user re-opens zed before the
3276        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3277        // room they were in.
3278        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3279            tracing::info!(
3280                stale_connection_id = %connection,
3281                "cleaning up stale connection",
3282            );
3283            drop(db);
3284            leave_room_for_session(&session, connection).await?;
3285            db = session.db().await;
3286        }
3287
3288        let (joined_room, membership_updated, role) = db
3289            .join_channel(channel_id, session.user_id(), session.connection_id)
3290            .await?;
3291
3292        let live_kit_connection_info =
3293            session
3294                .app_state
3295                .livekit_client
3296                .as_ref()
3297                .and_then(|live_kit| {
3298                    let (can_publish, token) = if role == ChannelRole::Guest {
3299                        (
3300                            false,
3301                            live_kit
3302                                .guest_token(
3303                                    &joined_room.room.livekit_room,
3304                                    &session.user_id().to_string(),
3305                                )
3306                                .trace_err()?,
3307                        )
3308                    } else {
3309                        (
3310                            true,
3311                            live_kit
3312                                .room_token(
3313                                    &joined_room.room.livekit_room,
3314                                    &session.user_id().to_string(),
3315                                )
3316                                .trace_err()?,
3317                        )
3318                    };
3319
3320                    Some(LiveKitConnectionInfo {
3321                        server_url: live_kit.url().into(),
3322                        token,
3323                        can_publish,
3324                    })
3325                });
3326
3327        response.send(proto::JoinRoomResponse {
3328            room: Some(joined_room.room.clone()),
3329            channel_id: joined_room
3330                .channel
3331                .as_ref()
3332                .map(|channel| channel.id.to_proto()),
3333            live_kit_connection_info,
3334        })?;
3335
3336        let mut connection_pool = session.connection_pool().await;
3337        if let Some(membership_updated) = membership_updated {
3338            notify_membership_updated(
3339                &mut connection_pool,
3340                membership_updated,
3341                session.user_id(),
3342                &session.peer,
3343            );
3344        }
3345
3346        room_updated(&joined_room.room, &session.peer);
3347
3348        joined_room
3349    };
3350
3351    channel_updated(
3352        &joined_room.channel.context("channel not returned")?,
3353        &joined_room.room,
3354        &session.peer,
3355        &*session.connection_pool().await,
3356    );
3357
3358    update_user_contacts(session.user_id(), &session).await?;
3359    Ok(())
3360}
3361
3362/// Start editing the channel notes
3363async fn join_channel_buffer(
3364    request: proto::JoinChannelBuffer,
3365    response: Response<proto::JoinChannelBuffer>,
3366    session: MessageContext,
3367) -> Result<()> {
3368    let db = session.db().await;
3369    let channel_id = ChannelId::from_proto(request.channel_id);
3370
3371    let open_response = db
3372        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3373        .await?;
3374
3375    let collaborators = open_response.collaborators.clone();
3376    response.send(open_response)?;
3377
3378    let update = UpdateChannelBufferCollaborators {
3379        channel_id: channel_id.to_proto(),
3380        collaborators: collaborators.clone(),
3381    };
3382    channel_buffer_updated(
3383        session.connection_id,
3384        collaborators
3385            .iter()
3386            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3387        &update,
3388        &session.peer,
3389    );
3390
3391    Ok(())
3392}
3393
3394/// Edit the channel notes
3395async fn update_channel_buffer(
3396    request: proto::UpdateChannelBuffer,
3397    session: MessageContext,
3398) -> Result<()> {
3399    let db = session.db().await;
3400    let channel_id = ChannelId::from_proto(request.channel_id);
3401
3402    let (collaborators, epoch, version) = db
3403        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3404        .await?;
3405
3406    channel_buffer_updated(
3407        session.connection_id,
3408        collaborators.clone(),
3409        &proto::UpdateChannelBuffer {
3410            channel_id: channel_id.to_proto(),
3411            operations: request.operations,
3412        },
3413        &session.peer,
3414    );
3415
3416    let pool = &*session.connection_pool().await;
3417
3418    let non_collaborators =
3419        pool.channel_connection_ids(channel_id)
3420            .filter_map(|(connection_id, _)| {
3421                if collaborators.contains(&connection_id) {
3422                    None
3423                } else {
3424                    Some(connection_id)
3425                }
3426            });
3427
3428    broadcast(None, non_collaborators, |peer_id| {
3429        session.peer.send(
3430            peer_id,
3431            proto::UpdateChannels {
3432                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3433                    channel_id: channel_id.to_proto(),
3434                    epoch: epoch as u64,
3435                    version: version.clone(),
3436                }],
3437                ..Default::default()
3438            },
3439        )
3440    });
3441
3442    Ok(())
3443}
3444
3445/// Rejoin the channel notes after a connection blip
3446async fn rejoin_channel_buffers(
3447    request: proto::RejoinChannelBuffers,
3448    response: Response<proto::RejoinChannelBuffers>,
3449    session: MessageContext,
3450) -> Result<()> {
3451    let db = session.db().await;
3452    let buffers = db
3453        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3454        .await?;
3455
3456    for rejoined_buffer in &buffers {
3457        let collaborators_to_notify = rejoined_buffer
3458            .buffer
3459            .collaborators
3460            .iter()
3461            .filter_map(|c| Some(c.peer_id?.into()));
3462        channel_buffer_updated(
3463            session.connection_id,
3464            collaborators_to_notify,
3465            &proto::UpdateChannelBufferCollaborators {
3466                channel_id: rejoined_buffer.buffer.channel_id,
3467                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3468            },
3469            &session.peer,
3470        );
3471    }
3472
3473    response.send(proto::RejoinChannelBuffersResponse {
3474        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3475    })?;
3476
3477    Ok(())
3478}
3479
3480/// Stop editing the channel notes
3481async fn leave_channel_buffer(
3482    request: proto::LeaveChannelBuffer,
3483    response: Response<proto::LeaveChannelBuffer>,
3484    session: MessageContext,
3485) -> Result<()> {
3486    let db = session.db().await;
3487    let channel_id = ChannelId::from_proto(request.channel_id);
3488
3489    let left_buffer = db
3490        .leave_channel_buffer(channel_id, session.connection_id)
3491        .await?;
3492
3493    response.send(Ack {})?;
3494
3495    channel_buffer_updated(
3496        session.connection_id,
3497        left_buffer.connections,
3498        &proto::UpdateChannelBufferCollaborators {
3499            channel_id: channel_id.to_proto(),
3500            collaborators: left_buffer.collaborators,
3501        },
3502        &session.peer,
3503    );
3504
3505    Ok(())
3506}
3507
3508fn channel_buffer_updated<T: EnvelopedMessage>(
3509    sender_id: ConnectionId,
3510    collaborators: impl IntoIterator<Item = ConnectionId>,
3511    message: &T,
3512    peer: &Peer,
3513) {
3514    broadcast(Some(sender_id), collaborators, |peer_id| {
3515        peer.send(peer_id, message.clone())
3516    });
3517}
3518
3519fn send_notifications(
3520    connection_pool: &ConnectionPool,
3521    peer: &Peer,
3522    notifications: db::NotificationBatch,
3523) {
3524    for (user_id, notification) in notifications {
3525        for connection_id in connection_pool.user_connection_ids(user_id) {
3526            if let Err(error) = peer.send(
3527                connection_id,
3528                proto::AddNotification {
3529                    notification: Some(notification.clone()),
3530                },
3531            ) {
3532                tracing::error!(
3533                    "failed to send notification to {:?} {}",
3534                    connection_id,
3535                    error
3536                );
3537            }
3538        }
3539    }
3540}
3541
3542/// Send a message to the channel
3543async fn send_channel_message(
3544    _request: proto::SendChannelMessage,
3545    _response: Response<proto::SendChannelMessage>,
3546    _session: MessageContext,
3547) -> Result<()> {
3548    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3549}
3550
3551/// Delete a channel message
3552async fn remove_channel_message(
3553    _request: proto::RemoveChannelMessage,
3554    _response: Response<proto::RemoveChannelMessage>,
3555    _session: MessageContext,
3556) -> Result<()> {
3557    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3558}
3559
3560async fn update_channel_message(
3561    _request: proto::UpdateChannelMessage,
3562    _response: Response<proto::UpdateChannelMessage>,
3563    _session: MessageContext,
3564) -> Result<()> {
3565    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3566}
3567
3568/// Mark a channel message as read
3569async fn acknowledge_channel_message(
3570    _request: proto::AckChannelMessage,
3571    _session: MessageContext,
3572) -> Result<()> {
3573    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3574}
3575
3576/// Mark a buffer version as synced
3577async fn acknowledge_buffer_version(
3578    request: proto::AckBufferOperation,
3579    session: MessageContext,
3580) -> Result<()> {
3581    let buffer_id = BufferId::from_proto(request.buffer_id);
3582    session
3583        .db()
3584        .await
3585        .observe_buffer_version(
3586            buffer_id,
3587            session.user_id(),
3588            request.epoch as i32,
3589            &request.version,
3590        )
3591        .await?;
3592    Ok(())
3593}
3594
3595/// Start receiving chat updates for a channel
3596async fn join_channel_chat(
3597    _request: proto::JoinChannelChat,
3598    _response: Response<proto::JoinChannelChat>,
3599    _session: MessageContext,
3600) -> Result<()> {
3601    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3602}
3603
3604/// Stop receiving chat updates for a channel
3605async fn leave_channel_chat(
3606    _request: proto::LeaveChannelChat,
3607    _session: MessageContext,
3608) -> Result<()> {
3609    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3610}
3611
3612/// Retrieve the chat history for a channel
3613async fn get_channel_messages(
3614    _request: proto::GetChannelMessages,
3615    _response: Response<proto::GetChannelMessages>,
3616    _session: MessageContext,
3617) -> Result<()> {
3618    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3619}
3620
3621/// Retrieve specific chat messages
3622async fn get_channel_messages_by_id(
3623    _request: proto::GetChannelMessagesById,
3624    _response: Response<proto::GetChannelMessagesById>,
3625    _session: MessageContext,
3626) -> Result<()> {
3627    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3628}
3629
3630/// Retrieve the current users notifications
3631async fn get_notifications(
3632    request: proto::GetNotifications,
3633    response: Response<proto::GetNotifications>,
3634    session: MessageContext,
3635) -> Result<()> {
3636    let notifications = session
3637        .db()
3638        .await
3639        .get_notifications(
3640            session.user_id(),
3641            NOTIFICATION_COUNT_PER_PAGE,
3642            request.before_id.map(db::NotificationId::from_proto),
3643        )
3644        .await?;
3645    response.send(proto::GetNotificationsResponse {
3646        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3647        notifications,
3648    })?;
3649    Ok(())
3650}
3651
3652/// Mark notifications as read
3653async fn mark_notification_as_read(
3654    request: proto::MarkNotificationRead,
3655    response: Response<proto::MarkNotificationRead>,
3656    session: MessageContext,
3657) -> Result<()> {
3658    let database = &session.db().await;
3659    let notifications = database
3660        .mark_notification_as_read_by_id(
3661            session.user_id(),
3662            NotificationId::from_proto(request.notification_id),
3663        )
3664        .await?;
3665    send_notifications(
3666        &*session.connection_pool().await,
3667        &session.peer,
3668        notifications,
3669    );
3670    response.send(proto::Ack {})?;
3671    Ok(())
3672}
3673
3674fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3675    let message = match message {
3676        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3677        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3678        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3679        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3680        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3681            code: frame.code.into(),
3682            reason: frame.reason.as_str().to_owned().into(),
3683        })),
3684        // We should never receive a frame while reading the message, according
3685        // to the `tungstenite` maintainers:
3686        //
3687        // > It cannot occur when you read messages from the WebSocket, but it
3688        // > can be used when you want to send the raw frames (e.g. you want to
3689        // > send the frames to the WebSocket without composing the full message first).
3690        // >
3691        // > — https://github.com/snapview/tungstenite-rs/issues/268
3692        TungsteniteMessage::Frame(_) => {
3693            bail!("received an unexpected frame while reading the message")
3694        }
3695    };
3696
3697    Ok(message)
3698}
3699
3700fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3701    match message {
3702        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3703        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3704        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3705        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3706        AxumMessage::Close(frame) => {
3707            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3708                code: frame.code.into(),
3709                reason: frame.reason.as_ref().into(),
3710            }))
3711        }
3712    }
3713}
3714
3715fn notify_membership_updated(
3716    connection_pool: &mut ConnectionPool,
3717    result: MembershipUpdated,
3718    user_id: UserId,
3719    peer: &Peer,
3720) {
3721    for membership in &result.new_channels.channel_memberships {
3722        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3723    }
3724    for channel_id in &result.removed_channels {
3725        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3726    }
3727
3728    let user_channels_update = proto::UpdateUserChannels {
3729        channel_memberships: result
3730            .new_channels
3731            .channel_memberships
3732            .iter()
3733            .map(|cm| proto::ChannelMembership {
3734                channel_id: cm.channel_id.to_proto(),
3735                role: cm.role.into(),
3736            })
3737            .collect(),
3738        ..Default::default()
3739    };
3740
3741    let mut update = build_channels_update(result.new_channels);
3742    update.delete_channels = result
3743        .removed_channels
3744        .into_iter()
3745        .map(|id| id.to_proto())
3746        .collect();
3747    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3748
3749    for connection_id in connection_pool.user_connection_ids(user_id) {
3750        peer.send(connection_id, user_channels_update.clone())
3751            .trace_err();
3752        peer.send(connection_id, update.clone()).trace_err();
3753    }
3754}
3755
3756fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3757    proto::UpdateUserChannels {
3758        channel_memberships: channels
3759            .channel_memberships
3760            .iter()
3761            .map(|m| proto::ChannelMembership {
3762                channel_id: m.channel_id.to_proto(),
3763                role: m.role.into(),
3764            })
3765            .collect(),
3766        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3767    }
3768}
3769
3770fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3771    let mut update = proto::UpdateChannels::default();
3772
3773    for channel in channels.channels {
3774        update.channels.push(channel.to_proto());
3775    }
3776
3777    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3778
3779    for (channel_id, participants) in channels.channel_participants {
3780        update
3781            .channel_participants
3782            .push(proto::ChannelParticipants {
3783                channel_id: channel_id.to_proto(),
3784                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3785            });
3786    }
3787
3788    for channel in channels.invited_channels {
3789        update.channel_invitations.push(channel.to_proto());
3790    }
3791
3792    update
3793}
3794
3795fn build_initial_contacts_update(
3796    contacts: Vec<db::Contact>,
3797    pool: &ConnectionPool,
3798) -> proto::UpdateContacts {
3799    let mut update = proto::UpdateContacts::default();
3800
3801    for contact in contacts {
3802        match contact {
3803            db::Contact::Accepted { user_id, busy } => {
3804                update.contacts.push(contact_for_user(user_id, busy, pool));
3805            }
3806            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3807            db::Contact::Incoming { user_id } => {
3808                update
3809                    .incoming_requests
3810                    .push(proto::IncomingContactRequest {
3811                        requester_id: user_id.to_proto(),
3812                    })
3813            }
3814        }
3815    }
3816
3817    update
3818}
3819
3820fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3821    proto::Contact {
3822        user_id: user_id.to_proto(),
3823        online: pool.is_user_online(user_id),
3824        busy,
3825    }
3826}
3827
3828fn room_updated(room: &proto::Room, peer: &Peer) {
3829    broadcast(
3830        None,
3831        room.participants
3832            .iter()
3833            .filter_map(|participant| Some(participant.peer_id?.into())),
3834        |peer_id| {
3835            peer.send(
3836                peer_id,
3837                proto::RoomUpdated {
3838                    room: Some(room.clone()),
3839                },
3840            )
3841        },
3842    );
3843}
3844
3845fn channel_updated(
3846    channel: &db::channel::Model,
3847    room: &proto::Room,
3848    peer: &Peer,
3849    pool: &ConnectionPool,
3850) {
3851    let participants = room
3852        .participants
3853        .iter()
3854        .map(|p| p.user_id)
3855        .collect::<Vec<_>>();
3856
3857    broadcast(
3858        None,
3859        pool.channel_connection_ids(channel.root_id())
3860            .filter_map(|(channel_id, role)| {
3861                role.can_see_channel(channel.visibility)
3862                    .then_some(channel_id)
3863            }),
3864        |peer_id| {
3865            peer.send(
3866                peer_id,
3867                proto::UpdateChannels {
3868                    channel_participants: vec![proto::ChannelParticipants {
3869                        channel_id: channel.id.to_proto(),
3870                        participant_user_ids: participants.clone(),
3871                    }],
3872                    ..Default::default()
3873                },
3874            )
3875        },
3876    );
3877}
3878
3879async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3880    let db = session.db().await;
3881
3882    let contacts = db.get_contacts(user_id).await?;
3883    let busy = db.is_user_busy(user_id).await?;
3884
3885    let pool = session.connection_pool().await;
3886    let updated_contact = contact_for_user(user_id, busy, &pool);
3887    for contact in contacts {
3888        if let db::Contact::Accepted {
3889            user_id: contact_user_id,
3890            ..
3891        } = contact
3892        {
3893            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3894                session
3895                    .peer
3896                    .send(
3897                        contact_conn_id,
3898                        proto::UpdateContacts {
3899                            contacts: vec![updated_contact.clone()],
3900                            remove_contacts: Default::default(),
3901                            incoming_requests: Default::default(),
3902                            remove_incoming_requests: Default::default(),
3903                            outgoing_requests: Default::default(),
3904                            remove_outgoing_requests: Default::default(),
3905                        },
3906                    )
3907                    .trace_err();
3908            }
3909        }
3910    }
3911    Ok(())
3912}
3913
3914async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3915    let mut contacts_to_update = HashSet::default();
3916
3917    let room_id;
3918    let canceled_calls_to_user_ids;
3919    let livekit_room;
3920    let delete_livekit_room;
3921    let room;
3922    let channel;
3923
3924    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3925        contacts_to_update.insert(session.user_id());
3926
3927        for project in left_room.left_projects.values() {
3928            project_left(project, session);
3929        }
3930
3931        room_id = RoomId::from_proto(left_room.room.id);
3932        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3933        livekit_room = mem::take(&mut left_room.room.livekit_room);
3934        delete_livekit_room = left_room.deleted;
3935        room = mem::take(&mut left_room.room);
3936        channel = mem::take(&mut left_room.channel);
3937
3938        room_updated(&room, &session.peer);
3939    } else {
3940        return Ok(());
3941    }
3942
3943    if let Some(channel) = channel {
3944        channel_updated(
3945            &channel,
3946            &room,
3947            &session.peer,
3948            &*session.connection_pool().await,
3949        );
3950    }
3951
3952    {
3953        let pool = session.connection_pool().await;
3954        for canceled_user_id in canceled_calls_to_user_ids {
3955            for connection_id in pool.user_connection_ids(canceled_user_id) {
3956                session
3957                    .peer
3958                    .send(
3959                        connection_id,
3960                        proto::CallCanceled {
3961                            room_id: room_id.to_proto(),
3962                        },
3963                    )
3964                    .trace_err();
3965            }
3966            contacts_to_update.insert(canceled_user_id);
3967        }
3968    }
3969
3970    for contact_user_id in contacts_to_update {
3971        update_user_contacts(contact_user_id, session).await?;
3972    }
3973
3974    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3975        live_kit
3976            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3977            .await
3978            .trace_err();
3979
3980        if delete_livekit_room {
3981            live_kit.delete_room(livekit_room).await.trace_err();
3982        }
3983    }
3984
3985    Ok(())
3986}
3987
3988async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3989    let left_channel_buffers = session
3990        .db()
3991        .await
3992        .leave_channel_buffers(session.connection_id)
3993        .await?;
3994
3995    for left_buffer in left_channel_buffers {
3996        channel_buffer_updated(
3997            session.connection_id,
3998            left_buffer.connections,
3999            &proto::UpdateChannelBufferCollaborators {
4000                channel_id: left_buffer.channel_id.to_proto(),
4001                collaborators: left_buffer.collaborators,
4002            },
4003            &session.peer,
4004        );
4005    }
4006
4007    Ok(())
4008}
4009
4010fn project_left(project: &db::LeftProject, session: &Session) {
4011    for connection_id in &project.connection_ids {
4012        if project.should_unshare {
4013            session
4014                .peer
4015                .send(
4016                    *connection_id,
4017                    proto::UnshareProject {
4018                        project_id: project.id.to_proto(),
4019                    },
4020                )
4021                .trace_err();
4022        } else {
4023            session
4024                .peer
4025                .send(
4026                    *connection_id,
4027                    proto::RemoveProjectCollaborator {
4028                        project_id: project.id.to_proto(),
4029                        peer_id: Some(session.connection_id.into()),
4030                    },
4031                )
4032                .trace_err();
4033        }
4034    }
4035}
4036
4037async fn share_agent_thread(
4038    request: proto::ShareAgentThread,
4039    response: Response<proto::ShareAgentThread>,
4040    session: MessageContext,
4041) -> Result<()> {
4042    let user_id = session.user_id();
4043
4044    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4045        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4046
4047    session
4048        .db()
4049        .await
4050        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4051        .await?;
4052
4053    response.send(proto::Ack {})?;
4054
4055    Ok(())
4056}
4057
4058async fn get_shared_agent_thread(
4059    request: proto::GetSharedAgentThread,
4060    response: Response<proto::GetSharedAgentThread>,
4061    session: MessageContext,
4062) -> Result<()> {
4063    let share_id = SharedThreadId::from_proto(request.session_id)
4064        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4065
4066    let result = session.db().await.get_shared_thread(share_id).await?;
4067
4068    match result {
4069        Some((thread, username)) => {
4070            response.send(proto::GetSharedAgentThreadResponse {
4071                title: thread.title,
4072                thread_data: thread.data,
4073                sharer_username: username,
4074                created_at: thread.created_at.and_utc().to_rfc3339(),
4075            })?;
4076        }
4077        None => {
4078            return Err(anyhow!("Shared thread not found").into());
4079        }
4080    }
4081
4082    Ok(())
4083}
4084
4085pub trait ResultExt {
4086    type Ok;
4087
4088    fn trace_err(self) -> Option<Self::Ok>;
4089}
4090
4091impl<T, E> ResultExt for Result<T, E>
4092where
4093    E: std::fmt::Debug,
4094{
4095    type Ok = T;
4096
4097    #[track_caller]
4098    fn trace_err(self) -> Option<T> {
4099        match self {
4100            Ok(value) => Some(value),
4101            Err(error) => {
4102                tracing::error!("{:?}", error);
4103                None
4104            }
4105        }
4106    }
4107}