rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133}
 134
 135impl Principal {
 136    fn update_span(&self, span: &tracing::Span) {
 137        match &self {
 138            Principal::User(user) => {
 139                span.record("user_id", user.id.0);
 140                span.record("login", &user.github_login);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct MessageContext {
 148    session: Session,
 149    span: tracing::Span,
 150}
 151
 152impl Deref for MessageContext {
 153    type Target = Session;
 154
 155    fn deref(&self) -> &Self::Target {
 156        &self.session
 157    }
 158}
 159
 160impl MessageContext {
 161    pub fn forward_request<T: RequestMessage>(
 162        &self,
 163        receiver_id: ConnectionId,
 164        request: T,
 165    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 166        let request_start_time = Instant::now();
 167        let span = self.span.clone();
 168        tracing::info!("start forwarding request");
 169        self.peer
 170            .forward_request(self.connection_id, receiver_id, request)
 171            .inspect(move |_| {
 172                span.record(
 173                    HOST_WAITING_MS,
 174                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 175                );
 176            })
 177            .inspect_err(|_| tracing::error!("error forwarding request"))
 178            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct Session {
 184    principal: Principal,
 185    connection_id: ConnectionId,
 186    db: Arc<tokio::sync::Mutex<DbHandle>>,
 187    peer: Arc<Peer>,
 188    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 189    app_state: Arc<AppState>,
 190    /// The GeoIP country code for the user.
 191    #[allow(unused)]
 192    geoip_country_code: Option<String>,
 193    #[allow(unused)]
 194    system_id: Option<String>,
 195    _executor: Executor,
 196}
 197
 198impl Session {
 199    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 200        #[cfg(feature = "test-support")]
 201        tokio::task::yield_now().await;
 202        let guard = self.db.lock().await;
 203        #[cfg(feature = "test-support")]
 204        tokio::task::yield_now().await;
 205        guard
 206    }
 207
 208    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 209        #[cfg(feature = "test-support")]
 210        tokio::task::yield_now().await;
 211        let guard = self.connection_pool.lock();
 212        ConnectionPoolGuard {
 213            guard,
 214            _not_send: PhantomData,
 215        }
 216    }
 217
 218    #[expect(dead_code)]
 219    fn is_staff(&self) -> bool {
 220        match &self.principal {
 221            Principal::User(user) => user.admin,
 222        }
 223    }
 224
 225    fn user_id(&self) -> UserId {
 226        match &self.principal {
 227            Principal::User(user) => user.id,
 228        }
 229    }
 230}
 231
 232impl Debug for Session {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 234        let mut result = f.debug_struct("Session");
 235        match &self.principal {
 236            Principal::User(user) => {
 237                result.field("user", &user.github_login);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct DbHandle(Arc<Database>);
 245
 246impl Deref for DbHandle {
 247    type Target = Database;
 248
 249    fn deref(&self) -> &Self::Target {
 250        self.0.as_ref()
 251    }
 252}
 253
 254pub struct Server {
 255    id: parking_lot::Mutex<ServerId>,
 256    peer: Arc<Peer>,
 257    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 258    app_state: Arc<AppState>,
 259    handlers: HashMap<TypeId, MessageHandler>,
 260    teardown: watch::Sender<bool>,
 261}
 262
 263struct ConnectionPoolGuard<'a> {
 264    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 265    _not_send: PhantomData<Rc<()>>,
 266}
 267
 268impl Server {
 269    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 270        let mut server = Self {
 271            id: parking_lot::Mutex::new(id),
 272            peer: Peer::new(id.0 as u32),
 273            app_state,
 274            connection_pool: Default::default(),
 275            handlers: Default::default(),
 276            teardown: watch::channel(false).0,
 277        };
 278
 279        server
 280            .add_request_handler(ping)
 281            .add_request_handler(create_room)
 282            .add_request_handler(join_room)
 283            .add_request_handler(rejoin_room)
 284            .add_request_handler(leave_room)
 285            .add_request_handler(set_room_participant_role)
 286            .add_request_handler(call)
 287            .add_request_handler(cancel_call)
 288            .add_message_handler(decline_call)
 289            .add_request_handler(update_participant_location)
 290            .add_request_handler(share_project)
 291            .add_message_handler(unshare_project)
 292            .add_request_handler(join_project)
 293            .add_message_handler(leave_project)
 294            .add_request_handler(update_project)
 295            .add_request_handler(update_worktree)
 296            .add_request_handler(update_repository)
 297            .add_request_handler(remove_repository)
 298            .add_message_handler(start_language_server)
 299            .add_message_handler(update_language_server)
 300            .add_message_handler(update_diagnostic_summary)
 301            .add_message_handler(update_worktree_settings)
 302            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 308            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 313            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 319            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 320            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 321            .add_request_handler(
 322                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 323            )
 324            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 325            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 326            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 327            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 332            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 333            .add_request_handler(
 334                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 335            )
 336            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 345            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 350            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 355            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 357            .add_request_handler(lsp_query)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 359            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 362            .add_message_handler(create_buffer_for_peer)
 363            .add_message_handler(create_image_for_peer)
 364            .add_request_handler(update_buffer)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 366            .add_message_handler(
 367                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 368            )
 369            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 370            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 374            .add_message_handler(
 375                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 376            )
 377            .add_request_handler(get_users)
 378            .add_request_handler(fuzzy_search_users)
 379            .add_request_handler(request_contact)
 380            .add_request_handler(remove_contact)
 381            .add_request_handler(respond_to_contact_request)
 382            .add_message_handler(subscribe_to_channels)
 383            .add_request_handler(create_channel)
 384            .add_request_handler(delete_channel)
 385            .add_request_handler(invite_channel_member)
 386            .add_request_handler(remove_channel_member)
 387            .add_request_handler(set_channel_member_role)
 388            .add_request_handler(set_channel_visibility)
 389            .add_request_handler(rename_channel)
 390            .add_request_handler(join_channel_buffer)
 391            .add_request_handler(leave_channel_buffer)
 392            .add_message_handler(update_channel_buffer)
 393            .add_request_handler(rejoin_channel_buffers)
 394            .add_request_handler(get_channel_members)
 395            .add_request_handler(respond_to_channel_invite)
 396            .add_request_handler(join_channel)
 397            .add_request_handler(join_channel_chat)
 398            .add_message_handler(leave_channel_chat)
 399            .add_request_handler(send_channel_message)
 400            .add_request_handler(remove_channel_message)
 401            .add_request_handler(update_channel_message)
 402            .add_request_handler(get_channel_messages)
 403            .add_request_handler(get_channel_messages_by_id)
 404            .add_request_handler(get_notifications)
 405            .add_request_handler(mark_notification_as_read)
 406            .add_request_handler(move_channel)
 407            .add_request_handler(reorder_channel)
 408            .add_request_handler(follow)
 409            .add_message_handler(unfollow)
 410            .add_message_handler(update_followers)
 411            .add_message_handler(acknowledge_channel_message)
 412            .add_message_handler(acknowledge_buffer_version)
 413            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 414            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 415            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 416            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 417            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 418            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 419            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 421            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 422            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 423            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 424            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 425            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 426            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 427            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 430            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 431            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 432            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 433            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 434            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 435            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 437            .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
 439            .add_request_handler(forward_read_only_project_request::<proto::GetCommitData>)
 440            .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
 441            .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
 442            .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
 443            .add_request_handler(forward_mutating_project_request::<proto::GitEditRef>)
 444            .add_request_handler(forward_mutating_project_request::<proto::GitRepairWorktrees>)
 445            .add_request_handler(disallow_guest_request::<proto::GitCreateArchiveCheckpoint>)
 446            .add_request_handler(disallow_guest_request::<proto::GitRestoreArchiveCheckpoint>)
 447            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 448            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 449            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 450            .add_request_handler(share_agent_thread)
 451            .add_request_handler(get_shared_agent_thread)
 452            .add_request_handler(forward_project_search_chunk);
 453
 454        Arc::new(server)
 455    }
 456
 457    pub async fn start(&self) -> Result<()> {
 458        let server_id = *self.id.lock();
 459        let app_state = self.app_state.clone();
 460        let peer = self.peer.clone();
 461        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 462        let pool = self.connection_pool.clone();
 463        let livekit_client = self.app_state.livekit_client.clone();
 464
 465        let span = info_span!("start server");
 466        self.app_state.executor.spawn_detached(
 467            async move {
 468                tracing::info!("waiting for cleanup timeout");
 469                timeout.await;
 470                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 471
 472                app_state
 473                    .db
 474                    .delete_stale_channel_chat_participants(
 475                        &app_state.config.zed_environment,
 476                        server_id,
 477                    )
 478                    .await
 479                    .trace_err();
 480
 481                if let Some((room_ids, channel_ids)) = app_state
 482                    .db
 483                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 484                    .await
 485                    .trace_err()
 486                {
 487                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 488                    tracing::info!(
 489                        stale_channel_buffer_count = channel_ids.len(),
 490                        "retrieved stale channel buffers"
 491                    );
 492
 493                    for channel_id in channel_ids {
 494                        if let Some(refreshed_channel_buffer) = app_state
 495                            .db
 496                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            for connection_id in refreshed_channel_buffer.connection_ids {
 501                                peer.send(
 502                                    connection_id,
 503                                    proto::UpdateChannelBufferCollaborators {
 504                                        channel_id: channel_id.to_proto(),
 505                                        collaborators: refreshed_channel_buffer
 506                                            .collaborators
 507                                            .clone(),
 508                                    },
 509                                )
 510                                .trace_err();
 511                            }
 512                        }
 513                    }
 514
 515                    for room_id in room_ids {
 516                        let mut contacts_to_update = HashSet::default();
 517                        let mut canceled_calls_to_user_ids = Vec::new();
 518                        let mut livekit_room = String::new();
 519                        let mut delete_livekit_room = false;
 520
 521                        if let Some(mut refreshed_room) = app_state
 522                            .db
 523                            .clear_stale_room_participants(room_id, server_id)
 524                            .await
 525                            .trace_err()
 526                        {
 527                            tracing::info!(
 528                                room_id = room_id.0,
 529                                new_participant_count = refreshed_room.room.participants.len(),
 530                                "refreshed room"
 531                            );
 532                            room_updated(&refreshed_room.room, &peer);
 533                            if let Some(channel) = refreshed_room.channel.as_ref() {
 534                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 535                            }
 536                            contacts_to_update
 537                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 538                            contacts_to_update
 539                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 540                            canceled_calls_to_user_ids =
 541                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 542                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 543                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 544                        }
 545
 546                        {
 547                            let pool = pool.lock();
 548                            for canceled_user_id in canceled_calls_to_user_ids {
 549                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 550                                    peer.send(
 551                                        connection_id,
 552                                        proto::CallCanceled {
 553                                            room_id: room_id.to_proto(),
 554                                        },
 555                                    )
 556                                    .trace_err();
 557                                }
 558                            }
 559                        }
 560
 561                        for user_id in contacts_to_update {
 562                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 563                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 564                            if let Some((busy, contacts)) = busy.zip(contacts) {
 565                                let pool = pool.lock();
 566                                let updated_contact = contact_for_user(user_id, busy, &pool);
 567                                for contact in contacts {
 568                                    if let db::Contact::Accepted {
 569                                        user_id: contact_user_id,
 570                                        ..
 571                                    } = contact
 572                                    {
 573                                        for contact_conn_id in
 574                                            pool.user_connection_ids(contact_user_id)
 575                                        {
 576                                            peer.send(
 577                                                contact_conn_id,
 578                                                proto::UpdateContacts {
 579                                                    contacts: vec![updated_contact.clone()],
 580                                                    remove_contacts: Default::default(),
 581                                                    incoming_requests: Default::default(),
 582                                                    remove_incoming_requests: Default::default(),
 583                                                    outgoing_requests: Default::default(),
 584                                                    remove_outgoing_requests: Default::default(),
 585                                                },
 586                                            )
 587                                            .trace_err();
 588                                        }
 589                                    }
 590                                }
 591                            }
 592                        }
 593
 594                        if let Some(live_kit) = livekit_client.as_ref()
 595                            && delete_livekit_room
 596                        {
 597                            live_kit.delete_room(livekit_room).await.trace_err();
 598                        }
 599                    }
 600                }
 601
 602                app_state
 603                    .db
 604                    .delete_stale_channel_chat_participants(
 605                        &app_state.config.zed_environment,
 606                        server_id,
 607                    )
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .clear_old_worktree_entries(server_id)
 614                    .await
 615                    .trace_err();
 616
 617                app_state
 618                    .db
 619                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 620                    .await
 621                    .trace_err();
 622            }
 623            .instrument(span),
 624        );
 625        Ok(())
 626    }
 627
 628    pub fn teardown(&self) {
 629        self.peer.teardown();
 630        self.connection_pool.lock().reset();
 631        let _ = self.teardown.send(true);
 632    }
 633
 634    #[cfg(feature = "test-support")]
 635    pub fn reset(&self, id: ServerId) {
 636        self.teardown();
 637        *self.id.lock() = id;
 638        self.peer.reset(id.0 as u32);
 639        let _ = self.teardown.send(false);
 640    }
 641
 642    #[cfg(feature = "test-support")]
 643    pub fn id(&self) -> ServerId {
 644        *self.id.lock()
 645    }
 646
 647    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 648    where
 649        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 650        Fut: 'static + Send + Future<Output = Result<()>>,
 651        M: EnvelopedMessage,
 652    {
 653        let prev_handler = self.handlers.insert(
 654            TypeId::of::<M>(),
 655            Box::new(move |envelope, session, span| {
 656                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 657                let received_at = envelope.received_at;
 658                tracing::info!("message received");
 659                let start_time = Instant::now();
 660                let future = (handler)(
 661                    *envelope,
 662                    MessageContext {
 663                        session,
 664                        span: span.clone(),
 665                    },
 666                );
 667                async move {
 668                    let result = future.await;
 669                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 670                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 671                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 672                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 673                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 674                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 675                    match result {
 676                        Err(error) => {
 677                            tracing::error!(?error, "error handling message")
 678                        }
 679                        Ok(()) => tracing::info!("finished handling message"),
 680                    }
 681                }
 682                .boxed()
 683            }),
 684        );
 685        if prev_handler.is_some() {
 686            panic!("registered a handler for the same message twice");
 687        }
 688        self
 689    }
 690
 691    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 692    where
 693        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 694        Fut: 'static + Send + Future<Output = Result<()>>,
 695        M: EnvelopedMessage,
 696    {
 697        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 698        self
 699    }
 700
 701    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 702    where
 703        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 704        Fut: Send + Future<Output = Result<()>>,
 705        M: RequestMessage,
 706    {
 707        let handler = Arc::new(handler);
 708        self.add_handler(move |envelope, session| {
 709            let receipt = envelope.receipt();
 710            let handler = handler.clone();
 711            async move {
 712                let peer = session.peer.clone();
 713                let responded = Arc::new(AtomicBool::default());
 714                let response = Response {
 715                    peer: peer.clone(),
 716                    responded: responded.clone(),
 717                    receipt,
 718                };
 719                match (handler)(envelope.payload, response, session).await {
 720                    Ok(()) => {
 721                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 722                            Ok(())
 723                        } else {
 724                            Err(anyhow!("handler did not send a response"))?
 725                        }
 726                    }
 727                    Err(error) => {
 728                        let proto_err = match &error {
 729                            Error::Internal(err) => err.to_proto(),
 730                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 731                        };
 732                        peer.respond_with_error(receipt, proto_err)?;
 733                        Err(error)
 734                    }
 735                }
 736            }
 737        })
 738    }
 739
 740    pub fn handle_connection(
 741        self: &Arc<Self>,
 742        connection: Connection,
 743        address: String,
 744        principal: Principal,
 745        zed_version: ZedVersion,
 746        release_channel: Option<String>,
 747        user_agent: Option<String>,
 748        geoip_country_code: Option<String>,
 749        system_id: Option<String>,
 750        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 751        executor: Executor,
 752        connection_guard: Option<ConnectionGuard>,
 753    ) -> impl Future<Output = ()> + use<> {
 754        let this = self.clone();
 755        let span = info_span!("handle connection", %address,
 756            connection_id=field::Empty,
 757            user_id=field::Empty,
 758            login=field::Empty,
 759            user_agent=field::Empty,
 760            geoip_country_code=field::Empty,
 761            release_channel=field::Empty,
 762        );
 763        principal.update_span(&span);
 764        if let Some(user_agent) = user_agent {
 765            span.record("user_agent", user_agent);
 766        }
 767        if let Some(release_channel) = release_channel {
 768            span.record("release_channel", release_channel);
 769        }
 770
 771        if let Some(country_code) = geoip_country_code.as_ref() {
 772            span.record("geoip_country_code", country_code);
 773        }
 774
 775        let mut teardown = self.teardown.subscribe();
 776        async move {
 777            if *teardown.borrow() {
 778                tracing::error!("server is tearing down");
 779                return;
 780            }
 781
 782            let (connection_id, handle_io, mut incoming_rx) =
 783                this.peer.add_connection(connection, {
 784                    let executor = executor.clone();
 785                    move |duration| executor.sleep(duration)
 786                });
 787            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 788
 789            tracing::info!("connection opened");
 790
 791            let session = Session {
 792                principal: principal.clone(),
 793                connection_id,
 794                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 795                peer: this.peer.clone(),
 796                connection_pool: this.connection_pool.clone(),
 797                app_state: this.app_state.clone(),
 798                geoip_country_code,
 799                system_id,
 800                _executor: executor.clone(),
 801            };
 802
 803            if let Err(error) = this
 804                .send_initial_client_update(
 805                    connection_id,
 806                    zed_version,
 807                    send_connection_id,
 808                    &session,
 809                )
 810                .await
 811            {
 812                tracing::error!(?error, "failed to send initial client update");
 813                return;
 814            }
 815            drop(connection_guard);
 816
 817            let handle_io = handle_io.fuse();
 818            futures::pin_mut!(handle_io);
 819
 820            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 821            // This prevents deadlocks when e.g., client A performs a request to client B and
 822            // client B performs a request to client A. If both clients stop processing further
 823            // messages until their respective request completes, they won't have a chance to
 824            // respond to the other client's request and cause a deadlock.
 825            //
 826            // This arrangement ensures we will attempt to process earlier messages first, but fall
 827            // back to processing messages arrived later in the spirit of making progress.
 828            const MAX_CONCURRENT_HANDLERS: usize = 256;
 829            let mut foreground_message_handlers = FuturesUnordered::new();
 830            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 831            let get_concurrent_handlers = {
 832                let concurrent_handlers = concurrent_handlers.clone();
 833                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 834            };
 835            loop {
 836                let next_message = async {
 837                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 838                    let message = incoming_rx.next().await;
 839                    // Cache the concurrent_handlers here, so that we know what the
 840                    // queue looks like as each handler starts
 841                    (permit, message, get_concurrent_handlers())
 842                }
 843                .fuse();
 844                futures::pin_mut!(next_message);
 845                futures::select_biased! {
 846                    _ = teardown.changed().fuse() => return,
 847                    result = handle_io => {
 848                        if let Err(error) = result {
 849                            tracing::error!(?error, "error handling I/O");
 850                        }
 851                        break;
 852                    }
 853                    _ = foreground_message_handlers.next() => {}
 854                    next_message = next_message => {
 855                        let (permit, message, concurrent_handlers) = next_message;
 856                        if let Some(message) = message {
 857                            let type_name = message.payload_type_name();
 858                            // note: we copy all the fields from the parent span so we can query them in the logs.
 859                            // (https://github.com/tokio-rs/tracing/issues/2670).
 860                            let span = tracing::info_span!("receive message",
 861                                %connection_id,
 862                                %address,
 863                                type_name,
 864                                concurrent_handlers,
 865                                user_id=field::Empty,
 866                                login=field::Empty,
 867                                lsp_query_request=field::Empty,
 868                                release_channel=field::Empty,
 869                                { TOTAL_DURATION_MS }=field::Empty,
 870                                { PROCESSING_DURATION_MS }=field::Empty,
 871                                { QUEUE_DURATION_MS }=field::Empty,
 872                                { HOST_WAITING_MS }=field::Empty
 873                            );
 874                            principal.update_span(&span);
 875                            let span_enter = span.enter();
 876                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 877                                let is_background = message.is_background();
 878                                let handle_message = (handler)(message, session.clone(), span.clone());
 879                                drop(span_enter);
 880
 881                                let handle_message = async move {
 882                                    handle_message.await;
 883                                    drop(permit);
 884                                }.instrument(span);
 885                                if is_background {
 886                                    executor.spawn_detached(handle_message);
 887                                } else {
 888                                    foreground_message_handlers.push(handle_message);
 889                                }
 890                            } else {
 891                                tracing::error!("no message handler");
 892                            }
 893                        } else {
 894                            tracing::info!("connection closed");
 895                            break;
 896                        }
 897                    }
 898                }
 899            }
 900
 901            drop(foreground_message_handlers);
 902            let concurrent_handlers = get_concurrent_handlers();
 903            tracing::info!(concurrent_handlers, "signing out");
 904            if let Err(error) = connection_lost(session, teardown, executor).await {
 905                tracing::error!(?error, "error signing out");
 906            }
 907        }
 908        .instrument(span)
 909    }
 910
 911    async fn send_initial_client_update(
 912        &self,
 913        connection_id: ConnectionId,
 914        zed_version: ZedVersion,
 915        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 916        session: &Session,
 917    ) -> Result<()> {
 918        self.peer.send(
 919            connection_id,
 920            proto::Hello {
 921                peer_id: Some(connection_id.into()),
 922            },
 923        )?;
 924        tracing::info!("sent hello message");
 925        if let Some(send_connection_id) = send_connection_id.take() {
 926            let _ = send_connection_id.send(connection_id);
 927        }
 928
 929        match &session.principal {
 930            Principal::User(user) => {
 931                if !user.connected_once {
 932                    self.peer.send(connection_id, proto::ShowContacts {})?;
 933                    self.app_state
 934                        .db
 935                        .set_user_connected_once(user.id, true)
 936                        .await?;
 937                }
 938
 939                let contacts = self.app_state.db.get_contacts(user.id).await?;
 940
 941                {
 942                    let mut pool = self.connection_pool.lock();
 943                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 944                    self.peer.send(
 945                        connection_id,
 946                        build_initial_contacts_update(contacts, &pool),
 947                    )?;
 948                }
 949
 950                if should_auto_subscribe_to_channels(&zed_version) {
 951                    subscribe_user_to_channels(user.id, session).await?;
 952                }
 953
 954                if let Some(incoming_call) =
 955                    self.app_state.db.incoming_call_for_user(user.id).await?
 956                {
 957                    self.peer.send(connection_id, incoming_call)?;
 958                }
 959
 960                update_user_contacts(user.id, session).await?;
 961            }
 962        }
 963
 964        Ok(())
 965    }
 966}
 967
 968impl Deref for ConnectionPoolGuard<'_> {
 969    type Target = ConnectionPool;
 970
 971    fn deref(&self) -> &Self::Target {
 972        &self.guard
 973    }
 974}
 975
 976impl DerefMut for ConnectionPoolGuard<'_> {
 977    fn deref_mut(&mut self) -> &mut Self::Target {
 978        &mut self.guard
 979    }
 980}
 981
 982impl Drop for ConnectionPoolGuard<'_> {
 983    fn drop(&mut self) {
 984        #[cfg(feature = "test-support")]
 985        self.check_invariants();
 986    }
 987}
 988
 989fn broadcast<F>(
 990    sender_id: Option<ConnectionId>,
 991    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 992    mut f: F,
 993) where
 994    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 995{
 996    for receiver_id in receiver_ids {
 997        if Some(receiver_id) != sender_id
 998            && let Err(error) = f(receiver_id)
 999        {
1000            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1001        }
1002    }
1003}
1004
1005pub struct ProtocolVersion(u32);
1006
1007impl Header for ProtocolVersion {
1008    fn name() -> &'static HeaderName {
1009        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1010        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1011    }
1012
1013    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1014    where
1015        Self: Sized,
1016        I: Iterator<Item = &'i axum::http::HeaderValue>,
1017    {
1018        let version = values
1019            .next()
1020            .ok_or_else(axum::headers::Error::invalid)?
1021            .to_str()
1022            .map_err(|_| axum::headers::Error::invalid())?
1023            .parse()
1024            .map_err(|_| axum::headers::Error::invalid())?;
1025        Ok(Self(version))
1026    }
1027
1028    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1029        values.extend([self.0.to_string().parse().unwrap()]);
1030    }
1031}
1032
1033pub struct AppVersionHeader(Version);
1034impl Header for AppVersionHeader {
1035    fn name() -> &'static HeaderName {
1036        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1037        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1038    }
1039
1040    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1041    where
1042        Self: Sized,
1043        I: Iterator<Item = &'i axum::http::HeaderValue>,
1044    {
1045        let version = values
1046            .next()
1047            .ok_or_else(axum::headers::Error::invalid)?
1048            .to_str()
1049            .map_err(|_| axum::headers::Error::invalid())?
1050            .parse()
1051            .map_err(|_| axum::headers::Error::invalid())?;
1052        Ok(Self(version))
1053    }
1054
1055    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1056        values.extend([self.0.to_string().parse().unwrap()]);
1057    }
1058}
1059
1060#[derive(Debug)]
1061pub struct ReleaseChannelHeader(String);
1062
1063impl Header for ReleaseChannelHeader {
1064    fn name() -> &'static HeaderName {
1065        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1066        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1067    }
1068
1069    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1070    where
1071        Self: Sized,
1072        I: Iterator<Item = &'i axum::http::HeaderValue>,
1073    {
1074        Ok(Self(
1075            values
1076                .next()
1077                .ok_or_else(axum::headers::Error::invalid)?
1078                .to_str()
1079                .map_err(|_| axum::headers::Error::invalid())?
1080                .to_owned(),
1081        ))
1082    }
1083
1084    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1085        values.extend([self.0.parse().unwrap()]);
1086    }
1087}
1088
1089pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1090    Router::new()
1091        .route("/rpc", get(handle_websocket_request))
1092        .layer(
1093            ServiceBuilder::new()
1094                .layer(Extension(server.app_state.clone()))
1095                .layer(middleware::from_fn(auth::validate_header)),
1096        )
1097        .route("/metrics", get(handle_metrics))
1098        .layer(Extension(server))
1099}
1100
1101pub async fn handle_websocket_request(
1102    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1103    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1104    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1105    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1106    Extension(server): Extension<Arc<Server>>,
1107    Extension(principal): Extension<Principal>,
1108    user_agent: Option<TypedHeader<UserAgent>>,
1109    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1110    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1111    ws: WebSocketUpgrade,
1112) -> axum::response::Response {
1113    if protocol_version != rpc::PROTOCOL_VERSION {
1114        return (
1115            StatusCode::UPGRADE_REQUIRED,
1116            "client must be upgraded".to_string(),
1117        )
1118            .into_response();
1119    }
1120
1121    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1122        return (
1123            StatusCode::UPGRADE_REQUIRED,
1124            "no version header found".to_string(),
1125        )
1126            .into_response();
1127    };
1128
1129    let release_channel = release_channel_header.map(|header| header.0.0);
1130
1131    if !version.can_collaborate() {
1132        return (
1133            StatusCode::UPGRADE_REQUIRED,
1134            "client must be upgraded".to_string(),
1135        )
1136            .into_response();
1137    }
1138
1139    let socket_address = socket_address.to_string();
1140
1141    // Acquire connection guard before WebSocket upgrade
1142    let connection_guard = match ConnectionGuard::try_acquire() {
1143        Ok(guard) => guard,
1144        Err(()) => {
1145            return (
1146                StatusCode::SERVICE_UNAVAILABLE,
1147                "Too many concurrent connections",
1148            )
1149                .into_response();
1150        }
1151    };
1152
1153    ws.on_upgrade(move |socket| {
1154        let socket = socket
1155            .map_ok(to_tungstenite_message)
1156            .err_into()
1157            .with(|message| async move { to_axum_message(message) });
1158        let connection = Connection::new(Box::pin(socket));
1159        async move {
1160            server
1161                .handle_connection(
1162                    connection,
1163                    socket_address,
1164                    principal,
1165                    version,
1166                    release_channel,
1167                    user_agent.map(|header| header.to_string()),
1168                    country_code_header.map(|header| header.to_string()),
1169                    system_id_header.map(|header| header.to_string()),
1170                    None,
1171                    Executor::Production,
1172                    Some(connection_guard),
1173                )
1174                .await;
1175        }
1176    })
1177}
1178
1179pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1180    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1181    let connections_metric = CONNECTIONS_METRIC
1182        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1183
1184    let connections = server
1185        .connection_pool
1186        .lock()
1187        .connections()
1188        .filter(|connection| !connection.admin)
1189        .count();
1190    connections_metric.set(connections as _);
1191
1192    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1193    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1194        register_int_gauge!(
1195            "shared_projects",
1196            "number of open projects with one or more guests"
1197        )
1198        .unwrap()
1199    });
1200
1201    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1202    shared_projects_metric.set(shared_projects as _);
1203
1204    let encoder = prometheus::TextEncoder::new();
1205    let metric_families = prometheus::gather();
1206    let encoded_metrics = encoder
1207        .encode_to_string(&metric_families)
1208        .map_err(|err| anyhow!("{err}"))?;
1209    Ok(encoded_metrics)
1210}
1211
1212#[instrument(err, skip(executor))]
1213async fn connection_lost(
1214    session: Session,
1215    mut teardown: watch::Receiver<bool>,
1216    executor: Executor,
1217) -> Result<()> {
1218    session.peer.disconnect(session.connection_id);
1219    session
1220        .connection_pool()
1221        .await
1222        .remove_connection(session.connection_id)?;
1223
1224    session
1225        .db()
1226        .await
1227        .connection_lost(session.connection_id)
1228        .await
1229        .trace_err();
1230
1231    futures::select_biased! {
1232        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1233
1234            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1235            leave_room_for_session(&session, session.connection_id).await.trace_err();
1236            leave_channel_buffers_for_session(&session)
1237                .await
1238                .trace_err();
1239
1240            if !session
1241                .connection_pool()
1242                .await
1243                .is_user_online(session.user_id())
1244            {
1245                let db = session.db().await;
1246                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1247                    room_updated(&room, &session.peer);
1248                }
1249            }
1250
1251            update_user_contacts(session.user_id(), &session).await?;
1252        },
1253        _ = teardown.changed().fuse() => {}
1254    }
1255
1256    Ok(())
1257}
1258
1259/// Acknowledges a ping from a client, used to keep the connection alive.
1260async fn ping(
1261    _: proto::Ping,
1262    response: Response<proto::Ping>,
1263    _session: MessageContext,
1264) -> Result<()> {
1265    response.send(proto::Ack {})?;
1266    Ok(())
1267}
1268
1269/// Creates a new room for calling (outside of channels)
1270async fn create_room(
1271    _request: proto::CreateRoom,
1272    response: Response<proto::CreateRoom>,
1273    session: MessageContext,
1274) -> Result<()> {
1275    let livekit_room = nanoid::nanoid!(30);
1276
1277    let live_kit_connection_info = util::maybe!(async {
1278        let live_kit = session.app_state.livekit_client.as_ref();
1279        let live_kit = live_kit?;
1280        let user_id = session.user_id().to_string();
1281
1282        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1283
1284        Some(proto::LiveKitConnectionInfo {
1285            server_url: live_kit.url().into(),
1286            token,
1287            can_publish: true,
1288        })
1289    })
1290    .await;
1291
1292    let room = session
1293        .db()
1294        .await
1295        .create_room(session.user_id(), session.connection_id, &livekit_room)
1296        .await?;
1297
1298    response.send(proto::CreateRoomResponse {
1299        room: Some(room.clone()),
1300        live_kit_connection_info,
1301    })?;
1302
1303    update_user_contacts(session.user_id(), &session).await?;
1304    Ok(())
1305}
1306
1307/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1308async fn join_room(
1309    request: proto::JoinRoom,
1310    response: Response<proto::JoinRoom>,
1311    session: MessageContext,
1312) -> Result<()> {
1313    let room_id = RoomId::from_proto(request.id);
1314
1315    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1316
1317    if let Some(channel_id) = channel_id {
1318        return join_channel_internal(channel_id, Box::new(response), session).await;
1319    }
1320
1321    let joined_room = {
1322        let room = session
1323            .db()
1324            .await
1325            .join_room(room_id, session.user_id(), session.connection_id)
1326            .await?;
1327        room_updated(&room.room, &session.peer);
1328        room.into_inner()
1329    };
1330
1331    for connection_id in session
1332        .connection_pool()
1333        .await
1334        .user_connection_ids(session.user_id())
1335    {
1336        session
1337            .peer
1338            .send(
1339                connection_id,
1340                proto::CallCanceled {
1341                    room_id: room_id.to_proto(),
1342                },
1343            )
1344            .trace_err();
1345    }
1346
1347    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1348    {
1349        live_kit
1350            .room_token(
1351                &joined_room.room.livekit_room,
1352                &session.user_id().to_string(),
1353            )
1354            .trace_err()
1355            .map(|token| proto::LiveKitConnectionInfo {
1356                server_url: live_kit.url().into(),
1357                token,
1358                can_publish: true,
1359            })
1360    } else {
1361        None
1362    };
1363
1364    response.send(proto::JoinRoomResponse {
1365        room: Some(joined_room.room),
1366        channel_id: None,
1367        live_kit_connection_info,
1368    })?;
1369
1370    update_user_contacts(session.user_id(), &session).await?;
1371    Ok(())
1372}
1373
1374/// Rejoin room is used to reconnect to a room after connection errors.
1375async fn rejoin_room(
1376    request: proto::RejoinRoom,
1377    response: Response<proto::RejoinRoom>,
1378    session: MessageContext,
1379) -> Result<()> {
1380    let room;
1381    let channel;
1382    {
1383        let mut rejoined_room = session
1384            .db()
1385            .await
1386            .rejoin_room(request, session.user_id(), session.connection_id)
1387            .await?;
1388
1389        response.send(proto::RejoinRoomResponse {
1390            room: Some(rejoined_room.room.clone()),
1391            reshared_projects: rejoined_room
1392                .reshared_projects
1393                .iter()
1394                .map(|project| proto::ResharedProject {
1395                    id: project.id.to_proto(),
1396                    collaborators: project
1397                        .collaborators
1398                        .iter()
1399                        .map(|collaborator| collaborator.to_proto())
1400                        .collect(),
1401                })
1402                .collect(),
1403            rejoined_projects: rejoined_room
1404                .rejoined_projects
1405                .iter()
1406                .map(|rejoined_project| rejoined_project.to_proto())
1407                .collect(),
1408        })?;
1409        room_updated(&rejoined_room.room, &session.peer);
1410
1411        for project in &rejoined_room.reshared_projects {
1412            for collaborator in &project.collaborators {
1413                session
1414                    .peer
1415                    .send(
1416                        collaborator.connection_id,
1417                        proto::UpdateProjectCollaborator {
1418                            project_id: project.id.to_proto(),
1419                            old_peer_id: Some(project.old_connection_id.into()),
1420                            new_peer_id: Some(session.connection_id.into()),
1421                        },
1422                    )
1423                    .trace_err();
1424            }
1425
1426            broadcast(
1427                Some(session.connection_id),
1428                project
1429                    .collaborators
1430                    .iter()
1431                    .map(|collaborator| collaborator.connection_id),
1432                |connection_id| {
1433                    session.peer.forward_send(
1434                        session.connection_id,
1435                        connection_id,
1436                        proto::UpdateProject {
1437                            project_id: project.id.to_proto(),
1438                            worktrees: project.worktrees.clone(),
1439                        },
1440                    )
1441                },
1442            );
1443        }
1444
1445        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1446
1447        let rejoined_room = rejoined_room.into_inner();
1448
1449        room = rejoined_room.room;
1450        channel = rejoined_room.channel;
1451    }
1452
1453    if let Some(channel) = channel {
1454        channel_updated(
1455            &channel,
1456            &room,
1457            &session.peer,
1458            &*session.connection_pool().await,
1459        );
1460    }
1461
1462    update_user_contacts(session.user_id(), &session).await?;
1463    Ok(())
1464}
1465
1466fn notify_rejoined_projects(
1467    rejoined_projects: &mut Vec<RejoinedProject>,
1468    session: &Session,
1469) -> Result<()> {
1470    for project in rejoined_projects.iter() {
1471        for collaborator in &project.collaborators {
1472            session
1473                .peer
1474                .send(
1475                    collaborator.connection_id,
1476                    proto::UpdateProjectCollaborator {
1477                        project_id: project.id.to_proto(),
1478                        old_peer_id: Some(project.old_connection_id.into()),
1479                        new_peer_id: Some(session.connection_id.into()),
1480                    },
1481                )
1482                .trace_err();
1483        }
1484    }
1485
1486    for project in rejoined_projects {
1487        for worktree in mem::take(&mut project.worktrees) {
1488            // Stream this worktree's entries.
1489            let message = proto::UpdateWorktree {
1490                project_id: project.id.to_proto(),
1491                worktree_id: worktree.id,
1492                abs_path: worktree.abs_path.clone(),
1493                root_name: worktree.root_name,
1494                root_repo_common_dir: worktree.root_repo_common_dir,
1495                updated_entries: worktree.updated_entries,
1496                removed_entries: worktree.removed_entries,
1497                scan_id: worktree.scan_id,
1498                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1499                updated_repositories: worktree.updated_repositories,
1500                removed_repositories: worktree.removed_repositories,
1501            };
1502            for update in proto::split_worktree_update(message) {
1503                session.peer.send(session.connection_id, update)?;
1504            }
1505
1506            // Stream this worktree's diagnostics.
1507            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1508            if let Some(summary) = worktree_diagnostics.next() {
1509                let message = proto::UpdateDiagnosticSummary {
1510                    project_id: project.id.to_proto(),
1511                    worktree_id: worktree.id,
1512                    summary: Some(summary),
1513                    more_summaries: worktree_diagnostics.collect(),
1514                };
1515                session.peer.send(session.connection_id, message)?;
1516            }
1517
1518            for settings_file in worktree.settings_files {
1519                session.peer.send(
1520                    session.connection_id,
1521                    proto::UpdateWorktreeSettings {
1522                        project_id: project.id.to_proto(),
1523                        worktree_id: worktree.id,
1524                        path: settings_file.path,
1525                        content: Some(settings_file.content),
1526                        kind: Some(settings_file.kind.to_proto().into()),
1527                        outside_worktree: Some(settings_file.outside_worktree),
1528                    },
1529                )?;
1530            }
1531        }
1532
1533        for repository in mem::take(&mut project.updated_repositories) {
1534            for update in split_repository_update(repository) {
1535                session.peer.send(session.connection_id, update)?;
1536            }
1537        }
1538
1539        for id in mem::take(&mut project.removed_repositories) {
1540            session.peer.send(
1541                session.connection_id,
1542                proto::RemoveRepository {
1543                    project_id: project.id.to_proto(),
1544                    id,
1545                },
1546            )?;
1547        }
1548    }
1549
1550    Ok(())
1551}
1552
1553/// leave room disconnects from the room.
1554async fn leave_room(
1555    _: proto::LeaveRoom,
1556    response: Response<proto::LeaveRoom>,
1557    session: MessageContext,
1558) -> Result<()> {
1559    leave_room_for_session(&session, session.connection_id).await?;
1560    response.send(proto::Ack {})?;
1561    Ok(())
1562}
1563
1564/// Updates the permissions of someone else in the room.
1565async fn set_room_participant_role(
1566    request: proto::SetRoomParticipantRole,
1567    response: Response<proto::SetRoomParticipantRole>,
1568    session: MessageContext,
1569) -> Result<()> {
1570    let user_id = UserId::from_proto(request.user_id);
1571    let role = ChannelRole::from(request.role());
1572
1573    let (livekit_room, can_publish) = {
1574        let room = session
1575            .db()
1576            .await
1577            .set_room_participant_role(
1578                session.user_id(),
1579                RoomId::from_proto(request.room_id),
1580                user_id,
1581                role,
1582            )
1583            .await?;
1584
1585        let livekit_room = room.livekit_room.clone();
1586        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1587        room_updated(&room, &session.peer);
1588        (livekit_room, can_publish)
1589    };
1590
1591    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1592        live_kit
1593            .update_participant(
1594                livekit_room.clone(),
1595                request.user_id.to_string(),
1596                livekit_api::proto::ParticipantPermission {
1597                    can_subscribe: true,
1598                    can_publish,
1599                    can_publish_data: can_publish,
1600                    hidden: false,
1601                    recorder: false,
1602                },
1603            )
1604            .await
1605            .trace_err();
1606    }
1607
1608    response.send(proto::Ack {})?;
1609    Ok(())
1610}
1611
1612/// Call someone else into the current room
1613async fn call(
1614    request: proto::Call,
1615    response: Response<proto::Call>,
1616    session: MessageContext,
1617) -> Result<()> {
1618    let room_id = RoomId::from_proto(request.room_id);
1619    let calling_user_id = session.user_id();
1620    let calling_connection_id = session.connection_id;
1621    let called_user_id = UserId::from_proto(request.called_user_id);
1622    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1623    if !session
1624        .db()
1625        .await
1626        .has_contact(calling_user_id, called_user_id)
1627        .await?
1628    {
1629        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1630    }
1631
1632    let incoming_call = {
1633        let (room, incoming_call) = &mut *session
1634            .db()
1635            .await
1636            .call(
1637                room_id,
1638                calling_user_id,
1639                calling_connection_id,
1640                called_user_id,
1641                initial_project_id,
1642            )
1643            .await?;
1644        room_updated(room, &session.peer);
1645        mem::take(incoming_call)
1646    };
1647    update_user_contacts(called_user_id, &session).await?;
1648
1649    let mut calls = session
1650        .connection_pool()
1651        .await
1652        .user_connection_ids(called_user_id)
1653        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1654        .collect::<FuturesUnordered<_>>();
1655
1656    while let Some(call_response) = calls.next().await {
1657        match call_response.as_ref() {
1658            Ok(_) => {
1659                response.send(proto::Ack {})?;
1660                return Ok(());
1661            }
1662            Err(_) => {
1663                call_response.trace_err();
1664            }
1665        }
1666    }
1667
1668    {
1669        let room = session
1670            .db()
1671            .await
1672            .call_failed(room_id, called_user_id)
1673            .await?;
1674        room_updated(&room, &session.peer);
1675    }
1676    update_user_contacts(called_user_id, &session).await?;
1677
1678    Err(anyhow!("failed to ring user"))?
1679}
1680
1681/// Cancel an outgoing call.
1682async fn cancel_call(
1683    request: proto::CancelCall,
1684    response: Response<proto::CancelCall>,
1685    session: MessageContext,
1686) -> Result<()> {
1687    let called_user_id = UserId::from_proto(request.called_user_id);
1688    let room_id = RoomId::from_proto(request.room_id);
1689    {
1690        let room = session
1691            .db()
1692            .await
1693            .cancel_call(room_id, session.connection_id, called_user_id)
1694            .await?;
1695        room_updated(&room, &session.peer);
1696    }
1697
1698    for connection_id in session
1699        .connection_pool()
1700        .await
1701        .user_connection_ids(called_user_id)
1702    {
1703        session
1704            .peer
1705            .send(
1706                connection_id,
1707                proto::CallCanceled {
1708                    room_id: room_id.to_proto(),
1709                },
1710            )
1711            .trace_err();
1712    }
1713    response.send(proto::Ack {})?;
1714
1715    update_user_contacts(called_user_id, &session).await?;
1716    Ok(())
1717}
1718
1719/// Decline an incoming call.
1720async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1721    let room_id = RoomId::from_proto(message.room_id);
1722    {
1723        let room = session
1724            .db()
1725            .await
1726            .decline_call(Some(room_id), session.user_id())
1727            .await?
1728            .context("declining call")?;
1729        room_updated(&room, &session.peer);
1730    }
1731
1732    for connection_id in session
1733        .connection_pool()
1734        .await
1735        .user_connection_ids(session.user_id())
1736    {
1737        session
1738            .peer
1739            .send(
1740                connection_id,
1741                proto::CallCanceled {
1742                    room_id: room_id.to_proto(),
1743                },
1744            )
1745            .trace_err();
1746    }
1747    update_user_contacts(session.user_id(), &session).await?;
1748    Ok(())
1749}
1750
1751/// Updates other participants in the room with your current location.
1752async fn update_participant_location(
1753    request: proto::UpdateParticipantLocation,
1754    response: Response<proto::UpdateParticipantLocation>,
1755    session: MessageContext,
1756) -> Result<()> {
1757    let room_id = RoomId::from_proto(request.room_id);
1758    let location = request.location.context("invalid location")?;
1759
1760    let db = session.db().await;
1761    let room = db
1762        .update_room_participant_location(room_id, session.connection_id, location)
1763        .await?;
1764
1765    room_updated(&room, &session.peer);
1766    response.send(proto::Ack {})?;
1767    Ok(())
1768}
1769
1770/// Share a project into the room.
1771async fn share_project(
1772    request: proto::ShareProject,
1773    response: Response<proto::ShareProject>,
1774    session: MessageContext,
1775) -> Result<()> {
1776    let (project_id, room) = &*session
1777        .db()
1778        .await
1779        .share_project(
1780            RoomId::from_proto(request.room_id),
1781            session.connection_id,
1782            &request.worktrees,
1783            request.is_ssh_project,
1784            request.windows_paths.unwrap_or(false),
1785            &request.features,
1786        )
1787        .await?;
1788    response.send(proto::ShareProjectResponse {
1789        project_id: project_id.to_proto(),
1790    })?;
1791    room_updated(room, &session.peer);
1792
1793    Ok(())
1794}
1795
1796/// Unshare a project from the room.
1797async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1798    let project_id = ProjectId::from_proto(message.project_id);
1799    unshare_project_internal(project_id, session.connection_id, &session).await
1800}
1801
1802async fn unshare_project_internal(
1803    project_id: ProjectId,
1804    connection_id: ConnectionId,
1805    session: &Session,
1806) -> Result<()> {
1807    let delete = {
1808        let room_guard = session
1809            .db()
1810            .await
1811            .unshare_project(project_id, connection_id)
1812            .await?;
1813
1814        let (delete, room, guest_connection_ids) = &*room_guard;
1815
1816        let message = proto::UnshareProject {
1817            project_id: project_id.to_proto(),
1818        };
1819
1820        broadcast(
1821            Some(connection_id),
1822            guest_connection_ids.iter().copied(),
1823            |conn_id| session.peer.send(conn_id, message.clone()),
1824        );
1825        if let Some(room) = room {
1826            room_updated(room, &session.peer);
1827        }
1828
1829        *delete
1830    };
1831
1832    if delete {
1833        let db = session.db().await;
1834        db.delete_project(project_id).await?;
1835    }
1836
1837    Ok(())
1838}
1839
1840/// Join someone elses shared project.
1841async fn join_project(
1842    request: proto::JoinProject,
1843    response: Response<proto::JoinProject>,
1844    session: MessageContext,
1845) -> Result<()> {
1846    let project_id = ProjectId::from_proto(request.project_id);
1847
1848    tracing::info!(%project_id, "join project");
1849
1850    let db = session.db().await;
1851    let project_model = db.get_project(project_id).await?;
1852    let host_features: Vec<String> =
1853        serde_json::from_str(&project_model.features).unwrap_or_default();
1854    let guest_features: HashSet<_> = request.features.iter().collect();
1855    let host_features_set: HashSet<_> = host_features.iter().collect();
1856    if guest_features != host_features_set {
1857        let host_connection_id = project_model.host_connection()?;
1858        let mut pool = session.connection_pool().await;
1859        let host_version = pool
1860            .connection(host_connection_id)
1861            .map(|c| c.zed_version.to_string());
1862        let guest_version = pool
1863            .connection(session.connection_id)
1864            .map(|c| c.zed_version.to_string());
1865        drop(pool);
1866        Err(anyhow!(
1867            "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1868            host_version.as_deref().unwrap_or("unknown"),
1869            guest_version.as_deref().unwrap_or("unknown"),
1870        ))?;
1871    }
1872
1873    let (project, replica_id) = &mut *db
1874        .join_project(
1875            project_id,
1876            session.connection_id,
1877            session.user_id(),
1878            request.committer_name.clone(),
1879            request.committer_email.clone(),
1880        )
1881        .await?;
1882    drop(db);
1883
1884    tracing::info!(%project_id, "join remote project");
1885    let collaborators = project
1886        .collaborators
1887        .iter()
1888        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1889        .map(|collaborator| collaborator.to_proto())
1890        .collect::<Vec<_>>();
1891    let project_id = project.id;
1892    let guest_user_id = session.user_id();
1893
1894    let worktrees = project
1895        .worktrees
1896        .iter()
1897        .map(|(id, worktree)| proto::WorktreeMetadata {
1898            id: *id,
1899            root_name: worktree.root_name.clone(),
1900            visible: worktree.visible,
1901            abs_path: worktree.abs_path.clone(),
1902            root_repo_common_dir: None,
1903        })
1904        .collect::<Vec<_>>();
1905
1906    let add_project_collaborator = proto::AddProjectCollaborator {
1907        project_id: project_id.to_proto(),
1908        collaborator: Some(proto::Collaborator {
1909            peer_id: Some(session.connection_id.into()),
1910            replica_id: replica_id.0 as u32,
1911            user_id: guest_user_id.to_proto(),
1912            is_host: false,
1913            committer_name: request.committer_name.clone(),
1914            committer_email: request.committer_email.clone(),
1915        }),
1916    };
1917
1918    for collaborator in &collaborators {
1919        session
1920            .peer
1921            .send(
1922                collaborator.peer_id.unwrap().into(),
1923                add_project_collaborator.clone(),
1924            )
1925            .trace_err();
1926    }
1927
1928    // First, we send the metadata associated with each worktree.
1929    let (language_servers, language_server_capabilities) = project
1930        .language_servers
1931        .clone()
1932        .into_iter()
1933        .map(|server| (server.server, server.capabilities))
1934        .unzip();
1935    response.send(proto::JoinProjectResponse {
1936        project_id: project.id.0 as u64,
1937        worktrees,
1938        replica_id: replica_id.0 as u32,
1939        collaborators,
1940        language_servers,
1941        language_server_capabilities,
1942        role: project.role.into(),
1943        windows_paths: project.path_style == PathStyle::Windows,
1944        features: project.features.clone(),
1945    })?;
1946
1947    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1948        // Stream this worktree's entries.
1949        let message = proto::UpdateWorktree {
1950            project_id: project_id.to_proto(),
1951            worktree_id,
1952            abs_path: worktree.abs_path.clone(),
1953            root_name: worktree.root_name,
1954            root_repo_common_dir: worktree.root_repo_common_dir,
1955            updated_entries: worktree.entries,
1956            removed_entries: Default::default(),
1957            scan_id: worktree.scan_id,
1958            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1959            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1960            removed_repositories: Default::default(),
1961        };
1962        for update in proto::split_worktree_update(message) {
1963            session.peer.send(session.connection_id, update.clone())?;
1964        }
1965
1966        // Stream this worktree's diagnostics.
1967        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1968        if let Some(summary) = worktree_diagnostics.next() {
1969            let message = proto::UpdateDiagnosticSummary {
1970                project_id: project.id.to_proto(),
1971                worktree_id: worktree.id,
1972                summary: Some(summary),
1973                more_summaries: worktree_diagnostics.collect(),
1974            };
1975            session.peer.send(session.connection_id, message)?;
1976        }
1977
1978        for settings_file in worktree.settings_files {
1979            session.peer.send(
1980                session.connection_id,
1981                proto::UpdateWorktreeSettings {
1982                    project_id: project_id.to_proto(),
1983                    worktree_id: worktree.id,
1984                    path: settings_file.path,
1985                    content: Some(settings_file.content),
1986                    kind: Some(settings_file.kind.to_proto() as i32),
1987                    outside_worktree: Some(settings_file.outside_worktree),
1988                },
1989            )?;
1990        }
1991    }
1992
1993    for repository in mem::take(&mut project.repositories) {
1994        for update in split_repository_update(repository) {
1995            session.peer.send(session.connection_id, update)?;
1996        }
1997    }
1998
1999    for language_server in &project.language_servers {
2000        session.peer.send(
2001            session.connection_id,
2002            proto::UpdateLanguageServer {
2003                project_id: project_id.to_proto(),
2004                server_name: Some(language_server.server.name.clone()),
2005                language_server_id: language_server.server.id,
2006                variant: Some(
2007                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2008                        proto::LspDiskBasedDiagnosticsUpdated {},
2009                    ),
2010                ),
2011            },
2012        )?;
2013    }
2014
2015    Ok(())
2016}
2017
2018/// Leave someone elses shared project.
2019async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2020    let sender_id = session.connection_id;
2021    let project_id = ProjectId::from_proto(request.project_id);
2022    let db = session.db().await;
2023
2024    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2025    tracing::info!(
2026        %project_id,
2027        "leave project"
2028    );
2029
2030    project_left(project, &session);
2031    if let Some(room) = room {
2032        room_updated(room, &session.peer);
2033    }
2034
2035    Ok(())
2036}
2037
2038/// Updates other participants with changes to the project
2039async fn update_project(
2040    request: proto::UpdateProject,
2041    response: Response<proto::UpdateProject>,
2042    session: MessageContext,
2043) -> Result<()> {
2044    let project_id = ProjectId::from_proto(request.project_id);
2045    let (room, guest_connection_ids) = &*session
2046        .db()
2047        .await
2048        .update_project(project_id, session.connection_id, &request.worktrees)
2049        .await?;
2050    broadcast(
2051        Some(session.connection_id),
2052        guest_connection_ids.iter().copied(),
2053        |connection_id| {
2054            session
2055                .peer
2056                .forward_send(session.connection_id, connection_id, request.clone())
2057        },
2058    );
2059    if let Some(room) = room {
2060        room_updated(room, &session.peer);
2061    }
2062    response.send(proto::Ack {})?;
2063
2064    Ok(())
2065}
2066
2067/// Updates other participants with changes to the worktree
2068async fn update_worktree(
2069    request: proto::UpdateWorktree,
2070    response: Response<proto::UpdateWorktree>,
2071    session: MessageContext,
2072) -> Result<()> {
2073    let guest_connection_ids = session
2074        .db()
2075        .await
2076        .update_worktree(&request, session.connection_id)
2077        .await?;
2078
2079    broadcast(
2080        Some(session.connection_id),
2081        guest_connection_ids.iter().copied(),
2082        |connection_id| {
2083            session
2084                .peer
2085                .forward_send(session.connection_id, connection_id, request.clone())
2086        },
2087    );
2088    response.send(proto::Ack {})?;
2089    Ok(())
2090}
2091
2092async fn update_repository(
2093    request: proto::UpdateRepository,
2094    response: Response<proto::UpdateRepository>,
2095    session: MessageContext,
2096) -> Result<()> {
2097    let guest_connection_ids = session
2098        .db()
2099        .await
2100        .update_repository(&request, session.connection_id)
2101        .await?;
2102
2103    broadcast(
2104        Some(session.connection_id),
2105        guest_connection_ids.iter().copied(),
2106        |connection_id| {
2107            session
2108                .peer
2109                .forward_send(session.connection_id, connection_id, request.clone())
2110        },
2111    );
2112    response.send(proto::Ack {})?;
2113    Ok(())
2114}
2115
2116async fn remove_repository(
2117    request: proto::RemoveRepository,
2118    response: Response<proto::RemoveRepository>,
2119    session: MessageContext,
2120) -> Result<()> {
2121    let guest_connection_ids = session
2122        .db()
2123        .await
2124        .remove_repository(&request, session.connection_id)
2125        .await?;
2126
2127    broadcast(
2128        Some(session.connection_id),
2129        guest_connection_ids.iter().copied(),
2130        |connection_id| {
2131            session
2132                .peer
2133                .forward_send(session.connection_id, connection_id, request.clone())
2134        },
2135    );
2136    response.send(proto::Ack {})?;
2137    Ok(())
2138}
2139
2140/// Updates other participants with changes to the diagnostics
2141async fn update_diagnostic_summary(
2142    message: proto::UpdateDiagnosticSummary,
2143    session: MessageContext,
2144) -> Result<()> {
2145    let guest_connection_ids = session
2146        .db()
2147        .await
2148        .update_diagnostic_summary(&message, session.connection_id)
2149        .await?;
2150
2151    broadcast(
2152        Some(session.connection_id),
2153        guest_connection_ids.iter().copied(),
2154        |connection_id| {
2155            session
2156                .peer
2157                .forward_send(session.connection_id, connection_id, message.clone())
2158        },
2159    );
2160
2161    Ok(())
2162}
2163
2164/// Updates other participants with changes to the worktree settings
2165async fn update_worktree_settings(
2166    message: proto::UpdateWorktreeSettings,
2167    session: MessageContext,
2168) -> Result<()> {
2169    let guest_connection_ids = session
2170        .db()
2171        .await
2172        .update_worktree_settings(&message, session.connection_id)
2173        .await?;
2174
2175    broadcast(
2176        Some(session.connection_id),
2177        guest_connection_ids.iter().copied(),
2178        |connection_id| {
2179            session
2180                .peer
2181                .forward_send(session.connection_id, connection_id, message.clone())
2182        },
2183    );
2184
2185    Ok(())
2186}
2187
2188/// Notify other participants that a language server has started.
2189async fn start_language_server(
2190    request: proto::StartLanguageServer,
2191    session: MessageContext,
2192) -> Result<()> {
2193    let guest_connection_ids = session
2194        .db()
2195        .await
2196        .start_language_server(&request, session.connection_id)
2197        .await?;
2198
2199    broadcast(
2200        Some(session.connection_id),
2201        guest_connection_ids.iter().copied(),
2202        |connection_id| {
2203            session
2204                .peer
2205                .forward_send(session.connection_id, connection_id, request.clone())
2206        },
2207    );
2208    Ok(())
2209}
2210
2211/// Notify other participants that a language server has changed.
2212async fn update_language_server(
2213    request: proto::UpdateLanguageServer,
2214    session: MessageContext,
2215) -> Result<()> {
2216    let project_id = ProjectId::from_proto(request.project_id);
2217    let db = session.db().await;
2218
2219    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2220        && let Some(capabilities) = update.capabilities.clone()
2221    {
2222        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2223            .await?;
2224    }
2225
2226    let project_connection_ids = db
2227        .project_connection_ids(project_id, session.connection_id, true)
2228        .await?;
2229    broadcast(
2230        Some(session.connection_id),
2231        project_connection_ids.iter().copied(),
2232        |connection_id| {
2233            session
2234                .peer
2235                .forward_send(session.connection_id, connection_id, request.clone())
2236        },
2237    );
2238    Ok(())
2239}
2240
2241/// forward a project request to the host. These requests should be read only
2242/// as guests are allowed to send them.
2243async fn forward_read_only_project_request<T>(
2244    request: T,
2245    response: Response<T>,
2246    session: MessageContext,
2247) -> Result<()>
2248where
2249    T: EntityMessage + RequestMessage,
2250{
2251    let project_id = ProjectId::from_proto(request.remote_entity_id());
2252    let host_connection_id = session
2253        .db()
2254        .await
2255        .host_for_read_only_project_request(project_id, session.connection_id)
2256        .await?;
2257    let payload = session.forward_request(host_connection_id, request).await?;
2258    response.send(payload)?;
2259    Ok(())
2260}
2261
2262/// forward a project request to the host. These requests are disallowed
2263/// for guests.
2264async fn forward_mutating_project_request<T>(
2265    request: T,
2266    response: Response<T>,
2267    session: MessageContext,
2268) -> Result<()>
2269where
2270    T: EntityMessage + RequestMessage,
2271{
2272    let project_id = ProjectId::from_proto(request.remote_entity_id());
2273
2274    let host_connection_id = session
2275        .db()
2276        .await
2277        .host_for_mutating_project_request(project_id, session.connection_id)
2278        .await?;
2279    let payload = session.forward_request(host_connection_id, request).await?;
2280    response.send(payload)?;
2281    Ok(())
2282}
2283
2284async fn disallow_guest_request<T>(
2285    _request: T,
2286    response: Response<T>,
2287    _session: MessageContext,
2288) -> Result<()>
2289where
2290    T: RequestMessage,
2291{
2292    response.peer.respond_with_error(
2293        response.receipt,
2294        ErrorCode::Forbidden
2295            .message("request is not allowed for guests".to_string())
2296            .to_proto(),
2297    )?;
2298    response.responded.store(true, SeqCst);
2299    Ok(())
2300}
2301
2302async fn lsp_query(
2303    request: proto::LspQuery,
2304    response: Response<proto::LspQuery>,
2305    session: MessageContext,
2306) -> Result<()> {
2307    let (name, should_write) = request.query_name_and_write_permissions();
2308    tracing::Span::current().record("lsp_query_request", name);
2309    tracing::info!("lsp_query message received");
2310    if should_write {
2311        forward_mutating_project_request(request, response, session).await
2312    } else {
2313        forward_read_only_project_request(request, response, session).await
2314    }
2315}
2316
2317/// Notify other participants that a new buffer has been created
2318async fn create_buffer_for_peer(
2319    request: proto::CreateBufferForPeer,
2320    session: MessageContext,
2321) -> Result<()> {
2322    session
2323        .db()
2324        .await
2325        .check_user_is_project_host(
2326            ProjectId::from_proto(request.project_id),
2327            session.connection_id,
2328        )
2329        .await?;
2330    let peer_id = request.peer_id.context("invalid peer id")?;
2331    session
2332        .peer
2333        .forward_send(session.connection_id, peer_id.into(), request)?;
2334    Ok(())
2335}
2336
2337/// Notify other participants that a new image has been created
2338async fn create_image_for_peer(
2339    request: proto::CreateImageForPeer,
2340    session: MessageContext,
2341) -> Result<()> {
2342    session
2343        .db()
2344        .await
2345        .check_user_is_project_host(
2346            ProjectId::from_proto(request.project_id),
2347            session.connection_id,
2348        )
2349        .await?;
2350    let peer_id = request.peer_id.context("invalid peer id")?;
2351    session
2352        .peer
2353        .forward_send(session.connection_id, peer_id.into(), request)?;
2354    Ok(())
2355}
2356
2357/// Notify other participants that a buffer has been updated. This is
2358/// allowed for guests as long as the update is limited to selections.
2359async fn update_buffer(
2360    request: proto::UpdateBuffer,
2361    response: Response<proto::UpdateBuffer>,
2362    session: MessageContext,
2363) -> Result<()> {
2364    let project_id = ProjectId::from_proto(request.project_id);
2365    let mut capability = Capability::ReadOnly;
2366
2367    for op in request.operations.iter() {
2368        match op.variant {
2369            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2370            Some(_) => capability = Capability::ReadWrite,
2371        }
2372    }
2373
2374    let host = {
2375        let guard = session
2376            .db()
2377            .await
2378            .connections_for_buffer_update(project_id, session.connection_id, capability)
2379            .await?;
2380
2381        let (host, guests) = &*guard;
2382
2383        broadcast(
2384            Some(session.connection_id),
2385            guests.clone(),
2386            |connection_id| {
2387                session
2388                    .peer
2389                    .forward_send(session.connection_id, connection_id, request.clone())
2390            },
2391        );
2392
2393        *host
2394    };
2395
2396    if host != session.connection_id {
2397        session.forward_request(host, request.clone()).await?;
2398    }
2399
2400    response.send(proto::Ack {})?;
2401    Ok(())
2402}
2403
2404async fn forward_project_search_chunk(
2405    message: proto::FindSearchCandidatesChunk,
2406    response: Response<proto::FindSearchCandidatesChunk>,
2407    session: MessageContext,
2408) -> Result<()> {
2409    let peer_id = message.peer_id.context("missing peer_id")?;
2410    let payload = session
2411        .peer
2412        .forward_request(session.connection_id, peer_id.into(), message)
2413        .await?;
2414    response.send(payload)?;
2415    Ok(())
2416}
2417
2418/// Notify other participants that a project has been updated.
2419async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2420    request: T,
2421    session: MessageContext,
2422) -> Result<()> {
2423    let project_id = ProjectId::from_proto(request.remote_entity_id());
2424    let project_connection_ids = session
2425        .db()
2426        .await
2427        .project_connection_ids(project_id, session.connection_id, false)
2428        .await?;
2429
2430    broadcast(
2431        Some(session.connection_id),
2432        project_connection_ids.iter().copied(),
2433        |connection_id| {
2434            session
2435                .peer
2436                .forward_send(session.connection_id, connection_id, request.clone())
2437        },
2438    );
2439    Ok(())
2440}
2441
2442/// Start following another user in a call.
2443async fn follow(
2444    request: proto::Follow,
2445    response: Response<proto::Follow>,
2446    session: MessageContext,
2447) -> Result<()> {
2448    let room_id = RoomId::from_proto(request.room_id);
2449    let project_id = request.project_id.map(ProjectId::from_proto);
2450    let leader_id = request.leader_id.context("invalid leader id")?.into();
2451    let follower_id = session.connection_id;
2452
2453    session
2454        .db()
2455        .await
2456        .check_room_participants(room_id, leader_id, session.connection_id)
2457        .await?;
2458
2459    let response_payload = session.forward_request(leader_id, request).await?;
2460    response.send(response_payload)?;
2461
2462    if let Some(project_id) = project_id {
2463        let room = session
2464            .db()
2465            .await
2466            .follow(room_id, project_id, leader_id, follower_id)
2467            .await?;
2468        room_updated(&room, &session.peer);
2469    }
2470
2471    Ok(())
2472}
2473
2474/// Stop following another user in a call.
2475async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2476    let room_id = RoomId::from_proto(request.room_id);
2477    let project_id = request.project_id.map(ProjectId::from_proto);
2478    let leader_id = request.leader_id.context("invalid leader id")?.into();
2479    let follower_id = session.connection_id;
2480
2481    session
2482        .db()
2483        .await
2484        .check_room_participants(room_id, leader_id, session.connection_id)
2485        .await?;
2486
2487    session
2488        .peer
2489        .forward_send(session.connection_id, leader_id, request)?;
2490
2491    if let Some(project_id) = project_id {
2492        let room = session
2493            .db()
2494            .await
2495            .unfollow(room_id, project_id, leader_id, follower_id)
2496            .await?;
2497        room_updated(&room, &session.peer);
2498    }
2499
2500    Ok(())
2501}
2502
2503/// Notify everyone following you of your current location.
2504async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2505    let room_id = RoomId::from_proto(request.room_id);
2506    let database = session.db.lock().await;
2507
2508    let connection_ids = if let Some(project_id) = request.project_id {
2509        let project_id = ProjectId::from_proto(project_id);
2510        database
2511            .project_connection_ids(project_id, session.connection_id, true)
2512            .await?
2513    } else {
2514        database
2515            .room_connection_ids(room_id, session.connection_id)
2516            .await?
2517    };
2518
2519    // For now, don't send view update messages back to that view's current leader.
2520    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2521        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2522        _ => None,
2523    });
2524
2525    for connection_id in connection_ids.iter().cloned() {
2526        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2527            session
2528                .peer
2529                .forward_send(session.connection_id, connection_id, request.clone())?;
2530        }
2531    }
2532    Ok(())
2533}
2534
2535/// Get public data about users.
2536async fn get_users(
2537    request: proto::GetUsers,
2538    response: Response<proto::GetUsers>,
2539    session: MessageContext,
2540) -> Result<()> {
2541    let user_ids = request
2542        .user_ids
2543        .into_iter()
2544        .map(UserId::from_proto)
2545        .collect();
2546    let users = session
2547        .db()
2548        .await
2549        .get_users_by_ids(user_ids)
2550        .await?
2551        .into_iter()
2552        .map(|user| proto::User {
2553            id: user.id.to_proto(),
2554            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2555            github_login: user.github_login,
2556            name: user.name,
2557        })
2558        .collect();
2559    response.send(proto::UsersResponse { users })?;
2560    Ok(())
2561}
2562
2563/// Search for users (to invite) buy Github login
2564async fn fuzzy_search_users(
2565    request: proto::FuzzySearchUsers,
2566    response: Response<proto::FuzzySearchUsers>,
2567    session: MessageContext,
2568) -> Result<()> {
2569    let query = request.query;
2570    let users = match query.len() {
2571        0 => vec![],
2572        1 | 2 => session
2573            .db()
2574            .await
2575            .get_user_by_github_login(&query)
2576            .await?
2577            .into_iter()
2578            .collect(),
2579        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2580    };
2581    let users = users
2582        .into_iter()
2583        .filter(|user| user.id != session.user_id())
2584        .map(|user| proto::User {
2585            id: user.id.to_proto(),
2586            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2587            github_login: user.github_login,
2588            name: user.name,
2589        })
2590        .collect();
2591    response.send(proto::UsersResponse { users })?;
2592    Ok(())
2593}
2594
2595/// Send a contact request to another user.
2596async fn request_contact(
2597    request: proto::RequestContact,
2598    response: Response<proto::RequestContact>,
2599    session: MessageContext,
2600) -> Result<()> {
2601    let requester_id = session.user_id();
2602    let responder_id = UserId::from_proto(request.responder_id);
2603    if requester_id == responder_id {
2604        return Err(anyhow!("cannot add yourself as a contact"))?;
2605    }
2606
2607    let notifications = session
2608        .db()
2609        .await
2610        .send_contact_request(requester_id, responder_id)
2611        .await?;
2612
2613    // Update outgoing contact requests of requester
2614    let mut update = proto::UpdateContacts::default();
2615    update.outgoing_requests.push(responder_id.to_proto());
2616    for connection_id in session
2617        .connection_pool()
2618        .await
2619        .user_connection_ids(requester_id)
2620    {
2621        session.peer.send(connection_id, update.clone())?;
2622    }
2623
2624    // Update incoming contact requests of responder
2625    let mut update = proto::UpdateContacts::default();
2626    update
2627        .incoming_requests
2628        .push(proto::IncomingContactRequest {
2629            requester_id: requester_id.to_proto(),
2630        });
2631    let connection_pool = session.connection_pool().await;
2632    for connection_id in connection_pool.user_connection_ids(responder_id) {
2633        session.peer.send(connection_id, update.clone())?;
2634    }
2635
2636    send_notifications(&connection_pool, &session.peer, notifications);
2637
2638    response.send(proto::Ack {})?;
2639    Ok(())
2640}
2641
2642/// Accept or decline a contact request
2643async fn respond_to_contact_request(
2644    request: proto::RespondToContactRequest,
2645    response: Response<proto::RespondToContactRequest>,
2646    session: MessageContext,
2647) -> Result<()> {
2648    let responder_id = session.user_id();
2649    let requester_id = UserId::from_proto(request.requester_id);
2650    let db = session.db().await;
2651    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2652        db.dismiss_contact_notification(responder_id, requester_id)
2653            .await?;
2654    } else {
2655        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2656
2657        let notifications = db
2658            .respond_to_contact_request(responder_id, requester_id, accept)
2659            .await?;
2660        let requester_busy = db.is_user_busy(requester_id).await?;
2661        let responder_busy = db.is_user_busy(responder_id).await?;
2662
2663        let pool = session.connection_pool().await;
2664        // Update responder with new contact
2665        let mut update = proto::UpdateContacts::default();
2666        if accept {
2667            update
2668                .contacts
2669                .push(contact_for_user(requester_id, requester_busy, &pool));
2670        }
2671        update
2672            .remove_incoming_requests
2673            .push(requester_id.to_proto());
2674        for connection_id in pool.user_connection_ids(responder_id) {
2675            session.peer.send(connection_id, update.clone())?;
2676        }
2677
2678        // Update requester with new contact
2679        let mut update = proto::UpdateContacts::default();
2680        if accept {
2681            update
2682                .contacts
2683                .push(contact_for_user(responder_id, responder_busy, &pool));
2684        }
2685        update
2686            .remove_outgoing_requests
2687            .push(responder_id.to_proto());
2688
2689        for connection_id in pool.user_connection_ids(requester_id) {
2690            session.peer.send(connection_id, update.clone())?;
2691        }
2692
2693        send_notifications(&pool, &session.peer, notifications);
2694    }
2695
2696    response.send(proto::Ack {})?;
2697    Ok(())
2698}
2699
2700/// Remove a contact.
2701async fn remove_contact(
2702    request: proto::RemoveContact,
2703    response: Response<proto::RemoveContact>,
2704    session: MessageContext,
2705) -> Result<()> {
2706    let requester_id = session.user_id();
2707    let responder_id = UserId::from_proto(request.user_id);
2708    let db = session.db().await;
2709    let (contact_accepted, deleted_notification_id) =
2710        db.remove_contact(requester_id, responder_id).await?;
2711
2712    let pool = session.connection_pool().await;
2713    // Update outgoing contact requests of requester
2714    let mut update = proto::UpdateContacts::default();
2715    if contact_accepted {
2716        update.remove_contacts.push(responder_id.to_proto());
2717    } else {
2718        update
2719            .remove_outgoing_requests
2720            .push(responder_id.to_proto());
2721    }
2722    for connection_id in pool.user_connection_ids(requester_id) {
2723        session.peer.send(connection_id, update.clone())?;
2724    }
2725
2726    // Update incoming contact requests of responder
2727    let mut update = proto::UpdateContacts::default();
2728    if contact_accepted {
2729        update.remove_contacts.push(requester_id.to_proto());
2730    } else {
2731        update
2732            .remove_incoming_requests
2733            .push(requester_id.to_proto());
2734    }
2735    for connection_id in pool.user_connection_ids(responder_id) {
2736        session.peer.send(connection_id, update.clone())?;
2737        if let Some(notification_id) = deleted_notification_id {
2738            session.peer.send(
2739                connection_id,
2740                proto::DeleteNotification {
2741                    notification_id: notification_id.to_proto(),
2742                },
2743            )?;
2744        }
2745    }
2746
2747    response.send(proto::Ack {})?;
2748    Ok(())
2749}
2750
2751fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2752    version.0.minor < 139
2753}
2754
2755async fn subscribe_to_channels(
2756    _: proto::SubscribeToChannels,
2757    session: MessageContext,
2758) -> Result<()> {
2759    subscribe_user_to_channels(session.user_id(), &session).await?;
2760    Ok(())
2761}
2762
2763async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2764    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2765    let mut pool = session.connection_pool().await;
2766    for membership in &channels_for_user.channel_memberships {
2767        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2768    }
2769    session.peer.send(
2770        session.connection_id,
2771        build_update_user_channels(&channels_for_user),
2772    )?;
2773    session.peer.send(
2774        session.connection_id,
2775        build_channels_update(channels_for_user),
2776    )?;
2777    Ok(())
2778}
2779
2780/// Creates a new channel.
2781async fn create_channel(
2782    request: proto::CreateChannel,
2783    response: Response<proto::CreateChannel>,
2784    session: MessageContext,
2785) -> Result<()> {
2786    let db = session.db().await;
2787
2788    let parent_id = request.parent_id.map(ChannelId::from_proto);
2789    let (channel, membership) = db
2790        .create_channel(&request.name, parent_id, session.user_id())
2791        .await?;
2792
2793    let root_id = channel.root_id();
2794    let channel = Channel::from_model(channel);
2795
2796    response.send(proto::CreateChannelResponse {
2797        channel: Some(channel.to_proto()),
2798        parent_id: request.parent_id,
2799    })?;
2800
2801    let mut connection_pool = session.connection_pool().await;
2802    if let Some(membership) = membership {
2803        connection_pool.subscribe_to_channel(
2804            membership.user_id,
2805            membership.channel_id,
2806            membership.role,
2807        );
2808        let update = proto::UpdateUserChannels {
2809            channel_memberships: vec![proto::ChannelMembership {
2810                channel_id: membership.channel_id.to_proto(),
2811                role: membership.role.into(),
2812            }],
2813            ..Default::default()
2814        };
2815        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2816            session.peer.send(connection_id, update.clone())?;
2817        }
2818    }
2819
2820    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2821        if !role.can_see_channel(channel.visibility) {
2822            continue;
2823        }
2824
2825        let update = proto::UpdateChannels {
2826            channels: vec![channel.to_proto()],
2827            ..Default::default()
2828        };
2829        session.peer.send(connection_id, update.clone())?;
2830    }
2831
2832    Ok(())
2833}
2834
2835/// Delete a channel
2836async fn delete_channel(
2837    request: proto::DeleteChannel,
2838    response: Response<proto::DeleteChannel>,
2839    session: MessageContext,
2840) -> Result<()> {
2841    let db = session.db().await;
2842
2843    let channel_id = request.channel_id;
2844    let (root_channel, removed_channels) = db
2845        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2846        .await?;
2847    response.send(proto::Ack {})?;
2848
2849    // Notify members of removed channels
2850    let mut update = proto::UpdateChannels::default();
2851    update
2852        .delete_channels
2853        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2854
2855    let connection_pool = session.connection_pool().await;
2856    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2857        session.peer.send(connection_id, update.clone())?;
2858    }
2859
2860    Ok(())
2861}
2862
2863/// Invite someone to join a channel.
2864async fn invite_channel_member(
2865    request: proto::InviteChannelMember,
2866    response: Response<proto::InviteChannelMember>,
2867    session: MessageContext,
2868) -> Result<()> {
2869    let db = session.db().await;
2870    let channel_id = ChannelId::from_proto(request.channel_id);
2871    let invitee_id = UserId::from_proto(request.user_id);
2872    let InviteMemberResult {
2873        channel,
2874        notifications,
2875    } = db
2876        .invite_channel_member(
2877            channel_id,
2878            invitee_id,
2879            session.user_id(),
2880            request.role().into(),
2881        )
2882        .await?;
2883
2884    let update = proto::UpdateChannels {
2885        channel_invitations: vec![channel.to_proto()],
2886        ..Default::default()
2887    };
2888
2889    let connection_pool = session.connection_pool().await;
2890    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2891        session.peer.send(connection_id, update.clone())?;
2892    }
2893
2894    send_notifications(&connection_pool, &session.peer, notifications);
2895
2896    response.send(proto::Ack {})?;
2897    Ok(())
2898}
2899
2900/// remove someone from a channel
2901async fn remove_channel_member(
2902    request: proto::RemoveChannelMember,
2903    response: Response<proto::RemoveChannelMember>,
2904    session: MessageContext,
2905) -> Result<()> {
2906    let db = session.db().await;
2907    let channel_id = ChannelId::from_proto(request.channel_id);
2908    let member_id = UserId::from_proto(request.user_id);
2909
2910    let RemoveChannelMemberResult {
2911        membership_update,
2912        notification_id,
2913    } = db
2914        .remove_channel_member(channel_id, member_id, session.user_id())
2915        .await?;
2916
2917    let mut connection_pool = session.connection_pool().await;
2918    notify_membership_updated(
2919        &mut connection_pool,
2920        membership_update,
2921        member_id,
2922        &session.peer,
2923    );
2924    for connection_id in connection_pool.user_connection_ids(member_id) {
2925        if let Some(notification_id) = notification_id {
2926            session
2927                .peer
2928                .send(
2929                    connection_id,
2930                    proto::DeleteNotification {
2931                        notification_id: notification_id.to_proto(),
2932                    },
2933                )
2934                .trace_err();
2935        }
2936    }
2937
2938    response.send(proto::Ack {})?;
2939    Ok(())
2940}
2941
2942/// Toggle the channel between public and private.
2943/// Care is taken to maintain the invariant that public channels only descend from public channels,
2944/// (though members-only channels can appear at any point in the hierarchy).
2945async fn set_channel_visibility(
2946    request: proto::SetChannelVisibility,
2947    response: Response<proto::SetChannelVisibility>,
2948    session: MessageContext,
2949) -> Result<()> {
2950    let db = session.db().await;
2951    let channel_id = ChannelId::from_proto(request.channel_id);
2952    let visibility = request.visibility().into();
2953
2954    let channel_model = db
2955        .set_channel_visibility(channel_id, visibility, session.user_id())
2956        .await?;
2957    let root_id = channel_model.root_id();
2958    let channel = Channel::from_model(channel_model);
2959
2960    let mut connection_pool = session.connection_pool().await;
2961    for (user_id, role) in connection_pool
2962        .channel_user_ids(root_id)
2963        .collect::<Vec<_>>()
2964        .into_iter()
2965    {
2966        let update = if role.can_see_channel(channel.visibility) {
2967            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2968            proto::UpdateChannels {
2969                channels: vec![channel.to_proto()],
2970                ..Default::default()
2971            }
2972        } else {
2973            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2974            proto::UpdateChannels {
2975                delete_channels: vec![channel.id.to_proto()],
2976                ..Default::default()
2977            }
2978        };
2979
2980        for connection_id in connection_pool.user_connection_ids(user_id) {
2981            session.peer.send(connection_id, update.clone())?;
2982        }
2983    }
2984
2985    response.send(proto::Ack {})?;
2986    Ok(())
2987}
2988
2989/// Alter the role for a user in the channel.
2990async fn set_channel_member_role(
2991    request: proto::SetChannelMemberRole,
2992    response: Response<proto::SetChannelMemberRole>,
2993    session: MessageContext,
2994) -> Result<()> {
2995    let db = session.db().await;
2996    let channel_id = ChannelId::from_proto(request.channel_id);
2997    let member_id = UserId::from_proto(request.user_id);
2998    let result = db
2999        .set_channel_member_role(
3000            channel_id,
3001            session.user_id(),
3002            member_id,
3003            request.role().into(),
3004        )
3005        .await?;
3006
3007    match result {
3008        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3009            let mut connection_pool = session.connection_pool().await;
3010            notify_membership_updated(
3011                &mut connection_pool,
3012                membership_update,
3013                member_id,
3014                &session.peer,
3015            )
3016        }
3017        db::SetMemberRoleResult::InviteUpdated(channel) => {
3018            let update = proto::UpdateChannels {
3019                channel_invitations: vec![channel.to_proto()],
3020                ..Default::default()
3021            };
3022
3023            for connection_id in session
3024                .connection_pool()
3025                .await
3026                .user_connection_ids(member_id)
3027            {
3028                session.peer.send(connection_id, update.clone())?;
3029            }
3030        }
3031    }
3032
3033    response.send(proto::Ack {})?;
3034    Ok(())
3035}
3036
3037/// Change the name of a channel
3038async fn rename_channel(
3039    request: proto::RenameChannel,
3040    response: Response<proto::RenameChannel>,
3041    session: MessageContext,
3042) -> Result<()> {
3043    let db = session.db().await;
3044    let channel_id = ChannelId::from_proto(request.channel_id);
3045    let channel_model = db
3046        .rename_channel(channel_id, session.user_id(), &request.name)
3047        .await?;
3048    let root_id = channel_model.root_id();
3049    let channel = Channel::from_model(channel_model);
3050
3051    response.send(proto::RenameChannelResponse {
3052        channel: Some(channel.to_proto()),
3053    })?;
3054
3055    let connection_pool = session.connection_pool().await;
3056    let update = proto::UpdateChannels {
3057        channels: vec![channel.to_proto()],
3058        ..Default::default()
3059    };
3060    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3061        if role.can_see_channel(channel.visibility) {
3062            session.peer.send(connection_id, update.clone())?;
3063        }
3064    }
3065
3066    Ok(())
3067}
3068
3069/// Move a channel to a new parent.
3070async fn move_channel(
3071    request: proto::MoveChannel,
3072    response: Response<proto::MoveChannel>,
3073    session: MessageContext,
3074) -> Result<()> {
3075    let channel_id = ChannelId::from_proto(request.channel_id);
3076    let to = ChannelId::from_proto(request.to);
3077
3078    let (root_id, channels) = session
3079        .db()
3080        .await
3081        .move_channel(channel_id, to, session.user_id())
3082        .await?;
3083
3084    let connection_pool = session.connection_pool().await;
3085    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3086        let channels = channels
3087            .iter()
3088            .filter_map(|channel| {
3089                if role.can_see_channel(channel.visibility) {
3090                    Some(channel.to_proto())
3091                } else {
3092                    None
3093                }
3094            })
3095            .collect::<Vec<_>>();
3096        if channels.is_empty() {
3097            continue;
3098        }
3099
3100        let update = proto::UpdateChannels {
3101            channels,
3102            ..Default::default()
3103        };
3104
3105        session.peer.send(connection_id, update.clone())?;
3106    }
3107
3108    response.send(Ack {})?;
3109    Ok(())
3110}
3111
3112async fn reorder_channel(
3113    request: proto::ReorderChannel,
3114    response: Response<proto::ReorderChannel>,
3115    session: MessageContext,
3116) -> Result<()> {
3117    let channel_id = ChannelId::from_proto(request.channel_id);
3118    let direction = request.direction();
3119
3120    let updated_channels = session
3121        .db()
3122        .await
3123        .reorder_channel(channel_id, direction, session.user_id())
3124        .await?;
3125
3126    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3127        let connection_pool = session.connection_pool().await;
3128        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3129            let channels = updated_channels
3130                .iter()
3131                .filter_map(|channel| {
3132                    if role.can_see_channel(channel.visibility) {
3133                        Some(channel.to_proto())
3134                    } else {
3135                        None
3136                    }
3137                })
3138                .collect::<Vec<_>>();
3139
3140            if channels.is_empty() {
3141                continue;
3142            }
3143
3144            let update = proto::UpdateChannels {
3145                channels,
3146                ..Default::default()
3147            };
3148
3149            session.peer.send(connection_id, update.clone())?;
3150        }
3151    }
3152
3153    response.send(Ack {})?;
3154    Ok(())
3155}
3156
3157/// Get the list of channel members
3158async fn get_channel_members(
3159    request: proto::GetChannelMembers,
3160    response: Response<proto::GetChannelMembers>,
3161    session: MessageContext,
3162) -> Result<()> {
3163    let db = session.db().await;
3164    let channel_id = ChannelId::from_proto(request.channel_id);
3165    let limit = if request.limit == 0 {
3166        u16::MAX as u64
3167    } else {
3168        request.limit
3169    };
3170    let (members, users) = db
3171        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3172        .await?;
3173    response.send(proto::GetChannelMembersResponse { members, users })?;
3174    Ok(())
3175}
3176
3177/// Accept or decline a channel invitation.
3178async fn respond_to_channel_invite(
3179    request: proto::RespondToChannelInvite,
3180    response: Response<proto::RespondToChannelInvite>,
3181    session: MessageContext,
3182) -> Result<()> {
3183    let db = session.db().await;
3184    let channel_id = ChannelId::from_proto(request.channel_id);
3185    let RespondToChannelInvite {
3186        membership_update,
3187        notifications,
3188    } = db
3189        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3190        .await?;
3191
3192    let mut connection_pool = session.connection_pool().await;
3193    if let Some(membership_update) = membership_update {
3194        notify_membership_updated(
3195            &mut connection_pool,
3196            membership_update,
3197            session.user_id(),
3198            &session.peer,
3199        );
3200    } else {
3201        let update = proto::UpdateChannels {
3202            remove_channel_invitations: vec![channel_id.to_proto()],
3203            ..Default::default()
3204        };
3205
3206        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3207            session.peer.send(connection_id, update.clone())?;
3208        }
3209    };
3210
3211    send_notifications(&connection_pool, &session.peer, notifications);
3212
3213    response.send(proto::Ack {})?;
3214
3215    Ok(())
3216}
3217
3218/// Join the channels' room
3219async fn join_channel(
3220    request: proto::JoinChannel,
3221    response: Response<proto::JoinChannel>,
3222    session: MessageContext,
3223) -> Result<()> {
3224    let channel_id = ChannelId::from_proto(request.channel_id);
3225    join_channel_internal(channel_id, Box::new(response), session).await
3226}
3227
3228trait JoinChannelInternalResponse {
3229    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3230}
3231impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3232    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3233        Response::<proto::JoinChannel>::send(self, result)
3234    }
3235}
3236impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3237    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3238        Response::<proto::JoinRoom>::send(self, result)
3239    }
3240}
3241
3242async fn join_channel_internal(
3243    channel_id: ChannelId,
3244    response: Box<impl JoinChannelInternalResponse>,
3245    session: MessageContext,
3246) -> Result<()> {
3247    let joined_room = {
3248        let mut db = session.db().await;
3249        // If zed quits without leaving the room, and the user re-opens zed before the
3250        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3251        // room they were in.
3252        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3253            tracing::info!(
3254                stale_connection_id = %connection,
3255                "cleaning up stale connection",
3256            );
3257            drop(db);
3258            leave_room_for_session(&session, connection).await?;
3259            db = session.db().await;
3260        }
3261
3262        let (joined_room, membership_updated, role) = db
3263            .join_channel(channel_id, session.user_id(), session.connection_id)
3264            .await?;
3265
3266        let live_kit_connection_info =
3267            session
3268                .app_state
3269                .livekit_client
3270                .as_ref()
3271                .and_then(|live_kit| {
3272                    let (can_publish, token) = if role == ChannelRole::Guest {
3273                        (
3274                            false,
3275                            live_kit
3276                                .guest_token(
3277                                    &joined_room.room.livekit_room,
3278                                    &session.user_id().to_string(),
3279                                )
3280                                .trace_err()?,
3281                        )
3282                    } else {
3283                        (
3284                            true,
3285                            live_kit
3286                                .room_token(
3287                                    &joined_room.room.livekit_room,
3288                                    &session.user_id().to_string(),
3289                                )
3290                                .trace_err()?,
3291                        )
3292                    };
3293
3294                    Some(LiveKitConnectionInfo {
3295                        server_url: live_kit.url().into(),
3296                        token,
3297                        can_publish,
3298                    })
3299                });
3300
3301        response.send(proto::JoinRoomResponse {
3302            room: Some(joined_room.room.clone()),
3303            channel_id: joined_room
3304                .channel
3305                .as_ref()
3306                .map(|channel| channel.id.to_proto()),
3307            live_kit_connection_info,
3308        })?;
3309
3310        let mut connection_pool = session.connection_pool().await;
3311        if let Some(membership_updated) = membership_updated {
3312            notify_membership_updated(
3313                &mut connection_pool,
3314                membership_updated,
3315                session.user_id(),
3316                &session.peer,
3317            );
3318        }
3319
3320        room_updated(&joined_room.room, &session.peer);
3321
3322        joined_room
3323    };
3324
3325    channel_updated(
3326        &joined_room.channel.context("channel not returned")?,
3327        &joined_room.room,
3328        &session.peer,
3329        &*session.connection_pool().await,
3330    );
3331
3332    update_user_contacts(session.user_id(), &session).await?;
3333    Ok(())
3334}
3335
3336/// Start editing the channel notes
3337async fn join_channel_buffer(
3338    request: proto::JoinChannelBuffer,
3339    response: Response<proto::JoinChannelBuffer>,
3340    session: MessageContext,
3341) -> Result<()> {
3342    let db = session.db().await;
3343    let channel_id = ChannelId::from_proto(request.channel_id);
3344
3345    let open_response = db
3346        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3347        .await?;
3348
3349    let collaborators = open_response.collaborators.clone();
3350    response.send(open_response)?;
3351
3352    let update = UpdateChannelBufferCollaborators {
3353        channel_id: channel_id.to_proto(),
3354        collaborators: collaborators.clone(),
3355    };
3356    channel_buffer_updated(
3357        session.connection_id,
3358        collaborators
3359            .iter()
3360            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3361        &update,
3362        &session.peer,
3363    );
3364
3365    Ok(())
3366}
3367
3368/// Edit the channel notes
3369async fn update_channel_buffer(
3370    request: proto::UpdateChannelBuffer,
3371    session: MessageContext,
3372) -> Result<()> {
3373    let db = session.db().await;
3374    let channel_id = ChannelId::from_proto(request.channel_id);
3375
3376    let (collaborators, epoch, version) = db
3377        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3378        .await?;
3379
3380    channel_buffer_updated(
3381        session.connection_id,
3382        collaborators.clone(),
3383        &proto::UpdateChannelBuffer {
3384            channel_id: channel_id.to_proto(),
3385            operations: request.operations,
3386        },
3387        &session.peer,
3388    );
3389
3390    let pool = &*session.connection_pool().await;
3391
3392    let non_collaborators =
3393        pool.channel_connection_ids(channel_id)
3394            .filter_map(|(connection_id, _)| {
3395                if collaborators.contains(&connection_id) {
3396                    None
3397                } else {
3398                    Some(connection_id)
3399                }
3400            });
3401
3402    broadcast(None, non_collaborators, |peer_id| {
3403        session.peer.send(
3404            peer_id,
3405            proto::UpdateChannels {
3406                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3407                    channel_id: channel_id.to_proto(),
3408                    epoch: epoch as u64,
3409                    version: version.clone(),
3410                }],
3411                ..Default::default()
3412            },
3413        )
3414    });
3415
3416    Ok(())
3417}
3418
3419/// Rejoin the channel notes after a connection blip
3420async fn rejoin_channel_buffers(
3421    request: proto::RejoinChannelBuffers,
3422    response: Response<proto::RejoinChannelBuffers>,
3423    session: MessageContext,
3424) -> Result<()> {
3425    let db = session.db().await;
3426    let buffers = db
3427        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3428        .await?;
3429
3430    for rejoined_buffer in &buffers {
3431        let collaborators_to_notify = rejoined_buffer
3432            .buffer
3433            .collaborators
3434            .iter()
3435            .filter_map(|c| Some(c.peer_id?.into()));
3436        channel_buffer_updated(
3437            session.connection_id,
3438            collaborators_to_notify,
3439            &proto::UpdateChannelBufferCollaborators {
3440                channel_id: rejoined_buffer.buffer.channel_id,
3441                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3442            },
3443            &session.peer,
3444        );
3445    }
3446
3447    response.send(proto::RejoinChannelBuffersResponse {
3448        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3449    })?;
3450
3451    Ok(())
3452}
3453
3454/// Stop editing the channel notes
3455async fn leave_channel_buffer(
3456    request: proto::LeaveChannelBuffer,
3457    response: Response<proto::LeaveChannelBuffer>,
3458    session: MessageContext,
3459) -> Result<()> {
3460    let db = session.db().await;
3461    let channel_id = ChannelId::from_proto(request.channel_id);
3462
3463    let left_buffer = db
3464        .leave_channel_buffer(channel_id, session.connection_id)
3465        .await?;
3466
3467    response.send(Ack {})?;
3468
3469    channel_buffer_updated(
3470        session.connection_id,
3471        left_buffer.connections,
3472        &proto::UpdateChannelBufferCollaborators {
3473            channel_id: channel_id.to_proto(),
3474            collaborators: left_buffer.collaborators,
3475        },
3476        &session.peer,
3477    );
3478
3479    Ok(())
3480}
3481
3482fn channel_buffer_updated<T: EnvelopedMessage>(
3483    sender_id: ConnectionId,
3484    collaborators: impl IntoIterator<Item = ConnectionId>,
3485    message: &T,
3486    peer: &Peer,
3487) {
3488    broadcast(Some(sender_id), collaborators, |peer_id| {
3489        peer.send(peer_id, message.clone())
3490    });
3491}
3492
3493fn send_notifications(
3494    connection_pool: &ConnectionPool,
3495    peer: &Peer,
3496    notifications: db::NotificationBatch,
3497) {
3498    for (user_id, notification) in notifications {
3499        for connection_id in connection_pool.user_connection_ids(user_id) {
3500            if let Err(error) = peer.send(
3501                connection_id,
3502                proto::AddNotification {
3503                    notification: Some(notification.clone()),
3504                },
3505            ) {
3506                tracing::error!(
3507                    "failed to send notification to {:?} {}",
3508                    connection_id,
3509                    error
3510                );
3511            }
3512        }
3513    }
3514}
3515
3516/// Send a message to the channel
3517async fn send_channel_message(
3518    _request: proto::SendChannelMessage,
3519    _response: Response<proto::SendChannelMessage>,
3520    _session: MessageContext,
3521) -> Result<()> {
3522    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3523}
3524
3525/// Delete a channel message
3526async fn remove_channel_message(
3527    _request: proto::RemoveChannelMessage,
3528    _response: Response<proto::RemoveChannelMessage>,
3529    _session: MessageContext,
3530) -> Result<()> {
3531    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3532}
3533
3534async fn update_channel_message(
3535    _request: proto::UpdateChannelMessage,
3536    _response: Response<proto::UpdateChannelMessage>,
3537    _session: MessageContext,
3538) -> Result<()> {
3539    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3540}
3541
3542/// Mark a channel message as read
3543async fn acknowledge_channel_message(
3544    _request: proto::AckChannelMessage,
3545    _session: MessageContext,
3546) -> Result<()> {
3547    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3548}
3549
3550/// Mark a buffer version as synced
3551async fn acknowledge_buffer_version(
3552    request: proto::AckBufferOperation,
3553    session: MessageContext,
3554) -> Result<()> {
3555    let buffer_id = BufferId::from_proto(request.buffer_id);
3556    session
3557        .db()
3558        .await
3559        .observe_buffer_version(
3560            buffer_id,
3561            session.user_id(),
3562            request.epoch as i32,
3563            &request.version,
3564        )
3565        .await?;
3566    Ok(())
3567}
3568
3569/// Start receiving chat updates for a channel
3570async fn join_channel_chat(
3571    _request: proto::JoinChannelChat,
3572    _response: Response<proto::JoinChannelChat>,
3573    _session: MessageContext,
3574) -> Result<()> {
3575    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3576}
3577
3578/// Stop receiving chat updates for a channel
3579async fn leave_channel_chat(
3580    _request: proto::LeaveChannelChat,
3581    _session: MessageContext,
3582) -> Result<()> {
3583    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Retrieve the chat history for a channel
3587async fn get_channel_messages(
3588    _request: proto::GetChannelMessages,
3589    _response: Response<proto::GetChannelMessages>,
3590    _session: MessageContext,
3591) -> Result<()> {
3592    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3593}
3594
3595/// Retrieve specific chat messages
3596async fn get_channel_messages_by_id(
3597    _request: proto::GetChannelMessagesById,
3598    _response: Response<proto::GetChannelMessagesById>,
3599    _session: MessageContext,
3600) -> Result<()> {
3601    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3602}
3603
3604/// Retrieve the current users notifications
3605async fn get_notifications(
3606    request: proto::GetNotifications,
3607    response: Response<proto::GetNotifications>,
3608    session: MessageContext,
3609) -> Result<()> {
3610    let notifications = session
3611        .db()
3612        .await
3613        .get_notifications(
3614            session.user_id(),
3615            NOTIFICATION_COUNT_PER_PAGE,
3616            request.before_id.map(db::NotificationId::from_proto),
3617        )
3618        .await?;
3619    response.send(proto::GetNotificationsResponse {
3620        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3621        notifications,
3622    })?;
3623    Ok(())
3624}
3625
3626/// Mark notifications as read
3627async fn mark_notification_as_read(
3628    request: proto::MarkNotificationRead,
3629    response: Response<proto::MarkNotificationRead>,
3630    session: MessageContext,
3631) -> Result<()> {
3632    let database = &session.db().await;
3633    let notifications = database
3634        .mark_notification_as_read_by_id(
3635            session.user_id(),
3636            NotificationId::from_proto(request.notification_id),
3637        )
3638        .await?;
3639    send_notifications(
3640        &*session.connection_pool().await,
3641        &session.peer,
3642        notifications,
3643    );
3644    response.send(proto::Ack {})?;
3645    Ok(())
3646}
3647
3648fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3649    let message = match message {
3650        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3651        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3652        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3653        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3654        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3655            code: frame.code.into(),
3656            reason: frame.reason.as_str().to_owned().into(),
3657        })),
3658        // We should never receive a frame while reading the message, according
3659        // to the `tungstenite` maintainers:
3660        //
3661        // > It cannot occur when you read messages from the WebSocket, but it
3662        // > can be used when you want to send the raw frames (e.g. you want to
3663        // > send the frames to the WebSocket without composing the full message first).
3664        // >
3665        // > — https://github.com/snapview/tungstenite-rs/issues/268
3666        TungsteniteMessage::Frame(_) => {
3667            bail!("received an unexpected frame while reading the message")
3668        }
3669    };
3670
3671    Ok(message)
3672}
3673
3674fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3675    match message {
3676        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3677        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3678        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3679        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3680        AxumMessage::Close(frame) => {
3681            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3682                code: frame.code.into(),
3683                reason: frame.reason.as_ref().into(),
3684            }))
3685        }
3686    }
3687}
3688
3689fn notify_membership_updated(
3690    connection_pool: &mut ConnectionPool,
3691    result: MembershipUpdated,
3692    user_id: UserId,
3693    peer: &Peer,
3694) {
3695    for membership in &result.new_channels.channel_memberships {
3696        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3697    }
3698    for channel_id in &result.removed_channels {
3699        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3700    }
3701
3702    let user_channels_update = proto::UpdateUserChannels {
3703        channel_memberships: result
3704            .new_channels
3705            .channel_memberships
3706            .iter()
3707            .map(|cm| proto::ChannelMembership {
3708                channel_id: cm.channel_id.to_proto(),
3709                role: cm.role.into(),
3710            })
3711            .collect(),
3712        ..Default::default()
3713    };
3714
3715    let mut update = build_channels_update(result.new_channels);
3716    update.delete_channels = result
3717        .removed_channels
3718        .into_iter()
3719        .map(|id| id.to_proto())
3720        .collect();
3721    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3722
3723    for connection_id in connection_pool.user_connection_ids(user_id) {
3724        peer.send(connection_id, user_channels_update.clone())
3725            .trace_err();
3726        peer.send(connection_id, update.clone()).trace_err();
3727    }
3728}
3729
3730fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3731    proto::UpdateUserChannels {
3732        channel_memberships: channels
3733            .channel_memberships
3734            .iter()
3735            .map(|m| proto::ChannelMembership {
3736                channel_id: m.channel_id.to_proto(),
3737                role: m.role.into(),
3738            })
3739            .collect(),
3740        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3741    }
3742}
3743
3744fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3745    let mut update = proto::UpdateChannels::default();
3746
3747    for channel in channels.channels {
3748        update.channels.push(channel.to_proto());
3749    }
3750
3751    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3752
3753    for (channel_id, participants) in channels.channel_participants {
3754        update
3755            .channel_participants
3756            .push(proto::ChannelParticipants {
3757                channel_id: channel_id.to_proto(),
3758                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3759            });
3760    }
3761
3762    for channel in channels.invited_channels {
3763        update.channel_invitations.push(channel.to_proto());
3764    }
3765
3766    update
3767}
3768
3769fn build_initial_contacts_update(
3770    contacts: Vec<db::Contact>,
3771    pool: &ConnectionPool,
3772) -> proto::UpdateContacts {
3773    let mut update = proto::UpdateContacts::default();
3774
3775    for contact in contacts {
3776        match contact {
3777            db::Contact::Accepted { user_id, busy } => {
3778                update.contacts.push(contact_for_user(user_id, busy, pool));
3779            }
3780            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3781            db::Contact::Incoming { user_id } => {
3782                update
3783                    .incoming_requests
3784                    .push(proto::IncomingContactRequest {
3785                        requester_id: user_id.to_proto(),
3786                    })
3787            }
3788        }
3789    }
3790
3791    update
3792}
3793
3794fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3795    proto::Contact {
3796        user_id: user_id.to_proto(),
3797        online: pool.is_user_online(user_id),
3798        busy,
3799    }
3800}
3801
3802fn room_updated(room: &proto::Room, peer: &Peer) {
3803    broadcast(
3804        None,
3805        room.participants
3806            .iter()
3807            .filter_map(|participant| Some(participant.peer_id?.into())),
3808        |peer_id| {
3809            peer.send(
3810                peer_id,
3811                proto::RoomUpdated {
3812                    room: Some(room.clone()),
3813                },
3814            )
3815        },
3816    );
3817}
3818
3819fn channel_updated(
3820    channel: &db::channel::Model,
3821    room: &proto::Room,
3822    peer: &Peer,
3823    pool: &ConnectionPool,
3824) {
3825    let participants = room
3826        .participants
3827        .iter()
3828        .map(|p| p.user_id)
3829        .collect::<Vec<_>>();
3830
3831    broadcast(
3832        None,
3833        pool.channel_connection_ids(channel.root_id())
3834            .filter_map(|(channel_id, role)| {
3835                role.can_see_channel(channel.visibility)
3836                    .then_some(channel_id)
3837            }),
3838        |peer_id| {
3839            peer.send(
3840                peer_id,
3841                proto::UpdateChannels {
3842                    channel_participants: vec![proto::ChannelParticipants {
3843                        channel_id: channel.id.to_proto(),
3844                        participant_user_ids: participants.clone(),
3845                    }],
3846                    ..Default::default()
3847                },
3848            )
3849        },
3850    );
3851}
3852
3853async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3854    let db = session.db().await;
3855
3856    let contacts = db.get_contacts(user_id).await?;
3857    let busy = db.is_user_busy(user_id).await?;
3858
3859    let pool = session.connection_pool().await;
3860    let updated_contact = contact_for_user(user_id, busy, &pool);
3861    for contact in contacts {
3862        if let db::Contact::Accepted {
3863            user_id: contact_user_id,
3864            ..
3865        } = contact
3866        {
3867            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3868                session
3869                    .peer
3870                    .send(
3871                        contact_conn_id,
3872                        proto::UpdateContacts {
3873                            contacts: vec![updated_contact.clone()],
3874                            remove_contacts: Default::default(),
3875                            incoming_requests: Default::default(),
3876                            remove_incoming_requests: Default::default(),
3877                            outgoing_requests: Default::default(),
3878                            remove_outgoing_requests: Default::default(),
3879                        },
3880                    )
3881                    .trace_err();
3882            }
3883        }
3884    }
3885    Ok(())
3886}
3887
3888async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3889    let mut contacts_to_update = HashSet::default();
3890
3891    let room_id;
3892    let canceled_calls_to_user_ids;
3893    let livekit_room;
3894    let delete_livekit_room;
3895    let room;
3896    let channel;
3897
3898    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3899        contacts_to_update.insert(session.user_id());
3900
3901        for project in left_room.left_projects.values() {
3902            project_left(project, session);
3903        }
3904
3905        room_id = RoomId::from_proto(left_room.room.id);
3906        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3907        livekit_room = mem::take(&mut left_room.room.livekit_room);
3908        delete_livekit_room = left_room.deleted;
3909        room = mem::take(&mut left_room.room);
3910        channel = mem::take(&mut left_room.channel);
3911
3912        room_updated(&room, &session.peer);
3913    } else {
3914        return Ok(());
3915    }
3916
3917    if let Some(channel) = channel {
3918        channel_updated(
3919            &channel,
3920            &room,
3921            &session.peer,
3922            &*session.connection_pool().await,
3923        );
3924    }
3925
3926    {
3927        let pool = session.connection_pool().await;
3928        for canceled_user_id in canceled_calls_to_user_ids {
3929            for connection_id in pool.user_connection_ids(canceled_user_id) {
3930                session
3931                    .peer
3932                    .send(
3933                        connection_id,
3934                        proto::CallCanceled {
3935                            room_id: room_id.to_proto(),
3936                        },
3937                    )
3938                    .trace_err();
3939            }
3940            contacts_to_update.insert(canceled_user_id);
3941        }
3942    }
3943
3944    for contact_user_id in contacts_to_update {
3945        update_user_contacts(contact_user_id, session).await?;
3946    }
3947
3948    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3949        live_kit
3950            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3951            .await
3952            .trace_err();
3953
3954        if delete_livekit_room {
3955            live_kit.delete_room(livekit_room).await.trace_err();
3956        }
3957    }
3958
3959    Ok(())
3960}
3961
3962async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3963    let left_channel_buffers = session
3964        .db()
3965        .await
3966        .leave_channel_buffers(session.connection_id)
3967        .await?;
3968
3969    for left_buffer in left_channel_buffers {
3970        channel_buffer_updated(
3971            session.connection_id,
3972            left_buffer.connections,
3973            &proto::UpdateChannelBufferCollaborators {
3974                channel_id: left_buffer.channel_id.to_proto(),
3975                collaborators: left_buffer.collaborators,
3976            },
3977            &session.peer,
3978        );
3979    }
3980
3981    Ok(())
3982}
3983
3984fn project_left(project: &db::LeftProject, session: &Session) {
3985    for connection_id in &project.connection_ids {
3986        if project.should_unshare {
3987            session
3988                .peer
3989                .send(
3990                    *connection_id,
3991                    proto::UnshareProject {
3992                        project_id: project.id.to_proto(),
3993                    },
3994                )
3995                .trace_err();
3996        } else {
3997            session
3998                .peer
3999                .send(
4000                    *connection_id,
4001                    proto::RemoveProjectCollaborator {
4002                        project_id: project.id.to_proto(),
4003                        peer_id: Some(session.connection_id.into()),
4004                    },
4005                )
4006                .trace_err();
4007        }
4008    }
4009}
4010
4011async fn share_agent_thread(
4012    request: proto::ShareAgentThread,
4013    response: Response<proto::ShareAgentThread>,
4014    session: MessageContext,
4015) -> Result<()> {
4016    let user_id = session.user_id();
4017
4018    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4019        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4020
4021    session
4022        .db()
4023        .await
4024        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4025        .await?;
4026
4027    response.send(proto::Ack {})?;
4028
4029    Ok(())
4030}
4031
4032async fn get_shared_agent_thread(
4033    request: proto::GetSharedAgentThread,
4034    response: Response<proto::GetSharedAgentThread>,
4035    session: MessageContext,
4036) -> Result<()> {
4037    let share_id = SharedThreadId::from_proto(request.session_id)
4038        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4039
4040    let result = session.db().await.get_shared_thread(share_id).await?;
4041
4042    match result {
4043        Some((thread, username)) => {
4044            response.send(proto::GetSharedAgentThreadResponse {
4045                title: thread.title,
4046                thread_data: thread.data,
4047                sharer_username: username,
4048                created_at: thread.created_at.and_utc().to_rfc3339(),
4049            })?;
4050        }
4051        None => {
4052            return Err(anyhow!("Shared thread not found").into());
4053        }
4054    }
4055
4056    Ok(())
4057}
4058
4059pub trait ResultExt {
4060    type Ok;
4061
4062    fn trace_err(self) -> Option<Self::Ok>;
4063}
4064
4065impl<T, E> ResultExt for Result<T, E>
4066where
4067    E: std::fmt::Debug,
4068{
4069    type Ok = T;
4070
4071    #[track_caller]
4072    fn trace_err(self) -> Option<T> {
4073        match self {
4074            Ok(value) => Some(value),
4075            Err(error) => {
4076                tracing::error!("{:?}", error);
4077                None
4078            }
4079        }
4080    }
4081}