rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::{
  27    Extension, Router, TypedHeader,
  28    body::Body,
  29    extract::{
  30        ConnectInfo, WebSocketUpgrade,
  31        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  32    },
  33    headers::{Header, HeaderName},
  34    http::StatusCode,
  35    middleware,
  36    response::IntoResponse,
  37    routing::get,
  38};
  39use chrono::Utc;
  40use collections::{HashMap, HashSet};
  41pub use connection_pool::{ConnectionPool, ZedVersion};
  42use core::fmt::{self, Debug, Formatter};
  43use reqwest_client::ReqwestClient;
  44use rpc::proto::split_repository_update;
  45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  46
  47use futures::{
  48    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  49    stream::FuturesUnordered,
  50};
  51use prometheus::{IntGauge, register_int_gauge};
  52use rpc::{
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54    proto::{
  55        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  56        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  57    },
  58};
  59use semantic_version::SemanticVersion;
  60use serde::{Serialize, Serializer};
  61use std::{
  62    any::TypeId,
  63    future::Future,
  64    marker::PhantomData,
  65    mem,
  66    net::SocketAddr,
  67    ops::{Deref, DerefMut},
  68    rc::Rc,
  69    sync::{
  70        Arc, OnceLock,
  71        atomic::{AtomicBool, Ordering::SeqCst},
  72    },
  73    time::{Duration, Instant},
  74};
  75use time::OffsetDateTime;
  76use tokio::sync::{Semaphore, watch};
  77use tower::ServiceBuilder;
  78use tracing::{
  79    Instrument,
  80    field::{self},
  81    info_span, instrument,
  82};
  83
  84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  85
  86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  88
  89const MESSAGE_COUNT_PER_PAGE: usize = 100;
  90const MAX_MESSAGE_LEN: usize = 1024;
  91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  92
  93type MessageHandler =
  94    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  95
  96struct Response<R> {
  97    peer: Arc<Peer>,
  98    receipt: Receipt<R>,
  99    responded: Arc<AtomicBool>,
 100}
 101
 102impl<R: RequestMessage> Response<R> {
 103    fn send(self, payload: R::Response) -> Result<()> {
 104        self.responded.store(true, SeqCst);
 105        self.peer.respond(self.receipt, payload)?;
 106        Ok(())
 107    }
 108}
 109
 110#[derive(Clone, Debug)]
 111pub enum Principal {
 112    User(User),
 113    Impersonated { user: User, admin: User },
 114}
 115
 116impl Principal {
 117    fn user(&self) -> &User {
 118        match self {
 119            Principal::User(user) => user,
 120            Principal::Impersonated { user, .. } => user,
 121        }
 122    }
 123
 124    fn update_span(&self, span: &tracing::Span) {
 125        match &self {
 126            Principal::User(user) => {
 127                span.record("user_id", user.id.0);
 128                span.record("login", &user.github_login);
 129            }
 130            Principal::Impersonated { user, admin } => {
 131                span.record("user_id", user.id.0);
 132                span.record("login", &user.github_login);
 133                span.record("impersonator", &admin.github_login);
 134            }
 135        }
 136    }
 137}
 138
 139#[derive(Clone)]
 140struct Session {
 141    principal: Principal,
 142    connection_id: ConnectionId,
 143    db: Arc<tokio::sync::Mutex<DbHandle>>,
 144    peer: Arc<Peer>,
 145    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 146    app_state: Arc<AppState>,
 147    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 148    /// The GeoIP country code for the user.
 149    #[allow(unused)]
 150    geoip_country_code: Option<String>,
 151    system_id: Option<String>,
 152    _executor: Executor,
 153}
 154
 155impl Session {
 156    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 157        #[cfg(test)]
 158        tokio::task::yield_now().await;
 159        let guard = self.db.lock().await;
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        guard
 163    }
 164
 165    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 166        #[cfg(test)]
 167        tokio::task::yield_now().await;
 168        let guard = self.connection_pool.lock();
 169        ConnectionPoolGuard {
 170            guard,
 171            _not_send: PhantomData,
 172        }
 173    }
 174
 175    fn is_staff(&self) -> bool {
 176        match &self.principal {
 177            Principal::User(user) => user.admin,
 178            Principal::Impersonated { .. } => true,
 179        }
 180    }
 181
 182    fn user_id(&self) -> UserId {
 183        match &self.principal {
 184            Principal::User(user) => user.id,
 185            Principal::Impersonated { user, .. } => user.id,
 186        }
 187    }
 188
 189    pub fn email(&self) -> Option<String> {
 190        match &self.principal {
 191            Principal::User(user) => user.email_address.clone(),
 192            Principal::Impersonated { user, .. } => user.email_address.clone(),
 193        }
 194    }
 195}
 196
 197impl Debug for Session {
 198    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 199        let mut result = f.debug_struct("Session");
 200        match &self.principal {
 201            Principal::User(user) => {
 202                result.field("user", &user.github_login);
 203            }
 204            Principal::Impersonated { user, admin } => {
 205                result.field("user", &user.github_login);
 206                result.field("impersonator", &admin.github_login);
 207            }
 208        }
 209        result.field("connection_id", &self.connection_id).finish()
 210    }
 211}
 212
 213struct DbHandle(Arc<Database>);
 214
 215impl Deref for DbHandle {
 216    type Target = Database;
 217
 218    fn deref(&self) -> &Self::Target {
 219        self.0.as_ref()
 220    }
 221}
 222
 223pub struct Server {
 224    id: parking_lot::Mutex<ServerId>,
 225    peer: Arc<Peer>,
 226    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 227    app_state: Arc<AppState>,
 228    handlers: HashMap<TypeId, MessageHandler>,
 229    teardown: watch::Sender<bool>,
 230}
 231
 232pub(crate) struct ConnectionPoolGuard<'a> {
 233    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 234    _not_send: PhantomData<Rc<()>>,
 235}
 236
 237#[derive(Serialize)]
 238pub struct ServerSnapshot<'a> {
 239    peer: &'a Peer,
 240    #[serde(serialize_with = "serialize_deref")]
 241    connection_pool: ConnectionPoolGuard<'a>,
 242}
 243
 244pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 245where
 246    S: Serializer,
 247    T: Deref<Target = U>,
 248    U: Serialize,
 249{
 250    Serialize::serialize(value.deref(), serializer)
 251}
 252
 253impl Server {
 254    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 255        let mut server = Self {
 256            id: parking_lot::Mutex::new(id),
 257            peer: Peer::new(id.0 as u32),
 258            app_state: app_state.clone(),
 259            connection_pool: Default::default(),
 260            handlers: Default::default(),
 261            teardown: watch::channel(false).0,
 262        };
 263
 264        server
 265            .add_request_handler(ping)
 266            .add_request_handler(create_room)
 267            .add_request_handler(join_room)
 268            .add_request_handler(rejoin_room)
 269            .add_request_handler(leave_room)
 270            .add_request_handler(set_room_participant_role)
 271            .add_request_handler(call)
 272            .add_request_handler(cancel_call)
 273            .add_message_handler(decline_call)
 274            .add_request_handler(update_participant_location)
 275            .add_request_handler(share_project)
 276            .add_message_handler(unshare_project)
 277            .add_request_handler(join_project)
 278            .add_message_handler(leave_project)
 279            .add_request_handler(update_project)
 280            .add_request_handler(update_worktree)
 281            .add_request_handler(update_repository)
 282            .add_request_handler(remove_repository)
 283            .add_message_handler(start_language_server)
 284            .add_message_handler(update_language_server)
 285            .add_message_handler(update_diagnostic_summary)
 286            .add_message_handler(update_worktree_settings)
 287            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 288            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 289            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 290            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 291            .add_request_handler(forward_find_search_candidates_request)
 292            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 293            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 294            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 295            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 296            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 297            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 298            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 299            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 300            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 301            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 302            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 303            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 304            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 305            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 306            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 307            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 308            .add_request_handler(
 309                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 310            )
 311            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 312            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 313            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 314            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 315            .add_request_handler(
 316                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 317            )
 318            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 323            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 328            .add_request_handler(
 329                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 330            )
 331            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 333            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 334            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 335            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 336            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 337            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 338            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 339            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 340            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 341            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 342            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 343            .add_request_handler(
 344                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 345            )
 346            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 347            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 348            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 349            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 350            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 351            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 352            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 353            .add_message_handler(create_buffer_for_peer)
 354            .add_request_handler(update_buffer)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 356            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 357            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 358            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 359            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 360            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 361            .add_message_handler(
 362                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 363            )
 364            .add_request_handler(get_users)
 365            .add_request_handler(fuzzy_search_users)
 366            .add_request_handler(request_contact)
 367            .add_request_handler(remove_contact)
 368            .add_request_handler(respond_to_contact_request)
 369            .add_message_handler(subscribe_to_channels)
 370            .add_request_handler(create_channel)
 371            .add_request_handler(delete_channel)
 372            .add_request_handler(invite_channel_member)
 373            .add_request_handler(remove_channel_member)
 374            .add_request_handler(set_channel_member_role)
 375            .add_request_handler(set_channel_visibility)
 376            .add_request_handler(rename_channel)
 377            .add_request_handler(join_channel_buffer)
 378            .add_request_handler(leave_channel_buffer)
 379            .add_message_handler(update_channel_buffer)
 380            .add_request_handler(rejoin_channel_buffers)
 381            .add_request_handler(get_channel_members)
 382            .add_request_handler(respond_to_channel_invite)
 383            .add_request_handler(join_channel)
 384            .add_request_handler(join_channel_chat)
 385            .add_message_handler(leave_channel_chat)
 386            .add_request_handler(send_channel_message)
 387            .add_request_handler(remove_channel_message)
 388            .add_request_handler(update_channel_message)
 389            .add_request_handler(get_channel_messages)
 390            .add_request_handler(get_channel_messages_by_id)
 391            .add_request_handler(get_notifications)
 392            .add_request_handler(mark_notification_as_read)
 393            .add_request_handler(move_channel)
 394            .add_request_handler(reorder_channel)
 395            .add_request_handler(follow)
 396            .add_message_handler(unfollow)
 397            .add_message_handler(update_followers)
 398            .add_request_handler(get_private_user_info)
 399            .add_request_handler(get_llm_api_token)
 400            .add_request_handler(accept_terms_of_service)
 401            .add_message_handler(acknowledge_channel_message)
 402            .add_message_handler(acknowledge_buffer_version)
 403            .add_request_handler(get_supermaven_api_key)
 404            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 405            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 406            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 407            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 408            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 409            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 410            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 411            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 412            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 413            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 414            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 415            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 416            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 417            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 418            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 419            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 420            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 421            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 422            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 423            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 424            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 425            .add_message_handler(update_context);
 426
 427        Arc::new(server)
 428    }
 429
 430    pub async fn start(&self) -> Result<()> {
 431        let server_id = *self.id.lock();
 432        let app_state = self.app_state.clone();
 433        let peer = self.peer.clone();
 434        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 435        let pool = self.connection_pool.clone();
 436        let livekit_client = self.app_state.livekit_client.clone();
 437
 438        let span = info_span!("start server");
 439        self.app_state.executor.spawn_detached(
 440            async move {
 441                tracing::info!("waiting for cleanup timeout");
 442                timeout.await;
 443                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 444
 445                app_state
 446                    .db
 447                    .delete_stale_channel_chat_participants(
 448                        &app_state.config.zed_environment,
 449                        server_id,
 450                    )
 451                    .await
 452                    .trace_err();
 453
 454                if let Some((room_ids, channel_ids)) = app_state
 455                    .db
 456                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 457                    .await
 458                    .trace_err()
 459                {
 460                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 461                    tracing::info!(
 462                        stale_channel_buffer_count = channel_ids.len(),
 463                        "retrieved stale channel buffers"
 464                    );
 465
 466                    for channel_id in channel_ids {
 467                        if let Some(refreshed_channel_buffer) = app_state
 468                            .db
 469                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 470                            .await
 471                            .trace_err()
 472                        {
 473                            for connection_id in refreshed_channel_buffer.connection_ids {
 474                                peer.send(
 475                                    connection_id,
 476                                    proto::UpdateChannelBufferCollaborators {
 477                                        channel_id: channel_id.to_proto(),
 478                                        collaborators: refreshed_channel_buffer
 479                                            .collaborators
 480                                            .clone(),
 481                                    },
 482                                )
 483                                .trace_err();
 484                            }
 485                        }
 486                    }
 487
 488                    for room_id in room_ids {
 489                        let mut contacts_to_update = HashSet::default();
 490                        let mut canceled_calls_to_user_ids = Vec::new();
 491                        let mut livekit_room = String::new();
 492                        let mut delete_livekit_room = false;
 493
 494                        if let Some(mut refreshed_room) = app_state
 495                            .db
 496                            .clear_stale_room_participants(room_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            tracing::info!(
 501                                room_id = room_id.0,
 502                                new_participant_count = refreshed_room.room.participants.len(),
 503                                "refreshed room"
 504                            );
 505                            room_updated(&refreshed_room.room, &peer);
 506                            if let Some(channel) = refreshed_room.channel.as_ref() {
 507                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 508                            }
 509                            contacts_to_update
 510                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 511                            contacts_to_update
 512                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 513                            canceled_calls_to_user_ids =
 514                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 515                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 516                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 517                        }
 518
 519                        {
 520                            let pool = pool.lock();
 521                            for canceled_user_id in canceled_calls_to_user_ids {
 522                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 523                                    peer.send(
 524                                        connection_id,
 525                                        proto::CallCanceled {
 526                                            room_id: room_id.to_proto(),
 527                                        },
 528                                    )
 529                                    .trace_err();
 530                                }
 531                            }
 532                        }
 533
 534                        for user_id in contacts_to_update {
 535                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 536                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 537                            if let Some((busy, contacts)) = busy.zip(contacts) {
 538                                let pool = pool.lock();
 539                                let updated_contact = contact_for_user(user_id, busy, &pool);
 540                                for contact in contacts {
 541                                    if let db::Contact::Accepted {
 542                                        user_id: contact_user_id,
 543                                        ..
 544                                    } = contact
 545                                    {
 546                                        for contact_conn_id in
 547                                            pool.user_connection_ids(contact_user_id)
 548                                        {
 549                                            peer.send(
 550                                                contact_conn_id,
 551                                                proto::UpdateContacts {
 552                                                    contacts: vec![updated_contact.clone()],
 553                                                    remove_contacts: Default::default(),
 554                                                    incoming_requests: Default::default(),
 555                                                    remove_incoming_requests: Default::default(),
 556                                                    outgoing_requests: Default::default(),
 557                                                    remove_outgoing_requests: Default::default(),
 558                                                },
 559                                            )
 560                                            .trace_err();
 561                                        }
 562                                    }
 563                                }
 564                            }
 565                        }
 566
 567                        if let Some(live_kit) = livekit_client.as_ref() {
 568                            if delete_livekit_room {
 569                                live_kit.delete_room(livekit_room).await.trace_err();
 570                            }
 571                        }
 572                    }
 573                }
 574
 575                app_state
 576                    .db
 577                    .delete_stale_channel_chat_participants(
 578                        &app_state.config.zed_environment,
 579                        server_id,
 580                    )
 581                    .await
 582                    .trace_err();
 583
 584                app_state
 585                    .db
 586                    .clear_old_worktree_entries(server_id)
 587                    .await
 588                    .trace_err();
 589
 590                app_state
 591                    .db
 592                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 593                    .await
 594                    .trace_err();
 595            }
 596            .instrument(span),
 597        );
 598        Ok(())
 599    }
 600
 601    pub fn teardown(&self) {
 602        self.peer.teardown();
 603        self.connection_pool.lock().reset();
 604        let _ = self.teardown.send(true);
 605    }
 606
 607    #[cfg(test)]
 608    pub fn reset(&self, id: ServerId) {
 609        self.teardown();
 610        *self.id.lock() = id;
 611        self.peer.reset(id.0 as u32);
 612        let _ = self.teardown.send(false);
 613    }
 614
 615    #[cfg(test)]
 616    pub fn id(&self) -> ServerId {
 617        *self.id.lock()
 618    }
 619
 620    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 621    where
 622        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 623        Fut: 'static + Send + Future<Output = Result<()>>,
 624        M: EnvelopedMessage,
 625    {
 626        let prev_handler = self.handlers.insert(
 627            TypeId::of::<M>(),
 628            Box::new(move |envelope, session| {
 629                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 630                let received_at = envelope.received_at;
 631                tracing::info!("message received");
 632                let start_time = Instant::now();
 633                let future = (handler)(*envelope, session);
 634                async move {
 635                    let result = future.await;
 636                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 637                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 638                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 639                    let payload_type = M::NAME;
 640
 641                    match result {
 642                        Err(error) => {
 643                            tracing::error!(
 644                                ?error,
 645                                total_duration_ms,
 646                                processing_duration_ms,
 647                                queue_duration_ms,
 648                                payload_type,
 649                                "error handling message"
 650                            )
 651                        }
 652                        Ok(()) => tracing::info!(
 653                            total_duration_ms,
 654                            processing_duration_ms,
 655                            queue_duration_ms,
 656                            "finished handling message"
 657                        ),
 658                    }
 659                }
 660                .boxed()
 661            }),
 662        );
 663        if prev_handler.is_some() {
 664            panic!("registered a handler for the same message twice");
 665        }
 666        self
 667    }
 668
 669    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 670    where
 671        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 672        Fut: 'static + Send + Future<Output = Result<()>>,
 673        M: EnvelopedMessage,
 674    {
 675        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 676        self
 677    }
 678
 679    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 680    where
 681        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 682        Fut: Send + Future<Output = Result<()>>,
 683        M: RequestMessage,
 684    {
 685        let handler = Arc::new(handler);
 686        self.add_handler(move |envelope, session| {
 687            let receipt = envelope.receipt();
 688            let handler = handler.clone();
 689            async move {
 690                let peer = session.peer.clone();
 691                let responded = Arc::new(AtomicBool::default());
 692                let response = Response {
 693                    peer: peer.clone(),
 694                    responded: responded.clone(),
 695                    receipt,
 696                };
 697                match (handler)(envelope.payload, response, session).await {
 698                    Ok(()) => {
 699                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 700                            Ok(())
 701                        } else {
 702                            Err(anyhow!("handler did not send a response"))?
 703                        }
 704                    }
 705                    Err(error) => {
 706                        let proto_err = match &error {
 707                            Error::Internal(err) => err.to_proto(),
 708                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 709                        };
 710                        peer.respond_with_error(receipt, proto_err)?;
 711                        Err(error)
 712                    }
 713                }
 714            }
 715        })
 716    }
 717
 718    pub fn handle_connection(
 719        self: &Arc<Self>,
 720        connection: Connection,
 721        address: String,
 722        principal: Principal,
 723        zed_version: ZedVersion,
 724        geoip_country_code: Option<String>,
 725        system_id: Option<String>,
 726        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 727        executor: Executor,
 728    ) -> impl Future<Output = ()> + use<> {
 729        let this = self.clone();
 730        let span = info_span!("handle connection", %address,
 731            connection_id=field::Empty,
 732            user_id=field::Empty,
 733            login=field::Empty,
 734            impersonator=field::Empty,
 735            geoip_country_code=field::Empty
 736        );
 737        principal.update_span(&span);
 738        if let Some(country_code) = geoip_country_code.as_ref() {
 739            span.record("geoip_country_code", country_code);
 740        }
 741
 742        let mut teardown = self.teardown.subscribe();
 743        async move {
 744            if *teardown.borrow() {
 745                tracing::error!("server is tearing down");
 746                return
 747            }
 748            let (connection_id, handle_io, mut incoming_rx) = this
 749                .peer
 750                .add_connection(connection, {
 751                    let executor = executor.clone();
 752                    move |duration| executor.sleep(duration)
 753                });
 754            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 755
 756            tracing::info!("connection opened");
 757
 758            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 759            let http_client = match ReqwestClient::user_agent(&user_agent) {
 760                Ok(http_client) => Arc::new(http_client),
 761                Err(error) => {
 762                    tracing::error!(?error, "failed to create HTTP client");
 763                    return;
 764                }
 765            };
 766
 767            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 768                    supermaven_admin_api_key.to_string(),
 769                    http_client.clone(),
 770                )));
 771
 772            let session = Session {
 773                principal: principal.clone(),
 774                connection_id,
 775                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 776                peer: this.peer.clone(),
 777                connection_pool: this.connection_pool.clone(),
 778                app_state: this.app_state.clone(),
 779                geoip_country_code,
 780                system_id,
 781                _executor: executor.clone(),
 782                supermaven_client,
 783            };
 784
 785            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 786                tracing::error!(?error, "failed to send initial client update");
 787                return;
 788            }
 789
 790            let handle_io = handle_io.fuse();
 791            futures::pin_mut!(handle_io);
 792
 793            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 794            // This prevents deadlocks when e.g., client A performs a request to client B and
 795            // client B performs a request to client A. If both clients stop processing further
 796            // messages until their respective request completes, they won't have a chance to
 797            // respond to the other client's request and cause a deadlock.
 798            //
 799            // This arrangement ensures we will attempt to process earlier messages first, but fall
 800            // back to processing messages arrived later in the spirit of making progress.
 801            let mut foreground_message_handlers = FuturesUnordered::new();
 802            let concurrent_handlers = Arc::new(Semaphore::new(256));
 803            loop {
 804                let next_message = async {
 805                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 806                    let message = incoming_rx.next().await;
 807                    (permit, message)
 808                }.fuse();
 809                futures::pin_mut!(next_message);
 810                futures::select_biased! {
 811                    _ = teardown.changed().fuse() => return,
 812                    result = handle_io => {
 813                        if let Err(error) = result {
 814                            tracing::error!(?error, "error handling I/O");
 815                        }
 816                        break;
 817                    }
 818                    _ = foreground_message_handlers.next() => {}
 819                    next_message = next_message => {
 820                        let (permit, message) = next_message;
 821                        if let Some(message) = message {
 822                            let type_name = message.payload_type_name();
 823                            // note: we copy all the fields from the parent span so we can query them in the logs.
 824                            // (https://github.com/tokio-rs/tracing/issues/2670).
 825                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 826                                user_id=field::Empty,
 827                                login=field::Empty,
 828                                impersonator=field::Empty,
 829                            );
 830                            principal.update_span(&span);
 831                            let span_enter = span.enter();
 832                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 833                                let is_background = message.is_background();
 834                                let handle_message = (handler)(message, session.clone());
 835                                drop(span_enter);
 836
 837                                let handle_message = async move {
 838                                    handle_message.await;
 839                                    drop(permit);
 840                                }.instrument(span);
 841                                if is_background {
 842                                    executor.spawn_detached(handle_message);
 843                                } else {
 844                                    foreground_message_handlers.push(handle_message);
 845                                }
 846                            } else {
 847                                tracing::error!("no message handler");
 848                            }
 849                        } else {
 850                            tracing::info!("connection closed");
 851                            break;
 852                        }
 853                    }
 854                }
 855            }
 856
 857            drop(foreground_message_handlers);
 858            tracing::info!("signing out");
 859            if let Err(error) = connection_lost(session, teardown, executor).await {
 860                tracing::error!(?error, "error signing out");
 861            }
 862
 863        }.instrument(span)
 864    }
 865
 866    async fn send_initial_client_update(
 867        &self,
 868        connection_id: ConnectionId,
 869        zed_version: ZedVersion,
 870        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 871        session: &Session,
 872    ) -> Result<()> {
 873        self.peer.send(
 874            connection_id,
 875            proto::Hello {
 876                peer_id: Some(connection_id.into()),
 877            },
 878        )?;
 879        tracing::info!("sent hello message");
 880        if let Some(send_connection_id) = send_connection_id.take() {
 881            let _ = send_connection_id.send(connection_id);
 882        }
 883
 884        match &session.principal {
 885            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 886                if !user.connected_once {
 887                    self.peer.send(connection_id, proto::ShowContacts {})?;
 888                    self.app_state
 889                        .db
 890                        .set_user_connected_once(user.id, true)
 891                        .await?;
 892                }
 893
 894                update_user_plan(session).await?;
 895
 896                let contacts = self.app_state.db.get_contacts(user.id).await?;
 897
 898                {
 899                    let mut pool = self.connection_pool.lock();
 900                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 901                    self.peer.send(
 902                        connection_id,
 903                        build_initial_contacts_update(contacts, &pool),
 904                    )?;
 905                }
 906
 907                if should_auto_subscribe_to_channels(zed_version) {
 908                    subscribe_user_to_channels(user.id, session).await?;
 909                }
 910
 911                if let Some(incoming_call) =
 912                    self.app_state.db.incoming_call_for_user(user.id).await?
 913                {
 914                    self.peer.send(connection_id, incoming_call)?;
 915                }
 916
 917                update_user_contacts(user.id, session).await?;
 918            }
 919        }
 920
 921        Ok(())
 922    }
 923
 924    pub async fn invite_code_redeemed(
 925        self: &Arc<Self>,
 926        inviter_id: UserId,
 927        invitee_id: UserId,
 928    ) -> Result<()> {
 929        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 930            if let Some(code) = &user.invite_code {
 931                let pool = self.connection_pool.lock();
 932                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 933                for connection_id in pool.user_connection_ids(inviter_id) {
 934                    self.peer.send(
 935                        connection_id,
 936                        proto::UpdateContacts {
 937                            contacts: vec![invitee_contact.clone()],
 938                            ..Default::default()
 939                        },
 940                    )?;
 941                    self.peer.send(
 942                        connection_id,
 943                        proto::UpdateInviteInfo {
 944                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 945                            count: user.invite_count as u32,
 946                        },
 947                    )?;
 948                }
 949            }
 950        }
 951        Ok(())
 952    }
 953
 954    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 955        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 956            if let Some(invite_code) = &user.invite_code {
 957                let pool = self.connection_pool.lock();
 958                for connection_id in pool.user_connection_ids(user_id) {
 959                    self.peer.send(
 960                        connection_id,
 961                        proto::UpdateInviteInfo {
 962                            url: format!(
 963                                "{}{}",
 964                                self.app_state.config.invite_link_prefix, invite_code
 965                            ),
 966                            count: user.invite_count as u32,
 967                        },
 968                    )?;
 969                }
 970            }
 971        }
 972        Ok(())
 973    }
 974
 975    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 976        let user = self
 977            .app_state
 978            .db
 979            .get_user_by_id(user_id)
 980            .await?
 981            .context("user not found")?;
 982
 983        let update_user_plan = make_update_user_plan_message(
 984            &user,
 985            user.admin,
 986            &self.app_state.db,
 987            self.app_state.llm_db.clone(),
 988        )
 989        .await?;
 990
 991        let pool = self.connection_pool.lock();
 992        for connection_id in pool.user_connection_ids(user_id) {
 993            self.peer
 994                .send(connection_id, update_user_plan.clone())
 995                .trace_err();
 996        }
 997
 998        Ok(())
 999    }
1000
1001    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1002        let pool = self.connection_pool.lock();
1003        for connection_id in pool.user_connection_ids(user_id) {
1004            self.peer
1005                .send(connection_id, proto::RefreshLlmToken {})
1006                .trace_err();
1007        }
1008    }
1009
1010    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1011        ServerSnapshot {
1012            connection_pool: ConnectionPoolGuard {
1013                guard: self.connection_pool.lock(),
1014                _not_send: PhantomData,
1015            },
1016            peer: &self.peer,
1017        }
1018    }
1019}
1020
1021impl Deref for ConnectionPoolGuard<'_> {
1022    type Target = ConnectionPool;
1023
1024    fn deref(&self) -> &Self::Target {
1025        &self.guard
1026    }
1027}
1028
1029impl DerefMut for ConnectionPoolGuard<'_> {
1030    fn deref_mut(&mut self) -> &mut Self::Target {
1031        &mut self.guard
1032    }
1033}
1034
1035impl Drop for ConnectionPoolGuard<'_> {
1036    fn drop(&mut self) {
1037        #[cfg(test)]
1038        self.check_invariants();
1039    }
1040}
1041
1042fn broadcast<F>(
1043    sender_id: Option<ConnectionId>,
1044    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1045    mut f: F,
1046) where
1047    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1048{
1049    for receiver_id in receiver_ids {
1050        if Some(receiver_id) != sender_id {
1051            if let Err(error) = f(receiver_id) {
1052                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1053            }
1054        }
1055    }
1056}
1057
1058pub struct ProtocolVersion(u32);
1059
1060impl Header for ProtocolVersion {
1061    fn name() -> &'static HeaderName {
1062        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1063        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1064    }
1065
1066    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1067    where
1068        Self: Sized,
1069        I: Iterator<Item = &'i axum::http::HeaderValue>,
1070    {
1071        let version = values
1072            .next()
1073            .ok_or_else(axum::headers::Error::invalid)?
1074            .to_str()
1075            .map_err(|_| axum::headers::Error::invalid())?
1076            .parse()
1077            .map_err(|_| axum::headers::Error::invalid())?;
1078        Ok(Self(version))
1079    }
1080
1081    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1082        values.extend([self.0.to_string().parse().unwrap()]);
1083    }
1084}
1085
1086pub struct AppVersionHeader(SemanticVersion);
1087impl Header for AppVersionHeader {
1088    fn name() -> &'static HeaderName {
1089        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1090        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1091    }
1092
1093    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1094    where
1095        Self: Sized,
1096        I: Iterator<Item = &'i axum::http::HeaderValue>,
1097    {
1098        let version = values
1099            .next()
1100            .ok_or_else(axum::headers::Error::invalid)?
1101            .to_str()
1102            .map_err(|_| axum::headers::Error::invalid())?
1103            .parse()
1104            .map_err(|_| axum::headers::Error::invalid())?;
1105        Ok(Self(version))
1106    }
1107
1108    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1109        values.extend([self.0.to_string().parse().unwrap()]);
1110    }
1111}
1112
1113pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1114    Router::new()
1115        .route("/rpc", get(handle_websocket_request))
1116        .layer(
1117            ServiceBuilder::new()
1118                .layer(Extension(server.app_state.clone()))
1119                .layer(middleware::from_fn(auth::validate_header)),
1120        )
1121        .route("/metrics", get(handle_metrics))
1122        .layer(Extension(server))
1123}
1124
1125pub async fn handle_websocket_request(
1126    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1127    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1128    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1129    Extension(server): Extension<Arc<Server>>,
1130    Extension(principal): Extension<Principal>,
1131    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1132    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1133    ws: WebSocketUpgrade,
1134) -> axum::response::Response {
1135    if protocol_version != rpc::PROTOCOL_VERSION {
1136        return (
1137            StatusCode::UPGRADE_REQUIRED,
1138            "client must be upgraded".to_string(),
1139        )
1140            .into_response();
1141    }
1142
1143    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1144        return (
1145            StatusCode::UPGRADE_REQUIRED,
1146            "no version header found".to_string(),
1147        )
1148            .into_response();
1149    };
1150
1151    if !version.can_collaborate() {
1152        return (
1153            StatusCode::UPGRADE_REQUIRED,
1154            "client must be upgraded".to_string(),
1155        )
1156            .into_response();
1157    }
1158
1159    let socket_address = socket_address.to_string();
1160    ws.on_upgrade(move |socket| {
1161        let socket = socket
1162            .map_ok(to_tungstenite_message)
1163            .err_into()
1164            .with(|message| async move { to_axum_message(message) });
1165        let connection = Connection::new(Box::pin(socket));
1166        async move {
1167            server
1168                .handle_connection(
1169                    connection,
1170                    socket_address,
1171                    principal,
1172                    version,
1173                    country_code_header.map(|header| header.to_string()),
1174                    system_id_header.map(|header| header.to_string()),
1175                    None,
1176                    Executor::Production,
1177                )
1178                .await;
1179        }
1180    })
1181}
1182
1183pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1184    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1185    let connections_metric = CONNECTIONS_METRIC
1186        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1187
1188    let connections = server
1189        .connection_pool
1190        .lock()
1191        .connections()
1192        .filter(|connection| !connection.admin)
1193        .count();
1194    connections_metric.set(connections as _);
1195
1196    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1197    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1198        register_int_gauge!(
1199            "shared_projects",
1200            "number of open projects with one or more guests"
1201        )
1202        .unwrap()
1203    });
1204
1205    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1206    shared_projects_metric.set(shared_projects as _);
1207
1208    let encoder = prometheus::TextEncoder::new();
1209    let metric_families = prometheus::gather();
1210    let encoded_metrics = encoder
1211        .encode_to_string(&metric_families)
1212        .map_err(|err| anyhow!("{err}"))?;
1213    Ok(encoded_metrics)
1214}
1215
1216#[instrument(err, skip(executor))]
1217async fn connection_lost(
1218    session: Session,
1219    mut teardown: watch::Receiver<bool>,
1220    executor: Executor,
1221) -> Result<()> {
1222    session.peer.disconnect(session.connection_id);
1223    session
1224        .connection_pool()
1225        .await
1226        .remove_connection(session.connection_id)?;
1227
1228    session
1229        .db()
1230        .await
1231        .connection_lost(session.connection_id)
1232        .await
1233        .trace_err();
1234
1235    futures::select_biased! {
1236        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1237
1238            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1239            leave_room_for_session(&session, session.connection_id).await.trace_err();
1240            leave_channel_buffers_for_session(&session)
1241                .await
1242                .trace_err();
1243
1244            if !session
1245                .connection_pool()
1246                .await
1247                .is_user_online(session.user_id())
1248            {
1249                let db = session.db().await;
1250                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1251                    room_updated(&room, &session.peer);
1252                }
1253            }
1254
1255            update_user_contacts(session.user_id(), &session).await?;
1256        },
1257        _ = teardown.changed().fuse() => {}
1258    }
1259
1260    Ok(())
1261}
1262
1263/// Acknowledges a ping from a client, used to keep the connection alive.
1264async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1265    response.send(proto::Ack {})?;
1266    Ok(())
1267}
1268
1269/// Creates a new room for calling (outside of channels)
1270async fn create_room(
1271    _request: proto::CreateRoom,
1272    response: Response<proto::CreateRoom>,
1273    session: Session,
1274) -> Result<()> {
1275    let livekit_room = nanoid::nanoid!(30);
1276
1277    let live_kit_connection_info = util::maybe!(async {
1278        let live_kit = session.app_state.livekit_client.as_ref();
1279        let live_kit = live_kit?;
1280        let user_id = session.user_id().to_string();
1281
1282        let token = live_kit
1283            .room_token(&livekit_room, &user_id.to_string())
1284            .trace_err()?;
1285
1286        Some(proto::LiveKitConnectionInfo {
1287            server_url: live_kit.url().into(),
1288            token,
1289            can_publish: true,
1290        })
1291    })
1292    .await;
1293
1294    let room = session
1295        .db()
1296        .await
1297        .create_room(session.user_id(), session.connection_id, &livekit_room)
1298        .await?;
1299
1300    response.send(proto::CreateRoomResponse {
1301        room: Some(room.clone()),
1302        live_kit_connection_info,
1303    })?;
1304
1305    update_user_contacts(session.user_id(), &session).await?;
1306    Ok(())
1307}
1308
1309/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1310async fn join_room(
1311    request: proto::JoinRoom,
1312    response: Response<proto::JoinRoom>,
1313    session: Session,
1314) -> Result<()> {
1315    let room_id = RoomId::from_proto(request.id);
1316
1317    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1318
1319    if let Some(channel_id) = channel_id {
1320        return join_channel_internal(channel_id, Box::new(response), session).await;
1321    }
1322
1323    let joined_room = {
1324        let room = session
1325            .db()
1326            .await
1327            .join_room(room_id, session.user_id(), session.connection_id)
1328            .await?;
1329        room_updated(&room.room, &session.peer);
1330        room.into_inner()
1331    };
1332
1333    for connection_id in session
1334        .connection_pool()
1335        .await
1336        .user_connection_ids(session.user_id())
1337    {
1338        session
1339            .peer
1340            .send(
1341                connection_id,
1342                proto::CallCanceled {
1343                    room_id: room_id.to_proto(),
1344                },
1345            )
1346            .trace_err();
1347    }
1348
1349    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1350    {
1351        live_kit
1352            .room_token(
1353                &joined_room.room.livekit_room,
1354                &session.user_id().to_string(),
1355            )
1356            .trace_err()
1357            .map(|token| proto::LiveKitConnectionInfo {
1358                server_url: live_kit.url().into(),
1359                token,
1360                can_publish: true,
1361            })
1362    } else {
1363        None
1364    };
1365
1366    response.send(proto::JoinRoomResponse {
1367        room: Some(joined_room.room),
1368        channel_id: None,
1369        live_kit_connection_info,
1370    })?;
1371
1372    update_user_contacts(session.user_id(), &session).await?;
1373    Ok(())
1374}
1375
1376/// Rejoin room is used to reconnect to a room after connection errors.
1377async fn rejoin_room(
1378    request: proto::RejoinRoom,
1379    response: Response<proto::RejoinRoom>,
1380    session: Session,
1381) -> Result<()> {
1382    let room;
1383    let channel;
1384    {
1385        let mut rejoined_room = session
1386            .db()
1387            .await
1388            .rejoin_room(request, session.user_id(), session.connection_id)
1389            .await?;
1390
1391        response.send(proto::RejoinRoomResponse {
1392            room: Some(rejoined_room.room.clone()),
1393            reshared_projects: rejoined_room
1394                .reshared_projects
1395                .iter()
1396                .map(|project| proto::ResharedProject {
1397                    id: project.id.to_proto(),
1398                    collaborators: project
1399                        .collaborators
1400                        .iter()
1401                        .map(|collaborator| collaborator.to_proto())
1402                        .collect(),
1403                })
1404                .collect(),
1405            rejoined_projects: rejoined_room
1406                .rejoined_projects
1407                .iter()
1408                .map(|rejoined_project| rejoined_project.to_proto())
1409                .collect(),
1410        })?;
1411        room_updated(&rejoined_room.room, &session.peer);
1412
1413        for project in &rejoined_room.reshared_projects {
1414            for collaborator in &project.collaborators {
1415                session
1416                    .peer
1417                    .send(
1418                        collaborator.connection_id,
1419                        proto::UpdateProjectCollaborator {
1420                            project_id: project.id.to_proto(),
1421                            old_peer_id: Some(project.old_connection_id.into()),
1422                            new_peer_id: Some(session.connection_id.into()),
1423                        },
1424                    )
1425                    .trace_err();
1426            }
1427
1428            broadcast(
1429                Some(session.connection_id),
1430                project
1431                    .collaborators
1432                    .iter()
1433                    .map(|collaborator| collaborator.connection_id),
1434                |connection_id| {
1435                    session.peer.forward_send(
1436                        session.connection_id,
1437                        connection_id,
1438                        proto::UpdateProject {
1439                            project_id: project.id.to_proto(),
1440                            worktrees: project.worktrees.clone(),
1441                        },
1442                    )
1443                },
1444            );
1445        }
1446
1447        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1448
1449        let rejoined_room = rejoined_room.into_inner();
1450
1451        room = rejoined_room.room;
1452        channel = rejoined_room.channel;
1453    }
1454
1455    if let Some(channel) = channel {
1456        channel_updated(
1457            &channel,
1458            &room,
1459            &session.peer,
1460            &*session.connection_pool().await,
1461        );
1462    }
1463
1464    update_user_contacts(session.user_id(), &session).await?;
1465    Ok(())
1466}
1467
1468fn notify_rejoined_projects(
1469    rejoined_projects: &mut Vec<RejoinedProject>,
1470    session: &Session,
1471) -> Result<()> {
1472    for project in rejoined_projects.iter() {
1473        for collaborator in &project.collaborators {
1474            session
1475                .peer
1476                .send(
1477                    collaborator.connection_id,
1478                    proto::UpdateProjectCollaborator {
1479                        project_id: project.id.to_proto(),
1480                        old_peer_id: Some(project.old_connection_id.into()),
1481                        new_peer_id: Some(session.connection_id.into()),
1482                    },
1483                )
1484                .trace_err();
1485        }
1486    }
1487
1488    for project in rejoined_projects {
1489        for worktree in mem::take(&mut project.worktrees) {
1490            // Stream this worktree's entries.
1491            let message = proto::UpdateWorktree {
1492                project_id: project.id.to_proto(),
1493                worktree_id: worktree.id,
1494                abs_path: worktree.abs_path.clone(),
1495                root_name: worktree.root_name,
1496                updated_entries: worktree.updated_entries,
1497                removed_entries: worktree.removed_entries,
1498                scan_id: worktree.scan_id,
1499                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1500                updated_repositories: worktree.updated_repositories,
1501                removed_repositories: worktree.removed_repositories,
1502            };
1503            for update in proto::split_worktree_update(message) {
1504                session.peer.send(session.connection_id, update)?;
1505            }
1506
1507            // Stream this worktree's diagnostics.
1508            for summary in worktree.diagnostic_summaries {
1509                session.peer.send(
1510                    session.connection_id,
1511                    proto::UpdateDiagnosticSummary {
1512                        project_id: project.id.to_proto(),
1513                        worktree_id: worktree.id,
1514                        summary: Some(summary),
1515                    },
1516                )?;
1517            }
1518
1519            for settings_file in worktree.settings_files {
1520                session.peer.send(
1521                    session.connection_id,
1522                    proto::UpdateWorktreeSettings {
1523                        project_id: project.id.to_proto(),
1524                        worktree_id: worktree.id,
1525                        path: settings_file.path,
1526                        content: Some(settings_file.content),
1527                        kind: Some(settings_file.kind.to_proto().into()),
1528                    },
1529                )?;
1530            }
1531        }
1532
1533        for repository in mem::take(&mut project.updated_repositories) {
1534            for update in split_repository_update(repository) {
1535                session.peer.send(session.connection_id, update)?;
1536            }
1537        }
1538
1539        for id in mem::take(&mut project.removed_repositories) {
1540            session.peer.send(
1541                session.connection_id,
1542                proto::RemoveRepository {
1543                    project_id: project.id.to_proto(),
1544                    id,
1545                },
1546            )?;
1547        }
1548    }
1549
1550    Ok(())
1551}
1552
1553/// leave room disconnects from the room.
1554async fn leave_room(
1555    _: proto::LeaveRoom,
1556    response: Response<proto::LeaveRoom>,
1557    session: Session,
1558) -> Result<()> {
1559    leave_room_for_session(&session, session.connection_id).await?;
1560    response.send(proto::Ack {})?;
1561    Ok(())
1562}
1563
1564/// Updates the permissions of someone else in the room.
1565async fn set_room_participant_role(
1566    request: proto::SetRoomParticipantRole,
1567    response: Response<proto::SetRoomParticipantRole>,
1568    session: Session,
1569) -> Result<()> {
1570    let user_id = UserId::from_proto(request.user_id);
1571    let role = ChannelRole::from(request.role());
1572
1573    let (livekit_room, can_publish) = {
1574        let room = session
1575            .db()
1576            .await
1577            .set_room_participant_role(
1578                session.user_id(),
1579                RoomId::from_proto(request.room_id),
1580                user_id,
1581                role,
1582            )
1583            .await?;
1584
1585        let livekit_room = room.livekit_room.clone();
1586        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1587        room_updated(&room, &session.peer);
1588        (livekit_room, can_publish)
1589    };
1590
1591    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1592        live_kit
1593            .update_participant(
1594                livekit_room.clone(),
1595                request.user_id.to_string(),
1596                livekit_api::proto::ParticipantPermission {
1597                    can_subscribe: true,
1598                    can_publish,
1599                    can_publish_data: can_publish,
1600                    hidden: false,
1601                    recorder: false,
1602                },
1603            )
1604            .await
1605            .trace_err();
1606    }
1607
1608    response.send(proto::Ack {})?;
1609    Ok(())
1610}
1611
1612/// Call someone else into the current room
1613async fn call(
1614    request: proto::Call,
1615    response: Response<proto::Call>,
1616    session: Session,
1617) -> Result<()> {
1618    let room_id = RoomId::from_proto(request.room_id);
1619    let calling_user_id = session.user_id();
1620    let calling_connection_id = session.connection_id;
1621    let called_user_id = UserId::from_proto(request.called_user_id);
1622    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1623    if !session
1624        .db()
1625        .await
1626        .has_contact(calling_user_id, called_user_id)
1627        .await?
1628    {
1629        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1630    }
1631
1632    let incoming_call = {
1633        let (room, incoming_call) = &mut *session
1634            .db()
1635            .await
1636            .call(
1637                room_id,
1638                calling_user_id,
1639                calling_connection_id,
1640                called_user_id,
1641                initial_project_id,
1642            )
1643            .await?;
1644        room_updated(room, &session.peer);
1645        mem::take(incoming_call)
1646    };
1647    update_user_contacts(called_user_id, &session).await?;
1648
1649    let mut calls = session
1650        .connection_pool()
1651        .await
1652        .user_connection_ids(called_user_id)
1653        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1654        .collect::<FuturesUnordered<_>>();
1655
1656    while let Some(call_response) = calls.next().await {
1657        match call_response.as_ref() {
1658            Ok(_) => {
1659                response.send(proto::Ack {})?;
1660                return Ok(());
1661            }
1662            Err(_) => {
1663                call_response.trace_err();
1664            }
1665        }
1666    }
1667
1668    {
1669        let room = session
1670            .db()
1671            .await
1672            .call_failed(room_id, called_user_id)
1673            .await?;
1674        room_updated(&room, &session.peer);
1675    }
1676    update_user_contacts(called_user_id, &session).await?;
1677
1678    Err(anyhow!("failed to ring user"))?
1679}
1680
1681/// Cancel an outgoing call.
1682async fn cancel_call(
1683    request: proto::CancelCall,
1684    response: Response<proto::CancelCall>,
1685    session: Session,
1686) -> Result<()> {
1687    let called_user_id = UserId::from_proto(request.called_user_id);
1688    let room_id = RoomId::from_proto(request.room_id);
1689    {
1690        let room = session
1691            .db()
1692            .await
1693            .cancel_call(room_id, session.connection_id, called_user_id)
1694            .await?;
1695        room_updated(&room, &session.peer);
1696    }
1697
1698    for connection_id in session
1699        .connection_pool()
1700        .await
1701        .user_connection_ids(called_user_id)
1702    {
1703        session
1704            .peer
1705            .send(
1706                connection_id,
1707                proto::CallCanceled {
1708                    room_id: room_id.to_proto(),
1709                },
1710            )
1711            .trace_err();
1712    }
1713    response.send(proto::Ack {})?;
1714
1715    update_user_contacts(called_user_id, &session).await?;
1716    Ok(())
1717}
1718
1719/// Decline an incoming call.
1720async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1721    let room_id = RoomId::from_proto(message.room_id);
1722    {
1723        let room = session
1724            .db()
1725            .await
1726            .decline_call(Some(room_id), session.user_id())
1727            .await?
1728            .context("declining call")?;
1729        room_updated(&room, &session.peer);
1730    }
1731
1732    for connection_id in session
1733        .connection_pool()
1734        .await
1735        .user_connection_ids(session.user_id())
1736    {
1737        session
1738            .peer
1739            .send(
1740                connection_id,
1741                proto::CallCanceled {
1742                    room_id: room_id.to_proto(),
1743                },
1744            )
1745            .trace_err();
1746    }
1747    update_user_contacts(session.user_id(), &session).await?;
1748    Ok(())
1749}
1750
1751/// Updates other participants in the room with your current location.
1752async fn update_participant_location(
1753    request: proto::UpdateParticipantLocation,
1754    response: Response<proto::UpdateParticipantLocation>,
1755    session: Session,
1756) -> Result<()> {
1757    let room_id = RoomId::from_proto(request.room_id);
1758    let location = request.location.context("invalid location")?;
1759
1760    let db = session.db().await;
1761    let room = db
1762        .update_room_participant_location(room_id, session.connection_id, location)
1763        .await?;
1764
1765    room_updated(&room, &session.peer);
1766    response.send(proto::Ack {})?;
1767    Ok(())
1768}
1769
1770/// Share a project into the room.
1771async fn share_project(
1772    request: proto::ShareProject,
1773    response: Response<proto::ShareProject>,
1774    session: Session,
1775) -> Result<()> {
1776    let (project_id, room) = &*session
1777        .db()
1778        .await
1779        .share_project(
1780            RoomId::from_proto(request.room_id),
1781            session.connection_id,
1782            &request.worktrees,
1783            request.is_ssh_project,
1784        )
1785        .await?;
1786    response.send(proto::ShareProjectResponse {
1787        project_id: project_id.to_proto(),
1788    })?;
1789    room_updated(room, &session.peer);
1790
1791    Ok(())
1792}
1793
1794/// Unshare a project from the room.
1795async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1796    let project_id = ProjectId::from_proto(message.project_id);
1797    unshare_project_internal(project_id, session.connection_id, &session).await
1798}
1799
1800async fn unshare_project_internal(
1801    project_id: ProjectId,
1802    connection_id: ConnectionId,
1803    session: &Session,
1804) -> Result<()> {
1805    let delete = {
1806        let room_guard = session
1807            .db()
1808            .await
1809            .unshare_project(project_id, connection_id)
1810            .await?;
1811
1812        let (delete, room, guest_connection_ids) = &*room_guard;
1813
1814        let message = proto::UnshareProject {
1815            project_id: project_id.to_proto(),
1816        };
1817
1818        broadcast(
1819            Some(connection_id),
1820            guest_connection_ids.iter().copied(),
1821            |conn_id| session.peer.send(conn_id, message.clone()),
1822        );
1823        if let Some(room) = room {
1824            room_updated(room, &session.peer);
1825        }
1826
1827        *delete
1828    };
1829
1830    if delete {
1831        let db = session.db().await;
1832        db.delete_project(project_id).await?;
1833    }
1834
1835    Ok(())
1836}
1837
1838/// Join someone elses shared project.
1839async fn join_project(
1840    request: proto::JoinProject,
1841    response: Response<proto::JoinProject>,
1842    session: Session,
1843) -> Result<()> {
1844    let project_id = ProjectId::from_proto(request.project_id);
1845
1846    tracing::info!(%project_id, "join project");
1847
1848    let db = session.db().await;
1849    let (project, replica_id) = &mut *db
1850        .join_project(project_id, session.connection_id, session.user_id())
1851        .await?;
1852    drop(db);
1853    tracing::info!(%project_id, "join remote project");
1854    join_project_internal(response, session, project, replica_id)
1855}
1856
1857trait JoinProjectInternalResponse {
1858    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1859}
1860impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1861    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1862        Response::<proto::JoinProject>::send(self, result)
1863    }
1864}
1865
1866fn join_project_internal(
1867    response: impl JoinProjectInternalResponse,
1868    session: Session,
1869    project: &mut Project,
1870    replica_id: &ReplicaId,
1871) -> Result<()> {
1872    let collaborators = project
1873        .collaborators
1874        .iter()
1875        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1876        .map(|collaborator| collaborator.to_proto())
1877        .collect::<Vec<_>>();
1878    let project_id = project.id;
1879    let guest_user_id = session.user_id();
1880
1881    let worktrees = project
1882        .worktrees
1883        .iter()
1884        .map(|(id, worktree)| proto::WorktreeMetadata {
1885            id: *id,
1886            root_name: worktree.root_name.clone(),
1887            visible: worktree.visible,
1888            abs_path: worktree.abs_path.clone(),
1889        })
1890        .collect::<Vec<_>>();
1891
1892    let add_project_collaborator = proto::AddProjectCollaborator {
1893        project_id: project_id.to_proto(),
1894        collaborator: Some(proto::Collaborator {
1895            peer_id: Some(session.connection_id.into()),
1896            replica_id: replica_id.0 as u32,
1897            user_id: guest_user_id.to_proto(),
1898            is_host: false,
1899        }),
1900    };
1901
1902    for collaborator in &collaborators {
1903        session
1904            .peer
1905            .send(
1906                collaborator.peer_id.unwrap().into(),
1907                add_project_collaborator.clone(),
1908            )
1909            .trace_err();
1910    }
1911
1912    // First, we send the metadata associated with each worktree.
1913    response.send(proto::JoinProjectResponse {
1914        project_id: project.id.0 as u64,
1915        worktrees: worktrees.clone(),
1916        replica_id: replica_id.0 as u32,
1917        collaborators: collaborators.clone(),
1918        language_servers: project.language_servers.clone(),
1919        role: project.role.into(),
1920    })?;
1921
1922    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1923        // Stream this worktree's entries.
1924        let message = proto::UpdateWorktree {
1925            project_id: project_id.to_proto(),
1926            worktree_id,
1927            abs_path: worktree.abs_path.clone(),
1928            root_name: worktree.root_name,
1929            updated_entries: worktree.entries,
1930            removed_entries: Default::default(),
1931            scan_id: worktree.scan_id,
1932            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1933            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1934            removed_repositories: Default::default(),
1935        };
1936        for update in proto::split_worktree_update(message) {
1937            session.peer.send(session.connection_id, update.clone())?;
1938        }
1939
1940        // Stream this worktree's diagnostics.
1941        for summary in worktree.diagnostic_summaries {
1942            session.peer.send(
1943                session.connection_id,
1944                proto::UpdateDiagnosticSummary {
1945                    project_id: project_id.to_proto(),
1946                    worktree_id: worktree.id,
1947                    summary: Some(summary),
1948                },
1949            )?;
1950        }
1951
1952        for settings_file in worktree.settings_files {
1953            session.peer.send(
1954                session.connection_id,
1955                proto::UpdateWorktreeSettings {
1956                    project_id: project_id.to_proto(),
1957                    worktree_id: worktree.id,
1958                    path: settings_file.path,
1959                    content: Some(settings_file.content),
1960                    kind: Some(settings_file.kind.to_proto() as i32),
1961                },
1962            )?;
1963        }
1964    }
1965
1966    for repository in mem::take(&mut project.repositories) {
1967        for update in split_repository_update(repository) {
1968            session.peer.send(session.connection_id, update)?;
1969        }
1970    }
1971
1972    for language_server in &project.language_servers {
1973        session.peer.send(
1974            session.connection_id,
1975            proto::UpdateLanguageServer {
1976                project_id: project_id.to_proto(),
1977                language_server_id: language_server.id,
1978                variant: Some(
1979                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1980                        proto::LspDiskBasedDiagnosticsUpdated {},
1981                    ),
1982                ),
1983            },
1984        )?;
1985    }
1986
1987    Ok(())
1988}
1989
1990/// Leave someone elses shared project.
1991async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1992    let sender_id = session.connection_id;
1993    let project_id = ProjectId::from_proto(request.project_id);
1994    let db = session.db().await;
1995
1996    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1997    tracing::info!(
1998        %project_id,
1999        "leave project"
2000    );
2001
2002    project_left(project, &session);
2003    if let Some(room) = room {
2004        room_updated(room, &session.peer);
2005    }
2006
2007    Ok(())
2008}
2009
2010/// Updates other participants with changes to the project
2011async fn update_project(
2012    request: proto::UpdateProject,
2013    response: Response<proto::UpdateProject>,
2014    session: Session,
2015) -> Result<()> {
2016    let project_id = ProjectId::from_proto(request.project_id);
2017    let (room, guest_connection_ids) = &*session
2018        .db()
2019        .await
2020        .update_project(project_id, session.connection_id, &request.worktrees)
2021        .await?;
2022    broadcast(
2023        Some(session.connection_id),
2024        guest_connection_ids.iter().copied(),
2025        |connection_id| {
2026            session
2027                .peer
2028                .forward_send(session.connection_id, connection_id, request.clone())
2029        },
2030    );
2031    if let Some(room) = room {
2032        room_updated(room, &session.peer);
2033    }
2034    response.send(proto::Ack {})?;
2035
2036    Ok(())
2037}
2038
2039/// Updates other participants with changes to the worktree
2040async fn update_worktree(
2041    request: proto::UpdateWorktree,
2042    response: Response<proto::UpdateWorktree>,
2043    session: Session,
2044) -> Result<()> {
2045    let guest_connection_ids = session
2046        .db()
2047        .await
2048        .update_worktree(&request, session.connection_id)
2049        .await?;
2050
2051    broadcast(
2052        Some(session.connection_id),
2053        guest_connection_ids.iter().copied(),
2054        |connection_id| {
2055            session
2056                .peer
2057                .forward_send(session.connection_id, connection_id, request.clone())
2058        },
2059    );
2060    response.send(proto::Ack {})?;
2061    Ok(())
2062}
2063
2064async fn update_repository(
2065    request: proto::UpdateRepository,
2066    response: Response<proto::UpdateRepository>,
2067    session: Session,
2068) -> Result<()> {
2069    let guest_connection_ids = session
2070        .db()
2071        .await
2072        .update_repository(&request, session.connection_id)
2073        .await?;
2074
2075    broadcast(
2076        Some(session.connection_id),
2077        guest_connection_ids.iter().copied(),
2078        |connection_id| {
2079            session
2080                .peer
2081                .forward_send(session.connection_id, connection_id, request.clone())
2082        },
2083    );
2084    response.send(proto::Ack {})?;
2085    Ok(())
2086}
2087
2088async fn remove_repository(
2089    request: proto::RemoveRepository,
2090    response: Response<proto::RemoveRepository>,
2091    session: Session,
2092) -> Result<()> {
2093    let guest_connection_ids = session
2094        .db()
2095        .await
2096        .remove_repository(&request, session.connection_id)
2097        .await?;
2098
2099    broadcast(
2100        Some(session.connection_id),
2101        guest_connection_ids.iter().copied(),
2102        |connection_id| {
2103            session
2104                .peer
2105                .forward_send(session.connection_id, connection_id, request.clone())
2106        },
2107    );
2108    response.send(proto::Ack {})?;
2109    Ok(())
2110}
2111
2112/// Updates other participants with changes to the diagnostics
2113async fn update_diagnostic_summary(
2114    message: proto::UpdateDiagnosticSummary,
2115    session: Session,
2116) -> Result<()> {
2117    let guest_connection_ids = session
2118        .db()
2119        .await
2120        .update_diagnostic_summary(&message, session.connection_id)
2121        .await?;
2122
2123    broadcast(
2124        Some(session.connection_id),
2125        guest_connection_ids.iter().copied(),
2126        |connection_id| {
2127            session
2128                .peer
2129                .forward_send(session.connection_id, connection_id, message.clone())
2130        },
2131    );
2132
2133    Ok(())
2134}
2135
2136/// Updates other participants with changes to the worktree settings
2137async fn update_worktree_settings(
2138    message: proto::UpdateWorktreeSettings,
2139    session: Session,
2140) -> Result<()> {
2141    let guest_connection_ids = session
2142        .db()
2143        .await
2144        .update_worktree_settings(&message, session.connection_id)
2145        .await?;
2146
2147    broadcast(
2148        Some(session.connection_id),
2149        guest_connection_ids.iter().copied(),
2150        |connection_id| {
2151            session
2152                .peer
2153                .forward_send(session.connection_id, connection_id, message.clone())
2154        },
2155    );
2156
2157    Ok(())
2158}
2159
2160/// Notify other participants that a language server has started.
2161async fn start_language_server(
2162    request: proto::StartLanguageServer,
2163    session: Session,
2164) -> Result<()> {
2165    let guest_connection_ids = session
2166        .db()
2167        .await
2168        .start_language_server(&request, session.connection_id)
2169        .await?;
2170
2171    broadcast(
2172        Some(session.connection_id),
2173        guest_connection_ids.iter().copied(),
2174        |connection_id| {
2175            session
2176                .peer
2177                .forward_send(session.connection_id, connection_id, request.clone())
2178        },
2179    );
2180    Ok(())
2181}
2182
2183/// Notify other participants that a language server has changed.
2184async fn update_language_server(
2185    request: proto::UpdateLanguageServer,
2186    session: Session,
2187) -> Result<()> {
2188    let project_id = ProjectId::from_proto(request.project_id);
2189    let project_connection_ids = session
2190        .db()
2191        .await
2192        .project_connection_ids(project_id, session.connection_id, true)
2193        .await?;
2194    broadcast(
2195        Some(session.connection_id),
2196        project_connection_ids.iter().copied(),
2197        |connection_id| {
2198            session
2199                .peer
2200                .forward_send(session.connection_id, connection_id, request.clone())
2201        },
2202    );
2203    Ok(())
2204}
2205
2206/// forward a project request to the host. These requests should be read only
2207/// as guests are allowed to send them.
2208async fn forward_read_only_project_request<T>(
2209    request: T,
2210    response: Response<T>,
2211    session: Session,
2212) -> Result<()>
2213where
2214    T: EntityMessage + RequestMessage,
2215{
2216    let project_id = ProjectId::from_proto(request.remote_entity_id());
2217    let host_connection_id = session
2218        .db()
2219        .await
2220        .host_for_read_only_project_request(project_id, session.connection_id)
2221        .await?;
2222    let payload = session
2223        .peer
2224        .forward_request(session.connection_id, host_connection_id, request)
2225        .await?;
2226    response.send(payload)?;
2227    Ok(())
2228}
2229
2230async fn forward_find_search_candidates_request(
2231    request: proto::FindSearchCandidates,
2232    response: Response<proto::FindSearchCandidates>,
2233    session: Session,
2234) -> Result<()> {
2235    let project_id = ProjectId::from_proto(request.remote_entity_id());
2236    let host_connection_id = session
2237        .db()
2238        .await
2239        .host_for_read_only_project_request(project_id, session.connection_id)
2240        .await?;
2241    let payload = session
2242        .peer
2243        .forward_request(session.connection_id, host_connection_id, request)
2244        .await?;
2245    response.send(payload)?;
2246    Ok(())
2247}
2248
2249/// forward a project request to the host. These requests are disallowed
2250/// for guests.
2251async fn forward_mutating_project_request<T>(
2252    request: T,
2253    response: Response<T>,
2254    session: Session,
2255) -> Result<()>
2256where
2257    T: EntityMessage + RequestMessage,
2258{
2259    let project_id = ProjectId::from_proto(request.remote_entity_id());
2260
2261    let host_connection_id = session
2262        .db()
2263        .await
2264        .host_for_mutating_project_request(project_id, session.connection_id)
2265        .await?;
2266    let payload = session
2267        .peer
2268        .forward_request(session.connection_id, host_connection_id, request)
2269        .await?;
2270    response.send(payload)?;
2271    Ok(())
2272}
2273
2274/// Notify other participants that a new buffer has been created
2275async fn create_buffer_for_peer(
2276    request: proto::CreateBufferForPeer,
2277    session: Session,
2278) -> Result<()> {
2279    session
2280        .db()
2281        .await
2282        .check_user_is_project_host(
2283            ProjectId::from_proto(request.project_id),
2284            session.connection_id,
2285        )
2286        .await?;
2287    let peer_id = request.peer_id.context("invalid peer id")?;
2288    session
2289        .peer
2290        .forward_send(session.connection_id, peer_id.into(), request)?;
2291    Ok(())
2292}
2293
2294/// Notify other participants that a buffer has been updated. This is
2295/// allowed for guests as long as the update is limited to selections.
2296async fn update_buffer(
2297    request: proto::UpdateBuffer,
2298    response: Response<proto::UpdateBuffer>,
2299    session: Session,
2300) -> Result<()> {
2301    let project_id = ProjectId::from_proto(request.project_id);
2302    let mut capability = Capability::ReadOnly;
2303
2304    for op in request.operations.iter() {
2305        match op.variant {
2306            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2307            Some(_) => capability = Capability::ReadWrite,
2308        }
2309    }
2310
2311    let host = {
2312        let guard = session
2313            .db()
2314            .await
2315            .connections_for_buffer_update(project_id, session.connection_id, capability)
2316            .await?;
2317
2318        let (host, guests) = &*guard;
2319
2320        broadcast(
2321            Some(session.connection_id),
2322            guests.clone(),
2323            |connection_id| {
2324                session
2325                    .peer
2326                    .forward_send(session.connection_id, connection_id, request.clone())
2327            },
2328        );
2329
2330        *host
2331    };
2332
2333    if host != session.connection_id {
2334        session
2335            .peer
2336            .forward_request(session.connection_id, host, request.clone())
2337            .await?;
2338    }
2339
2340    response.send(proto::Ack {})?;
2341    Ok(())
2342}
2343
2344async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2345    let project_id = ProjectId::from_proto(message.project_id);
2346
2347    let operation = message.operation.as_ref().context("invalid operation")?;
2348    let capability = match operation.variant.as_ref() {
2349        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2350            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2351                match buffer_op.variant {
2352                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2353                        Capability::ReadOnly
2354                    }
2355                    _ => Capability::ReadWrite,
2356                }
2357            } else {
2358                Capability::ReadWrite
2359            }
2360        }
2361        Some(_) => Capability::ReadWrite,
2362        None => Capability::ReadOnly,
2363    };
2364
2365    let guard = session
2366        .db()
2367        .await
2368        .connections_for_buffer_update(project_id, session.connection_id, capability)
2369        .await?;
2370
2371    let (host, guests) = &*guard;
2372
2373    broadcast(
2374        Some(session.connection_id),
2375        guests.iter().chain([host]).copied(),
2376        |connection_id| {
2377            session
2378                .peer
2379                .forward_send(session.connection_id, connection_id, message.clone())
2380        },
2381    );
2382
2383    Ok(())
2384}
2385
2386/// Notify other participants that a project has been updated.
2387async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2388    request: T,
2389    session: Session,
2390) -> Result<()> {
2391    let project_id = ProjectId::from_proto(request.remote_entity_id());
2392    let project_connection_ids = session
2393        .db()
2394        .await
2395        .project_connection_ids(project_id, session.connection_id, false)
2396        .await?;
2397
2398    broadcast(
2399        Some(session.connection_id),
2400        project_connection_ids.iter().copied(),
2401        |connection_id| {
2402            session
2403                .peer
2404                .forward_send(session.connection_id, connection_id, request.clone())
2405        },
2406    );
2407    Ok(())
2408}
2409
2410/// Start following another user in a call.
2411async fn follow(
2412    request: proto::Follow,
2413    response: Response<proto::Follow>,
2414    session: Session,
2415) -> Result<()> {
2416    let room_id = RoomId::from_proto(request.room_id);
2417    let project_id = request.project_id.map(ProjectId::from_proto);
2418    let leader_id = request.leader_id.context("invalid leader id")?.into();
2419    let follower_id = session.connection_id;
2420
2421    session
2422        .db()
2423        .await
2424        .check_room_participants(room_id, leader_id, session.connection_id)
2425        .await?;
2426
2427    let response_payload = session
2428        .peer
2429        .forward_request(session.connection_id, leader_id, request)
2430        .await?;
2431    response.send(response_payload)?;
2432
2433    if let Some(project_id) = project_id {
2434        let room = session
2435            .db()
2436            .await
2437            .follow(room_id, project_id, leader_id, follower_id)
2438            .await?;
2439        room_updated(&room, &session.peer);
2440    }
2441
2442    Ok(())
2443}
2444
2445/// Stop following another user in a call.
2446async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2447    let room_id = RoomId::from_proto(request.room_id);
2448    let project_id = request.project_id.map(ProjectId::from_proto);
2449    let leader_id = request.leader_id.context("invalid leader id")?.into();
2450    let follower_id = session.connection_id;
2451
2452    session
2453        .db()
2454        .await
2455        .check_room_participants(room_id, leader_id, session.connection_id)
2456        .await?;
2457
2458    session
2459        .peer
2460        .forward_send(session.connection_id, leader_id, request)?;
2461
2462    if let Some(project_id) = project_id {
2463        let room = session
2464            .db()
2465            .await
2466            .unfollow(room_id, project_id, leader_id, follower_id)
2467            .await?;
2468        room_updated(&room, &session.peer);
2469    }
2470
2471    Ok(())
2472}
2473
2474/// Notify everyone following you of your current location.
2475async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2476    let room_id = RoomId::from_proto(request.room_id);
2477    let database = session.db.lock().await;
2478
2479    let connection_ids = if let Some(project_id) = request.project_id {
2480        let project_id = ProjectId::from_proto(project_id);
2481        database
2482            .project_connection_ids(project_id, session.connection_id, true)
2483            .await?
2484    } else {
2485        database
2486            .room_connection_ids(room_id, session.connection_id)
2487            .await?
2488    };
2489
2490    // For now, don't send view update messages back to that view's current leader.
2491    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2492        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2493        _ => None,
2494    });
2495
2496    for connection_id in connection_ids.iter().cloned() {
2497        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2498            session
2499                .peer
2500                .forward_send(session.connection_id, connection_id, request.clone())?;
2501        }
2502    }
2503    Ok(())
2504}
2505
2506/// Get public data about users.
2507async fn get_users(
2508    request: proto::GetUsers,
2509    response: Response<proto::GetUsers>,
2510    session: Session,
2511) -> Result<()> {
2512    let user_ids = request
2513        .user_ids
2514        .into_iter()
2515        .map(UserId::from_proto)
2516        .collect();
2517    let users = session
2518        .db()
2519        .await
2520        .get_users_by_ids(user_ids)
2521        .await?
2522        .into_iter()
2523        .map(|user| proto::User {
2524            id: user.id.to_proto(),
2525            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2526            github_login: user.github_login,
2527            email: user.email_address,
2528            name: user.name,
2529        })
2530        .collect();
2531    response.send(proto::UsersResponse { users })?;
2532    Ok(())
2533}
2534
2535/// Search for users (to invite) buy Github login
2536async fn fuzzy_search_users(
2537    request: proto::FuzzySearchUsers,
2538    response: Response<proto::FuzzySearchUsers>,
2539    session: Session,
2540) -> Result<()> {
2541    let query = request.query;
2542    let users = match query.len() {
2543        0 => vec![],
2544        1 | 2 => session
2545            .db()
2546            .await
2547            .get_user_by_github_login(&query)
2548            .await?
2549            .into_iter()
2550            .collect(),
2551        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2552    };
2553    let users = users
2554        .into_iter()
2555        .filter(|user| user.id != session.user_id())
2556        .map(|user| proto::User {
2557            id: user.id.to_proto(),
2558            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2559            github_login: user.github_login,
2560            name: user.name,
2561            email: user.email_address,
2562        })
2563        .collect();
2564    response.send(proto::UsersResponse { users })?;
2565    Ok(())
2566}
2567
2568/// Send a contact request to another user.
2569async fn request_contact(
2570    request: proto::RequestContact,
2571    response: Response<proto::RequestContact>,
2572    session: Session,
2573) -> Result<()> {
2574    let requester_id = session.user_id();
2575    let responder_id = UserId::from_proto(request.responder_id);
2576    if requester_id == responder_id {
2577        return Err(anyhow!("cannot add yourself as a contact"))?;
2578    }
2579
2580    let notifications = session
2581        .db()
2582        .await
2583        .send_contact_request(requester_id, responder_id)
2584        .await?;
2585
2586    // Update outgoing contact requests of requester
2587    let mut update = proto::UpdateContacts::default();
2588    update.outgoing_requests.push(responder_id.to_proto());
2589    for connection_id in session
2590        .connection_pool()
2591        .await
2592        .user_connection_ids(requester_id)
2593    {
2594        session.peer.send(connection_id, update.clone())?;
2595    }
2596
2597    // Update incoming contact requests of responder
2598    let mut update = proto::UpdateContacts::default();
2599    update
2600        .incoming_requests
2601        .push(proto::IncomingContactRequest {
2602            requester_id: requester_id.to_proto(),
2603        });
2604    let connection_pool = session.connection_pool().await;
2605    for connection_id in connection_pool.user_connection_ids(responder_id) {
2606        session.peer.send(connection_id, update.clone())?;
2607    }
2608
2609    send_notifications(&connection_pool, &session.peer, notifications);
2610
2611    response.send(proto::Ack {})?;
2612    Ok(())
2613}
2614
2615/// Accept or decline a contact request
2616async fn respond_to_contact_request(
2617    request: proto::RespondToContactRequest,
2618    response: Response<proto::RespondToContactRequest>,
2619    session: Session,
2620) -> Result<()> {
2621    let responder_id = session.user_id();
2622    let requester_id = UserId::from_proto(request.requester_id);
2623    let db = session.db().await;
2624    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2625        db.dismiss_contact_notification(responder_id, requester_id)
2626            .await?;
2627    } else {
2628        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2629
2630        let notifications = db
2631            .respond_to_contact_request(responder_id, requester_id, accept)
2632            .await?;
2633        let requester_busy = db.is_user_busy(requester_id).await?;
2634        let responder_busy = db.is_user_busy(responder_id).await?;
2635
2636        let pool = session.connection_pool().await;
2637        // Update responder with new contact
2638        let mut update = proto::UpdateContacts::default();
2639        if accept {
2640            update
2641                .contacts
2642                .push(contact_for_user(requester_id, requester_busy, &pool));
2643        }
2644        update
2645            .remove_incoming_requests
2646            .push(requester_id.to_proto());
2647        for connection_id in pool.user_connection_ids(responder_id) {
2648            session.peer.send(connection_id, update.clone())?;
2649        }
2650
2651        // Update requester with new contact
2652        let mut update = proto::UpdateContacts::default();
2653        if accept {
2654            update
2655                .contacts
2656                .push(contact_for_user(responder_id, responder_busy, &pool));
2657        }
2658        update
2659            .remove_outgoing_requests
2660            .push(responder_id.to_proto());
2661
2662        for connection_id in pool.user_connection_ids(requester_id) {
2663            session.peer.send(connection_id, update.clone())?;
2664        }
2665
2666        send_notifications(&pool, &session.peer, notifications);
2667    }
2668
2669    response.send(proto::Ack {})?;
2670    Ok(())
2671}
2672
2673/// Remove a contact.
2674async fn remove_contact(
2675    request: proto::RemoveContact,
2676    response: Response<proto::RemoveContact>,
2677    session: Session,
2678) -> Result<()> {
2679    let requester_id = session.user_id();
2680    let responder_id = UserId::from_proto(request.user_id);
2681    let db = session.db().await;
2682    let (contact_accepted, deleted_notification_id) =
2683        db.remove_contact(requester_id, responder_id).await?;
2684
2685    let pool = session.connection_pool().await;
2686    // Update outgoing contact requests of requester
2687    let mut update = proto::UpdateContacts::default();
2688    if contact_accepted {
2689        update.remove_contacts.push(responder_id.to_proto());
2690    } else {
2691        update
2692            .remove_outgoing_requests
2693            .push(responder_id.to_proto());
2694    }
2695    for connection_id in pool.user_connection_ids(requester_id) {
2696        session.peer.send(connection_id, update.clone())?;
2697    }
2698
2699    // Update incoming contact requests of responder
2700    let mut update = proto::UpdateContacts::default();
2701    if contact_accepted {
2702        update.remove_contacts.push(requester_id.to_proto());
2703    } else {
2704        update
2705            .remove_incoming_requests
2706            .push(requester_id.to_proto());
2707    }
2708    for connection_id in pool.user_connection_ids(responder_id) {
2709        session.peer.send(connection_id, update.clone())?;
2710        if let Some(notification_id) = deleted_notification_id {
2711            session.peer.send(
2712                connection_id,
2713                proto::DeleteNotification {
2714                    notification_id: notification_id.to_proto(),
2715                },
2716            )?;
2717        }
2718    }
2719
2720    response.send(proto::Ack {})?;
2721    Ok(())
2722}
2723
2724fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2725    version.0.minor() < 139
2726}
2727
2728async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2729    if is_staff {
2730        return Ok(proto::Plan::ZedPro);
2731    }
2732
2733    let subscription = db.get_active_billing_subscription(user_id).await?;
2734    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2735
2736    let plan = if let Some(subscription_kind) = subscription_kind {
2737        match subscription_kind {
2738            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2739            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2740            SubscriptionKind::ZedFree => proto::Plan::Free,
2741        }
2742    } else {
2743        proto::Plan::Free
2744    };
2745
2746    Ok(plan)
2747}
2748
2749async fn make_update_user_plan_message(
2750    user: &User,
2751    is_staff: bool,
2752    db: &Arc<Database>,
2753    llm_db: Option<Arc<LlmDatabase>>,
2754) -> Result<proto::UpdateUserPlan> {
2755    let feature_flags = db.get_user_flags(user.id).await?;
2756    let plan = current_plan(db, user.id, is_staff).await?;
2757    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2758    let billing_preferences = db.get_billing_preferences(user.id).await?;
2759
2760    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2761        let subscription = db.get_active_billing_subscription(user.id).await?;
2762
2763        let subscription_period =
2764            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2765
2766        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2767            llm_db
2768                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2769                .await?
2770        } else {
2771            None
2772        };
2773
2774        (subscription_period, usage)
2775    } else {
2776        (None, None)
2777    };
2778
2779    let bypass_account_age_check = feature_flags
2780        .iter()
2781        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2782    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2783        && !bypass_account_age_check
2784        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2785
2786    Ok(proto::UpdateUserPlan {
2787        plan: plan.into(),
2788        trial_started_at: billing_customer
2789            .as_ref()
2790            .and_then(|billing_customer| billing_customer.trial_started_at)
2791            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2792        is_usage_based_billing_enabled: if is_staff {
2793            Some(true)
2794        } else {
2795            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2796        },
2797        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2798            proto::SubscriptionPeriod {
2799                started_at: started_at.timestamp() as u64,
2800                ended_at: ended_at.timestamp() as u64,
2801            }
2802        }),
2803        account_too_young: Some(account_too_young),
2804        has_overdue_invoices: billing_customer
2805            .map(|billing_customer| billing_customer.has_overdue_invoices),
2806        usage: usage.map(|usage| {
2807            let plan = match plan {
2808                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2809                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2810                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2811            };
2812
2813            let model_requests_limit = match plan.model_requests_limit() {
2814                zed_llm_client::UsageLimit::Limited(limit) => {
2815                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2816                        && feature_flags
2817                            .iter()
2818                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2819                    {
2820                        1_000
2821                    } else {
2822                        limit
2823                    };
2824
2825                    zed_llm_client::UsageLimit::Limited(limit)
2826                }
2827                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2828            };
2829
2830            proto::SubscriptionUsage {
2831                model_requests_usage_amount: usage.model_requests as u32,
2832                model_requests_usage_limit: Some(proto::UsageLimit {
2833                    variant: Some(match model_requests_limit {
2834                        zed_llm_client::UsageLimit::Limited(limit) => {
2835                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2836                                limit: limit as u32,
2837                            })
2838                        }
2839                        zed_llm_client::UsageLimit::Unlimited => {
2840                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2841                        }
2842                    }),
2843                }),
2844                edit_predictions_usage_amount: usage.edit_predictions as u32,
2845                edit_predictions_usage_limit: Some(proto::UsageLimit {
2846                    variant: Some(match plan.edit_predictions_limit() {
2847                        zed_llm_client::UsageLimit::Limited(limit) => {
2848                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2849                                limit: limit as u32,
2850                            })
2851                        }
2852                        zed_llm_client::UsageLimit::Unlimited => {
2853                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2854                        }
2855                    }),
2856                }),
2857            }
2858        }),
2859    })
2860}
2861
2862async fn update_user_plan(session: &Session) -> Result<()> {
2863    let db = session.db().await;
2864
2865    let update_user_plan = make_update_user_plan_message(
2866        session.principal.user(),
2867        session.is_staff(),
2868        &db.0,
2869        session.app_state.llm_db.clone(),
2870    )
2871    .await?;
2872
2873    session
2874        .peer
2875        .send(session.connection_id, update_user_plan)
2876        .trace_err();
2877
2878    Ok(())
2879}
2880
2881async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2882    subscribe_user_to_channels(session.user_id(), &session).await?;
2883    Ok(())
2884}
2885
2886async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2887    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2888    let mut pool = session.connection_pool().await;
2889    for membership in &channels_for_user.channel_memberships {
2890        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2891    }
2892    session.peer.send(
2893        session.connection_id,
2894        build_update_user_channels(&channels_for_user),
2895    )?;
2896    session.peer.send(
2897        session.connection_id,
2898        build_channels_update(channels_for_user),
2899    )?;
2900    Ok(())
2901}
2902
2903/// Creates a new channel.
2904async fn create_channel(
2905    request: proto::CreateChannel,
2906    response: Response<proto::CreateChannel>,
2907    session: Session,
2908) -> Result<()> {
2909    let db = session.db().await;
2910
2911    let parent_id = request.parent_id.map(ChannelId::from_proto);
2912    let (channel, membership) = db
2913        .create_channel(&request.name, parent_id, session.user_id())
2914        .await?;
2915
2916    let root_id = channel.root_id();
2917    let channel = Channel::from_model(channel);
2918
2919    response.send(proto::CreateChannelResponse {
2920        channel: Some(channel.to_proto()),
2921        parent_id: request.parent_id,
2922    })?;
2923
2924    let mut connection_pool = session.connection_pool().await;
2925    if let Some(membership) = membership {
2926        connection_pool.subscribe_to_channel(
2927            membership.user_id,
2928            membership.channel_id,
2929            membership.role,
2930        );
2931        let update = proto::UpdateUserChannels {
2932            channel_memberships: vec![proto::ChannelMembership {
2933                channel_id: membership.channel_id.to_proto(),
2934                role: membership.role.into(),
2935            }],
2936            ..Default::default()
2937        };
2938        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2939            session.peer.send(connection_id, update.clone())?;
2940        }
2941    }
2942
2943    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2944        if !role.can_see_channel(channel.visibility) {
2945            continue;
2946        }
2947
2948        let update = proto::UpdateChannels {
2949            channels: vec![channel.to_proto()],
2950            ..Default::default()
2951        };
2952        session.peer.send(connection_id, update.clone())?;
2953    }
2954
2955    Ok(())
2956}
2957
2958/// Delete a channel
2959async fn delete_channel(
2960    request: proto::DeleteChannel,
2961    response: Response<proto::DeleteChannel>,
2962    session: Session,
2963) -> Result<()> {
2964    let db = session.db().await;
2965
2966    let channel_id = request.channel_id;
2967    let (root_channel, removed_channels) = db
2968        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2969        .await?;
2970    response.send(proto::Ack {})?;
2971
2972    // Notify members of removed channels
2973    let mut update = proto::UpdateChannels::default();
2974    update
2975        .delete_channels
2976        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2977
2978    let connection_pool = session.connection_pool().await;
2979    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2980        session.peer.send(connection_id, update.clone())?;
2981    }
2982
2983    Ok(())
2984}
2985
2986/// Invite someone to join a channel.
2987async fn invite_channel_member(
2988    request: proto::InviteChannelMember,
2989    response: Response<proto::InviteChannelMember>,
2990    session: Session,
2991) -> Result<()> {
2992    let db = session.db().await;
2993    let channel_id = ChannelId::from_proto(request.channel_id);
2994    let invitee_id = UserId::from_proto(request.user_id);
2995    let InviteMemberResult {
2996        channel,
2997        notifications,
2998    } = db
2999        .invite_channel_member(
3000            channel_id,
3001            invitee_id,
3002            session.user_id(),
3003            request.role().into(),
3004        )
3005        .await?;
3006
3007    let update = proto::UpdateChannels {
3008        channel_invitations: vec![channel.to_proto()],
3009        ..Default::default()
3010    };
3011
3012    let connection_pool = session.connection_pool().await;
3013    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3014        session.peer.send(connection_id, update.clone())?;
3015    }
3016
3017    send_notifications(&connection_pool, &session.peer, notifications);
3018
3019    response.send(proto::Ack {})?;
3020    Ok(())
3021}
3022
3023/// remove someone from a channel
3024async fn remove_channel_member(
3025    request: proto::RemoveChannelMember,
3026    response: Response<proto::RemoveChannelMember>,
3027    session: Session,
3028) -> Result<()> {
3029    let db = session.db().await;
3030    let channel_id = ChannelId::from_proto(request.channel_id);
3031    let member_id = UserId::from_proto(request.user_id);
3032
3033    let RemoveChannelMemberResult {
3034        membership_update,
3035        notification_id,
3036    } = db
3037        .remove_channel_member(channel_id, member_id, session.user_id())
3038        .await?;
3039
3040    let mut connection_pool = session.connection_pool().await;
3041    notify_membership_updated(
3042        &mut connection_pool,
3043        membership_update,
3044        member_id,
3045        &session.peer,
3046    );
3047    for connection_id in connection_pool.user_connection_ids(member_id) {
3048        if let Some(notification_id) = notification_id {
3049            session
3050                .peer
3051                .send(
3052                    connection_id,
3053                    proto::DeleteNotification {
3054                        notification_id: notification_id.to_proto(),
3055                    },
3056                )
3057                .trace_err();
3058        }
3059    }
3060
3061    response.send(proto::Ack {})?;
3062    Ok(())
3063}
3064
3065/// Toggle the channel between public and private.
3066/// Care is taken to maintain the invariant that public channels only descend from public channels,
3067/// (though members-only channels can appear at any point in the hierarchy).
3068async fn set_channel_visibility(
3069    request: proto::SetChannelVisibility,
3070    response: Response<proto::SetChannelVisibility>,
3071    session: Session,
3072) -> Result<()> {
3073    let db = session.db().await;
3074    let channel_id = ChannelId::from_proto(request.channel_id);
3075    let visibility = request.visibility().into();
3076
3077    let channel_model = db
3078        .set_channel_visibility(channel_id, visibility, session.user_id())
3079        .await?;
3080    let root_id = channel_model.root_id();
3081    let channel = Channel::from_model(channel_model);
3082
3083    let mut connection_pool = session.connection_pool().await;
3084    for (user_id, role) in connection_pool
3085        .channel_user_ids(root_id)
3086        .collect::<Vec<_>>()
3087        .into_iter()
3088    {
3089        let update = if role.can_see_channel(channel.visibility) {
3090            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3091            proto::UpdateChannels {
3092                channels: vec![channel.to_proto()],
3093                ..Default::default()
3094            }
3095        } else {
3096            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3097            proto::UpdateChannels {
3098                delete_channels: vec![channel.id.to_proto()],
3099                ..Default::default()
3100            }
3101        };
3102
3103        for connection_id in connection_pool.user_connection_ids(user_id) {
3104            session.peer.send(connection_id, update.clone())?;
3105        }
3106    }
3107
3108    response.send(proto::Ack {})?;
3109    Ok(())
3110}
3111
3112/// Alter the role for a user in the channel.
3113async fn set_channel_member_role(
3114    request: proto::SetChannelMemberRole,
3115    response: Response<proto::SetChannelMemberRole>,
3116    session: Session,
3117) -> Result<()> {
3118    let db = session.db().await;
3119    let channel_id = ChannelId::from_proto(request.channel_id);
3120    let member_id = UserId::from_proto(request.user_id);
3121    let result = db
3122        .set_channel_member_role(
3123            channel_id,
3124            session.user_id(),
3125            member_id,
3126            request.role().into(),
3127        )
3128        .await?;
3129
3130    match result {
3131        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3132            let mut connection_pool = session.connection_pool().await;
3133            notify_membership_updated(
3134                &mut connection_pool,
3135                membership_update,
3136                member_id,
3137                &session.peer,
3138            )
3139        }
3140        db::SetMemberRoleResult::InviteUpdated(channel) => {
3141            let update = proto::UpdateChannels {
3142                channel_invitations: vec![channel.to_proto()],
3143                ..Default::default()
3144            };
3145
3146            for connection_id in session
3147                .connection_pool()
3148                .await
3149                .user_connection_ids(member_id)
3150            {
3151                session.peer.send(connection_id, update.clone())?;
3152            }
3153        }
3154    }
3155
3156    response.send(proto::Ack {})?;
3157    Ok(())
3158}
3159
3160/// Change the name of a channel
3161async fn rename_channel(
3162    request: proto::RenameChannel,
3163    response: Response<proto::RenameChannel>,
3164    session: Session,
3165) -> Result<()> {
3166    let db = session.db().await;
3167    let channel_id = ChannelId::from_proto(request.channel_id);
3168    let channel_model = db
3169        .rename_channel(channel_id, session.user_id(), &request.name)
3170        .await?;
3171    let root_id = channel_model.root_id();
3172    let channel = Channel::from_model(channel_model);
3173
3174    response.send(proto::RenameChannelResponse {
3175        channel: Some(channel.to_proto()),
3176    })?;
3177
3178    let connection_pool = session.connection_pool().await;
3179    let update = proto::UpdateChannels {
3180        channels: vec![channel.to_proto()],
3181        ..Default::default()
3182    };
3183    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3184        if role.can_see_channel(channel.visibility) {
3185            session.peer.send(connection_id, update.clone())?;
3186        }
3187    }
3188
3189    Ok(())
3190}
3191
3192/// Move a channel to a new parent.
3193async fn move_channel(
3194    request: proto::MoveChannel,
3195    response: Response<proto::MoveChannel>,
3196    session: Session,
3197) -> Result<()> {
3198    let channel_id = ChannelId::from_proto(request.channel_id);
3199    let to = ChannelId::from_proto(request.to);
3200
3201    let (root_id, channels) = session
3202        .db()
3203        .await
3204        .move_channel(channel_id, to, session.user_id())
3205        .await?;
3206
3207    let connection_pool = session.connection_pool().await;
3208    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3209        let channels = channels
3210            .iter()
3211            .filter_map(|channel| {
3212                if role.can_see_channel(channel.visibility) {
3213                    Some(channel.to_proto())
3214                } else {
3215                    None
3216                }
3217            })
3218            .collect::<Vec<_>>();
3219        if channels.is_empty() {
3220            continue;
3221        }
3222
3223        let update = proto::UpdateChannels {
3224            channels,
3225            ..Default::default()
3226        };
3227
3228        session.peer.send(connection_id, update.clone())?;
3229    }
3230
3231    response.send(Ack {})?;
3232    Ok(())
3233}
3234
3235async fn reorder_channel(
3236    request: proto::ReorderChannel,
3237    response: Response<proto::ReorderChannel>,
3238    session: Session,
3239) -> Result<()> {
3240    let channel_id = ChannelId::from_proto(request.channel_id);
3241    let direction = request.direction();
3242
3243    let updated_channels = session
3244        .db()
3245        .await
3246        .reorder_channel(channel_id, direction, session.user_id())
3247        .await?;
3248
3249    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3250        let connection_pool = session.connection_pool().await;
3251        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3252            let channels = updated_channels
3253                .iter()
3254                .filter_map(|channel| {
3255                    if role.can_see_channel(channel.visibility) {
3256                        Some(channel.to_proto())
3257                    } else {
3258                        None
3259                    }
3260                })
3261                .collect::<Vec<_>>();
3262
3263            if channels.is_empty() {
3264                continue;
3265            }
3266
3267            let update = proto::UpdateChannels {
3268                channels,
3269                ..Default::default()
3270            };
3271
3272            session.peer.send(connection_id, update.clone())?;
3273        }
3274    }
3275
3276    response.send(Ack {})?;
3277    Ok(())
3278}
3279
3280/// Get the list of channel members
3281async fn get_channel_members(
3282    request: proto::GetChannelMembers,
3283    response: Response<proto::GetChannelMembers>,
3284    session: Session,
3285) -> Result<()> {
3286    let db = session.db().await;
3287    let channel_id = ChannelId::from_proto(request.channel_id);
3288    let limit = if request.limit == 0 {
3289        u16::MAX as u64
3290    } else {
3291        request.limit
3292    };
3293    let (members, users) = db
3294        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3295        .await?;
3296    response.send(proto::GetChannelMembersResponse { members, users })?;
3297    Ok(())
3298}
3299
3300/// Accept or decline a channel invitation.
3301async fn respond_to_channel_invite(
3302    request: proto::RespondToChannelInvite,
3303    response: Response<proto::RespondToChannelInvite>,
3304    session: Session,
3305) -> Result<()> {
3306    let db = session.db().await;
3307    let channel_id = ChannelId::from_proto(request.channel_id);
3308    let RespondToChannelInvite {
3309        membership_update,
3310        notifications,
3311    } = db
3312        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3313        .await?;
3314
3315    let mut connection_pool = session.connection_pool().await;
3316    if let Some(membership_update) = membership_update {
3317        notify_membership_updated(
3318            &mut connection_pool,
3319            membership_update,
3320            session.user_id(),
3321            &session.peer,
3322        );
3323    } else {
3324        let update = proto::UpdateChannels {
3325            remove_channel_invitations: vec![channel_id.to_proto()],
3326            ..Default::default()
3327        };
3328
3329        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3330            session.peer.send(connection_id, update.clone())?;
3331        }
3332    };
3333
3334    send_notifications(&connection_pool, &session.peer, notifications);
3335
3336    response.send(proto::Ack {})?;
3337
3338    Ok(())
3339}
3340
3341/// Join the channels' room
3342async fn join_channel(
3343    request: proto::JoinChannel,
3344    response: Response<proto::JoinChannel>,
3345    session: Session,
3346) -> Result<()> {
3347    let channel_id = ChannelId::from_proto(request.channel_id);
3348    join_channel_internal(channel_id, Box::new(response), session).await
3349}
3350
3351trait JoinChannelInternalResponse {
3352    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3353}
3354impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3355    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3356        Response::<proto::JoinChannel>::send(self, result)
3357    }
3358}
3359impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3360    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3361        Response::<proto::JoinRoom>::send(self, result)
3362    }
3363}
3364
3365async fn join_channel_internal(
3366    channel_id: ChannelId,
3367    response: Box<impl JoinChannelInternalResponse>,
3368    session: Session,
3369) -> Result<()> {
3370    let joined_room = {
3371        let mut db = session.db().await;
3372        // If zed quits without leaving the room, and the user re-opens zed before the
3373        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3374        // room they were in.
3375        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3376            tracing::info!(
3377                stale_connection_id = %connection,
3378                "cleaning up stale connection",
3379            );
3380            drop(db);
3381            leave_room_for_session(&session, connection).await?;
3382            db = session.db().await;
3383        }
3384
3385        let (joined_room, membership_updated, role) = db
3386            .join_channel(channel_id, session.user_id(), session.connection_id)
3387            .await?;
3388
3389        let live_kit_connection_info =
3390            session
3391                .app_state
3392                .livekit_client
3393                .as_ref()
3394                .and_then(|live_kit| {
3395                    let (can_publish, token) = if role == ChannelRole::Guest {
3396                        (
3397                            false,
3398                            live_kit
3399                                .guest_token(
3400                                    &joined_room.room.livekit_room,
3401                                    &session.user_id().to_string(),
3402                                )
3403                                .trace_err()?,
3404                        )
3405                    } else {
3406                        (
3407                            true,
3408                            live_kit
3409                                .room_token(
3410                                    &joined_room.room.livekit_room,
3411                                    &session.user_id().to_string(),
3412                                )
3413                                .trace_err()?,
3414                        )
3415                    };
3416
3417                    Some(LiveKitConnectionInfo {
3418                        server_url: live_kit.url().into(),
3419                        token,
3420                        can_publish,
3421                    })
3422                });
3423
3424        response.send(proto::JoinRoomResponse {
3425            room: Some(joined_room.room.clone()),
3426            channel_id: joined_room
3427                .channel
3428                .as_ref()
3429                .map(|channel| channel.id.to_proto()),
3430            live_kit_connection_info,
3431        })?;
3432
3433        let mut connection_pool = session.connection_pool().await;
3434        if let Some(membership_updated) = membership_updated {
3435            notify_membership_updated(
3436                &mut connection_pool,
3437                membership_updated,
3438                session.user_id(),
3439                &session.peer,
3440            );
3441        }
3442
3443        room_updated(&joined_room.room, &session.peer);
3444
3445        joined_room
3446    };
3447
3448    channel_updated(
3449        &joined_room.channel.context("channel not returned")?,
3450        &joined_room.room,
3451        &session.peer,
3452        &*session.connection_pool().await,
3453    );
3454
3455    update_user_contacts(session.user_id(), &session).await?;
3456    Ok(())
3457}
3458
3459/// Start editing the channel notes
3460async fn join_channel_buffer(
3461    request: proto::JoinChannelBuffer,
3462    response: Response<proto::JoinChannelBuffer>,
3463    session: Session,
3464) -> Result<()> {
3465    let db = session.db().await;
3466    let channel_id = ChannelId::from_proto(request.channel_id);
3467
3468    let open_response = db
3469        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3470        .await?;
3471
3472    let collaborators = open_response.collaborators.clone();
3473    response.send(open_response)?;
3474
3475    let update = UpdateChannelBufferCollaborators {
3476        channel_id: channel_id.to_proto(),
3477        collaborators: collaborators.clone(),
3478    };
3479    channel_buffer_updated(
3480        session.connection_id,
3481        collaborators
3482            .iter()
3483            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3484        &update,
3485        &session.peer,
3486    );
3487
3488    Ok(())
3489}
3490
3491/// Edit the channel notes
3492async fn update_channel_buffer(
3493    request: proto::UpdateChannelBuffer,
3494    session: Session,
3495) -> Result<()> {
3496    let db = session.db().await;
3497    let channel_id = ChannelId::from_proto(request.channel_id);
3498
3499    let (collaborators, epoch, version) = db
3500        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3501        .await?;
3502
3503    channel_buffer_updated(
3504        session.connection_id,
3505        collaborators.clone(),
3506        &proto::UpdateChannelBuffer {
3507            channel_id: channel_id.to_proto(),
3508            operations: request.operations,
3509        },
3510        &session.peer,
3511    );
3512
3513    let pool = &*session.connection_pool().await;
3514
3515    let non_collaborators =
3516        pool.channel_connection_ids(channel_id)
3517            .filter_map(|(connection_id, _)| {
3518                if collaborators.contains(&connection_id) {
3519                    None
3520                } else {
3521                    Some(connection_id)
3522                }
3523            });
3524
3525    broadcast(None, non_collaborators, |peer_id| {
3526        session.peer.send(
3527            peer_id,
3528            proto::UpdateChannels {
3529                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3530                    channel_id: channel_id.to_proto(),
3531                    epoch: epoch as u64,
3532                    version: version.clone(),
3533                }],
3534                ..Default::default()
3535            },
3536        )
3537    });
3538
3539    Ok(())
3540}
3541
3542/// Rejoin the channel notes after a connection blip
3543async fn rejoin_channel_buffers(
3544    request: proto::RejoinChannelBuffers,
3545    response: Response<proto::RejoinChannelBuffers>,
3546    session: Session,
3547) -> Result<()> {
3548    let db = session.db().await;
3549    let buffers = db
3550        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3551        .await?;
3552
3553    for rejoined_buffer in &buffers {
3554        let collaborators_to_notify = rejoined_buffer
3555            .buffer
3556            .collaborators
3557            .iter()
3558            .filter_map(|c| Some(c.peer_id?.into()));
3559        channel_buffer_updated(
3560            session.connection_id,
3561            collaborators_to_notify,
3562            &proto::UpdateChannelBufferCollaborators {
3563                channel_id: rejoined_buffer.buffer.channel_id,
3564                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3565            },
3566            &session.peer,
3567        );
3568    }
3569
3570    response.send(proto::RejoinChannelBuffersResponse {
3571        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3572    })?;
3573
3574    Ok(())
3575}
3576
3577/// Stop editing the channel notes
3578async fn leave_channel_buffer(
3579    request: proto::LeaveChannelBuffer,
3580    response: Response<proto::LeaveChannelBuffer>,
3581    session: Session,
3582) -> Result<()> {
3583    let db = session.db().await;
3584    let channel_id = ChannelId::from_proto(request.channel_id);
3585
3586    let left_buffer = db
3587        .leave_channel_buffer(channel_id, session.connection_id)
3588        .await?;
3589
3590    response.send(Ack {})?;
3591
3592    channel_buffer_updated(
3593        session.connection_id,
3594        left_buffer.connections,
3595        &proto::UpdateChannelBufferCollaborators {
3596            channel_id: channel_id.to_proto(),
3597            collaborators: left_buffer.collaborators,
3598        },
3599        &session.peer,
3600    );
3601
3602    Ok(())
3603}
3604
3605fn channel_buffer_updated<T: EnvelopedMessage>(
3606    sender_id: ConnectionId,
3607    collaborators: impl IntoIterator<Item = ConnectionId>,
3608    message: &T,
3609    peer: &Peer,
3610) {
3611    broadcast(Some(sender_id), collaborators, |peer_id| {
3612        peer.send(peer_id, message.clone())
3613    });
3614}
3615
3616fn send_notifications(
3617    connection_pool: &ConnectionPool,
3618    peer: &Peer,
3619    notifications: db::NotificationBatch,
3620) {
3621    for (user_id, notification) in notifications {
3622        for connection_id in connection_pool.user_connection_ids(user_id) {
3623            if let Err(error) = peer.send(
3624                connection_id,
3625                proto::AddNotification {
3626                    notification: Some(notification.clone()),
3627                },
3628            ) {
3629                tracing::error!(
3630                    "failed to send notification to {:?} {}",
3631                    connection_id,
3632                    error
3633                );
3634            }
3635        }
3636    }
3637}
3638
3639/// Send a message to the channel
3640async fn send_channel_message(
3641    request: proto::SendChannelMessage,
3642    response: Response<proto::SendChannelMessage>,
3643    session: Session,
3644) -> Result<()> {
3645    // Validate the message body.
3646    let body = request.body.trim().to_string();
3647    if body.len() > MAX_MESSAGE_LEN {
3648        return Err(anyhow!("message is too long"))?;
3649    }
3650    if body.is_empty() {
3651        return Err(anyhow!("message can't be blank"))?;
3652    }
3653
3654    // TODO: adjust mentions if body is trimmed
3655
3656    let timestamp = OffsetDateTime::now_utc();
3657    let nonce = request.nonce.context("nonce can't be blank")?;
3658
3659    let channel_id = ChannelId::from_proto(request.channel_id);
3660    let CreatedChannelMessage {
3661        message_id,
3662        participant_connection_ids,
3663        notifications,
3664    } = session
3665        .db()
3666        .await
3667        .create_channel_message(
3668            channel_id,
3669            session.user_id(),
3670            &body,
3671            &request.mentions,
3672            timestamp,
3673            nonce.clone().into(),
3674            request.reply_to_message_id.map(MessageId::from_proto),
3675        )
3676        .await?;
3677
3678    let message = proto::ChannelMessage {
3679        sender_id: session.user_id().to_proto(),
3680        id: message_id.to_proto(),
3681        body,
3682        mentions: request.mentions,
3683        timestamp: timestamp.unix_timestamp() as u64,
3684        nonce: Some(nonce),
3685        reply_to_message_id: request.reply_to_message_id,
3686        edited_at: None,
3687    };
3688    broadcast(
3689        Some(session.connection_id),
3690        participant_connection_ids.clone(),
3691        |connection| {
3692            session.peer.send(
3693                connection,
3694                proto::ChannelMessageSent {
3695                    channel_id: channel_id.to_proto(),
3696                    message: Some(message.clone()),
3697                },
3698            )
3699        },
3700    );
3701    response.send(proto::SendChannelMessageResponse {
3702        message: Some(message),
3703    })?;
3704
3705    let pool = &*session.connection_pool().await;
3706    let non_participants =
3707        pool.channel_connection_ids(channel_id)
3708            .filter_map(|(connection_id, _)| {
3709                if participant_connection_ids.contains(&connection_id) {
3710                    None
3711                } else {
3712                    Some(connection_id)
3713                }
3714            });
3715    broadcast(None, non_participants, |peer_id| {
3716        session.peer.send(
3717            peer_id,
3718            proto::UpdateChannels {
3719                latest_channel_message_ids: vec![proto::ChannelMessageId {
3720                    channel_id: channel_id.to_proto(),
3721                    message_id: message_id.to_proto(),
3722                }],
3723                ..Default::default()
3724            },
3725        )
3726    });
3727    send_notifications(pool, &session.peer, notifications);
3728
3729    Ok(())
3730}
3731
3732/// Delete a channel message
3733async fn remove_channel_message(
3734    request: proto::RemoveChannelMessage,
3735    response: Response<proto::RemoveChannelMessage>,
3736    session: Session,
3737) -> Result<()> {
3738    let channel_id = ChannelId::from_proto(request.channel_id);
3739    let message_id = MessageId::from_proto(request.message_id);
3740    let (connection_ids, existing_notification_ids) = session
3741        .db()
3742        .await
3743        .remove_channel_message(channel_id, message_id, session.user_id())
3744        .await?;
3745
3746    broadcast(
3747        Some(session.connection_id),
3748        connection_ids,
3749        move |connection| {
3750            session.peer.send(connection, request.clone())?;
3751
3752            for notification_id in &existing_notification_ids {
3753                session.peer.send(
3754                    connection,
3755                    proto::DeleteNotification {
3756                        notification_id: (*notification_id).to_proto(),
3757                    },
3758                )?;
3759            }
3760
3761            Ok(())
3762        },
3763    );
3764    response.send(proto::Ack {})?;
3765    Ok(())
3766}
3767
3768async fn update_channel_message(
3769    request: proto::UpdateChannelMessage,
3770    response: Response<proto::UpdateChannelMessage>,
3771    session: Session,
3772) -> Result<()> {
3773    let channel_id = ChannelId::from_proto(request.channel_id);
3774    let message_id = MessageId::from_proto(request.message_id);
3775    let updated_at = OffsetDateTime::now_utc();
3776    let UpdatedChannelMessage {
3777        message_id,
3778        participant_connection_ids,
3779        notifications,
3780        reply_to_message_id,
3781        timestamp,
3782        deleted_mention_notification_ids,
3783        updated_mention_notifications,
3784    } = session
3785        .db()
3786        .await
3787        .update_channel_message(
3788            channel_id,
3789            message_id,
3790            session.user_id(),
3791            request.body.as_str(),
3792            &request.mentions,
3793            updated_at,
3794        )
3795        .await?;
3796
3797    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3798
3799    let message = proto::ChannelMessage {
3800        sender_id: session.user_id().to_proto(),
3801        id: message_id.to_proto(),
3802        body: request.body.clone(),
3803        mentions: request.mentions.clone(),
3804        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3805        nonce: Some(nonce),
3806        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3807        edited_at: Some(updated_at.unix_timestamp() as u64),
3808    };
3809
3810    response.send(proto::Ack {})?;
3811
3812    let pool = &*session.connection_pool().await;
3813    broadcast(
3814        Some(session.connection_id),
3815        participant_connection_ids,
3816        |connection| {
3817            session.peer.send(
3818                connection,
3819                proto::ChannelMessageUpdate {
3820                    channel_id: channel_id.to_proto(),
3821                    message: Some(message.clone()),
3822                },
3823            )?;
3824
3825            for notification_id in &deleted_mention_notification_ids {
3826                session.peer.send(
3827                    connection,
3828                    proto::DeleteNotification {
3829                        notification_id: (*notification_id).to_proto(),
3830                    },
3831                )?;
3832            }
3833
3834            for notification in &updated_mention_notifications {
3835                session.peer.send(
3836                    connection,
3837                    proto::UpdateNotification {
3838                        notification: Some(notification.clone()),
3839                    },
3840                )?;
3841            }
3842
3843            Ok(())
3844        },
3845    );
3846
3847    send_notifications(pool, &session.peer, notifications);
3848
3849    Ok(())
3850}
3851
3852/// Mark a channel message as read
3853async fn acknowledge_channel_message(
3854    request: proto::AckChannelMessage,
3855    session: Session,
3856) -> Result<()> {
3857    let channel_id = ChannelId::from_proto(request.channel_id);
3858    let message_id = MessageId::from_proto(request.message_id);
3859    let notifications = session
3860        .db()
3861        .await
3862        .observe_channel_message(channel_id, session.user_id(), message_id)
3863        .await?;
3864    send_notifications(
3865        &*session.connection_pool().await,
3866        &session.peer,
3867        notifications,
3868    );
3869    Ok(())
3870}
3871
3872/// Mark a buffer version as synced
3873async fn acknowledge_buffer_version(
3874    request: proto::AckBufferOperation,
3875    session: Session,
3876) -> Result<()> {
3877    let buffer_id = BufferId::from_proto(request.buffer_id);
3878    session
3879        .db()
3880        .await
3881        .observe_buffer_version(
3882            buffer_id,
3883            session.user_id(),
3884            request.epoch as i32,
3885            &request.version,
3886        )
3887        .await?;
3888    Ok(())
3889}
3890
3891/// Get a Supermaven API key for the user
3892async fn get_supermaven_api_key(
3893    _request: proto::GetSupermavenApiKey,
3894    response: Response<proto::GetSupermavenApiKey>,
3895    session: Session,
3896) -> Result<()> {
3897    let user_id: String = session.user_id().to_string();
3898    if !session.is_staff() {
3899        return Err(anyhow!("supermaven not enabled for this account"))?;
3900    }
3901
3902    let email = session.email().context("user must have an email")?;
3903
3904    let supermaven_admin_api = session
3905        .supermaven_client
3906        .as_ref()
3907        .context("supermaven not configured")?;
3908
3909    let result = supermaven_admin_api
3910        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3911        .await?;
3912
3913    response.send(proto::GetSupermavenApiKeyResponse {
3914        api_key: result.api_key,
3915    })?;
3916
3917    Ok(())
3918}
3919
3920/// Start receiving chat updates for a channel
3921async fn join_channel_chat(
3922    request: proto::JoinChannelChat,
3923    response: Response<proto::JoinChannelChat>,
3924    session: Session,
3925) -> Result<()> {
3926    let channel_id = ChannelId::from_proto(request.channel_id);
3927
3928    let db = session.db().await;
3929    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3930        .await?;
3931    let messages = db
3932        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3933        .await?;
3934    response.send(proto::JoinChannelChatResponse {
3935        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3936        messages,
3937    })?;
3938    Ok(())
3939}
3940
3941/// Stop receiving chat updates for a channel
3942async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3943    let channel_id = ChannelId::from_proto(request.channel_id);
3944    session
3945        .db()
3946        .await
3947        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3948        .await?;
3949    Ok(())
3950}
3951
3952/// Retrieve the chat history for a channel
3953async fn get_channel_messages(
3954    request: proto::GetChannelMessages,
3955    response: Response<proto::GetChannelMessages>,
3956    session: Session,
3957) -> Result<()> {
3958    let channel_id = ChannelId::from_proto(request.channel_id);
3959    let messages = session
3960        .db()
3961        .await
3962        .get_channel_messages(
3963            channel_id,
3964            session.user_id(),
3965            MESSAGE_COUNT_PER_PAGE,
3966            Some(MessageId::from_proto(request.before_message_id)),
3967        )
3968        .await?;
3969    response.send(proto::GetChannelMessagesResponse {
3970        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3971        messages,
3972    })?;
3973    Ok(())
3974}
3975
3976/// Retrieve specific chat messages
3977async fn get_channel_messages_by_id(
3978    request: proto::GetChannelMessagesById,
3979    response: Response<proto::GetChannelMessagesById>,
3980    session: Session,
3981) -> Result<()> {
3982    let message_ids = request
3983        .message_ids
3984        .iter()
3985        .map(|id| MessageId::from_proto(*id))
3986        .collect::<Vec<_>>();
3987    let messages = session
3988        .db()
3989        .await
3990        .get_channel_messages_by_id(session.user_id(), &message_ids)
3991        .await?;
3992    response.send(proto::GetChannelMessagesResponse {
3993        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3994        messages,
3995    })?;
3996    Ok(())
3997}
3998
3999/// Retrieve the current users notifications
4000async fn get_notifications(
4001    request: proto::GetNotifications,
4002    response: Response<proto::GetNotifications>,
4003    session: Session,
4004) -> Result<()> {
4005    let notifications = session
4006        .db()
4007        .await
4008        .get_notifications(
4009            session.user_id(),
4010            NOTIFICATION_COUNT_PER_PAGE,
4011            request.before_id.map(db::NotificationId::from_proto),
4012        )
4013        .await?;
4014    response.send(proto::GetNotificationsResponse {
4015        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4016        notifications,
4017    })?;
4018    Ok(())
4019}
4020
4021/// Mark notifications as read
4022async fn mark_notification_as_read(
4023    request: proto::MarkNotificationRead,
4024    response: Response<proto::MarkNotificationRead>,
4025    session: Session,
4026) -> Result<()> {
4027    let database = &session.db().await;
4028    let notifications = database
4029        .mark_notification_as_read_by_id(
4030            session.user_id(),
4031            NotificationId::from_proto(request.notification_id),
4032        )
4033        .await?;
4034    send_notifications(
4035        &*session.connection_pool().await,
4036        &session.peer,
4037        notifications,
4038    );
4039    response.send(proto::Ack {})?;
4040    Ok(())
4041}
4042
4043/// Get the current users information
4044async fn get_private_user_info(
4045    _request: proto::GetPrivateUserInfo,
4046    response: Response<proto::GetPrivateUserInfo>,
4047    session: Session,
4048) -> Result<()> {
4049    let db = session.db().await;
4050
4051    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4052    let user = db
4053        .get_user_by_id(session.user_id())
4054        .await?
4055        .context("user not found")?;
4056    let flags = db.get_user_flags(session.user_id()).await?;
4057
4058    response.send(proto::GetPrivateUserInfoResponse {
4059        metrics_id,
4060        staff: user.admin,
4061        flags,
4062        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4063    })?;
4064    Ok(())
4065}
4066
4067/// Accept the terms of service (tos) on behalf of the current user
4068async fn accept_terms_of_service(
4069    _request: proto::AcceptTermsOfService,
4070    response: Response<proto::AcceptTermsOfService>,
4071    session: Session,
4072) -> Result<()> {
4073    let db = session.db().await;
4074
4075    let accepted_tos_at = Utc::now();
4076    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4077        .await?;
4078
4079    response.send(proto::AcceptTermsOfServiceResponse {
4080        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4081    })?;
4082    Ok(())
4083}
4084
4085async fn get_llm_api_token(
4086    _request: proto::GetLlmToken,
4087    response: Response<proto::GetLlmToken>,
4088    session: Session,
4089) -> Result<()> {
4090    let db = session.db().await;
4091
4092    let flags = db.get_user_flags(session.user_id()).await?;
4093
4094    let user_id = session.user_id();
4095    let user = db
4096        .get_user_by_id(user_id)
4097        .await?
4098        .with_context(|| format!("user {user_id} not found"))?;
4099
4100    if user.accepted_tos_at.is_none() {
4101        Err(anyhow!("terms of service not accepted"))?
4102    }
4103
4104    let stripe_client = session
4105        .app_state
4106        .stripe_client
4107        .as_ref()
4108        .context("failed to retrieve Stripe client")?;
4109
4110    let stripe_billing = session
4111        .app_state
4112        .stripe_billing
4113        .as_ref()
4114        .context("failed to retrieve Stripe billing object")?;
4115
4116    let billing_customer = if let Some(billing_customer) =
4117        db.get_billing_customer_by_user_id(user.id).await?
4118    {
4119        billing_customer
4120    } else {
4121        let customer_id = stripe_billing
4122            .find_or_create_customer_by_email(user.email_address.as_deref())
4123            .await?;
4124
4125        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4126            .await?
4127            .context("billing customer not found")?
4128    };
4129
4130    let billing_subscription =
4131        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4132            billing_subscription
4133        } else {
4134            let stripe_customer_id =
4135                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4136
4137            let stripe_subscription = stripe_billing
4138                .subscribe_to_zed_free(stripe_customer_id)
4139                .await?;
4140
4141            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4142                billing_customer_id: billing_customer.id,
4143                kind: Some(SubscriptionKind::ZedFree),
4144                stripe_subscription_id: stripe_subscription.id.to_string(),
4145                stripe_subscription_status: stripe_subscription.status.into(),
4146                stripe_cancellation_reason: None,
4147                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4148                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4149            })
4150            .await?
4151        };
4152
4153    let billing_preferences = db.get_billing_preferences(user.id).await?;
4154
4155    let token = LlmTokenClaims::create(
4156        &user,
4157        session.is_staff(),
4158        billing_customer,
4159        billing_preferences,
4160        &flags,
4161        billing_subscription,
4162        session.system_id.clone(),
4163        &session.app_state.config,
4164    )?;
4165    response.send(proto::GetLlmTokenResponse { token })?;
4166    Ok(())
4167}
4168
4169fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4170    let message = match message {
4171        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4172        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4173        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4174        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4175        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4176            code: frame.code.into(),
4177            reason: frame.reason.as_str().to_owned().into(),
4178        })),
4179        // We should never receive a frame while reading the message, according
4180        // to the `tungstenite` maintainers:
4181        //
4182        // > It cannot occur when you read messages from the WebSocket, but it
4183        // > can be used when you want to send the raw frames (e.g. you want to
4184        // > send the frames to the WebSocket without composing the full message first).
4185        // >
4186        // > — https://github.com/snapview/tungstenite-rs/issues/268
4187        TungsteniteMessage::Frame(_) => {
4188            bail!("received an unexpected frame while reading the message")
4189        }
4190    };
4191
4192    Ok(message)
4193}
4194
4195fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4196    match message {
4197        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4198        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4199        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4200        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4201        AxumMessage::Close(frame) => {
4202            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4203                code: frame.code.into(),
4204                reason: frame.reason.as_ref().into(),
4205            }))
4206        }
4207    }
4208}
4209
4210fn notify_membership_updated(
4211    connection_pool: &mut ConnectionPool,
4212    result: MembershipUpdated,
4213    user_id: UserId,
4214    peer: &Peer,
4215) {
4216    for membership in &result.new_channels.channel_memberships {
4217        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4218    }
4219    for channel_id in &result.removed_channels {
4220        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4221    }
4222
4223    let user_channels_update = proto::UpdateUserChannels {
4224        channel_memberships: result
4225            .new_channels
4226            .channel_memberships
4227            .iter()
4228            .map(|cm| proto::ChannelMembership {
4229                channel_id: cm.channel_id.to_proto(),
4230                role: cm.role.into(),
4231            })
4232            .collect(),
4233        ..Default::default()
4234    };
4235
4236    let mut update = build_channels_update(result.new_channels);
4237    update.delete_channels = result
4238        .removed_channels
4239        .into_iter()
4240        .map(|id| id.to_proto())
4241        .collect();
4242    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4243
4244    for connection_id in connection_pool.user_connection_ids(user_id) {
4245        peer.send(connection_id, user_channels_update.clone())
4246            .trace_err();
4247        peer.send(connection_id, update.clone()).trace_err();
4248    }
4249}
4250
4251fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4252    proto::UpdateUserChannels {
4253        channel_memberships: channels
4254            .channel_memberships
4255            .iter()
4256            .map(|m| proto::ChannelMembership {
4257                channel_id: m.channel_id.to_proto(),
4258                role: m.role.into(),
4259            })
4260            .collect(),
4261        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4262        observed_channel_message_id: channels.observed_channel_messages.clone(),
4263    }
4264}
4265
4266fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4267    let mut update = proto::UpdateChannels::default();
4268
4269    for channel in channels.channels {
4270        update.channels.push(channel.to_proto());
4271    }
4272
4273    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4274    update.latest_channel_message_ids = channels.latest_channel_messages;
4275
4276    for (channel_id, participants) in channels.channel_participants {
4277        update
4278            .channel_participants
4279            .push(proto::ChannelParticipants {
4280                channel_id: channel_id.to_proto(),
4281                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4282            });
4283    }
4284
4285    for channel in channels.invited_channels {
4286        update.channel_invitations.push(channel.to_proto());
4287    }
4288
4289    update
4290}
4291
4292fn build_initial_contacts_update(
4293    contacts: Vec<db::Contact>,
4294    pool: &ConnectionPool,
4295) -> proto::UpdateContacts {
4296    let mut update = proto::UpdateContacts::default();
4297
4298    for contact in contacts {
4299        match contact {
4300            db::Contact::Accepted { user_id, busy } => {
4301                update.contacts.push(contact_for_user(user_id, busy, pool));
4302            }
4303            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4304            db::Contact::Incoming { user_id } => {
4305                update
4306                    .incoming_requests
4307                    .push(proto::IncomingContactRequest {
4308                        requester_id: user_id.to_proto(),
4309                    })
4310            }
4311        }
4312    }
4313
4314    update
4315}
4316
4317fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4318    proto::Contact {
4319        user_id: user_id.to_proto(),
4320        online: pool.is_user_online(user_id),
4321        busy,
4322    }
4323}
4324
4325fn room_updated(room: &proto::Room, peer: &Peer) {
4326    broadcast(
4327        None,
4328        room.participants
4329            .iter()
4330            .filter_map(|participant| Some(participant.peer_id?.into())),
4331        |peer_id| {
4332            peer.send(
4333                peer_id,
4334                proto::RoomUpdated {
4335                    room: Some(room.clone()),
4336                },
4337            )
4338        },
4339    );
4340}
4341
4342fn channel_updated(
4343    channel: &db::channel::Model,
4344    room: &proto::Room,
4345    peer: &Peer,
4346    pool: &ConnectionPool,
4347) {
4348    let participants = room
4349        .participants
4350        .iter()
4351        .map(|p| p.user_id)
4352        .collect::<Vec<_>>();
4353
4354    broadcast(
4355        None,
4356        pool.channel_connection_ids(channel.root_id())
4357            .filter_map(|(channel_id, role)| {
4358                role.can_see_channel(channel.visibility)
4359                    .then_some(channel_id)
4360            }),
4361        |peer_id| {
4362            peer.send(
4363                peer_id,
4364                proto::UpdateChannels {
4365                    channel_participants: vec![proto::ChannelParticipants {
4366                        channel_id: channel.id.to_proto(),
4367                        participant_user_ids: participants.clone(),
4368                    }],
4369                    ..Default::default()
4370                },
4371            )
4372        },
4373    );
4374}
4375
4376async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4377    let db = session.db().await;
4378
4379    let contacts = db.get_contacts(user_id).await?;
4380    let busy = db.is_user_busy(user_id).await?;
4381
4382    let pool = session.connection_pool().await;
4383    let updated_contact = contact_for_user(user_id, busy, &pool);
4384    for contact in contacts {
4385        if let db::Contact::Accepted {
4386            user_id: contact_user_id,
4387            ..
4388        } = contact
4389        {
4390            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4391                session
4392                    .peer
4393                    .send(
4394                        contact_conn_id,
4395                        proto::UpdateContacts {
4396                            contacts: vec![updated_contact.clone()],
4397                            remove_contacts: Default::default(),
4398                            incoming_requests: Default::default(),
4399                            remove_incoming_requests: Default::default(),
4400                            outgoing_requests: Default::default(),
4401                            remove_outgoing_requests: Default::default(),
4402                        },
4403                    )
4404                    .trace_err();
4405            }
4406        }
4407    }
4408    Ok(())
4409}
4410
4411async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4412    let mut contacts_to_update = HashSet::default();
4413
4414    let room_id;
4415    let canceled_calls_to_user_ids;
4416    let livekit_room;
4417    let delete_livekit_room;
4418    let room;
4419    let channel;
4420
4421    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4422        contacts_to_update.insert(session.user_id());
4423
4424        for project in left_room.left_projects.values() {
4425            project_left(project, session);
4426        }
4427
4428        room_id = RoomId::from_proto(left_room.room.id);
4429        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4430        livekit_room = mem::take(&mut left_room.room.livekit_room);
4431        delete_livekit_room = left_room.deleted;
4432        room = mem::take(&mut left_room.room);
4433        channel = mem::take(&mut left_room.channel);
4434
4435        room_updated(&room, &session.peer);
4436    } else {
4437        return Ok(());
4438    }
4439
4440    if let Some(channel) = channel {
4441        channel_updated(
4442            &channel,
4443            &room,
4444            &session.peer,
4445            &*session.connection_pool().await,
4446        );
4447    }
4448
4449    {
4450        let pool = session.connection_pool().await;
4451        for canceled_user_id in canceled_calls_to_user_ids {
4452            for connection_id in pool.user_connection_ids(canceled_user_id) {
4453                session
4454                    .peer
4455                    .send(
4456                        connection_id,
4457                        proto::CallCanceled {
4458                            room_id: room_id.to_proto(),
4459                        },
4460                    )
4461                    .trace_err();
4462            }
4463            contacts_to_update.insert(canceled_user_id);
4464        }
4465    }
4466
4467    for contact_user_id in contacts_to_update {
4468        update_user_contacts(contact_user_id, session).await?;
4469    }
4470
4471    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4472        live_kit
4473            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4474            .await
4475            .trace_err();
4476
4477        if delete_livekit_room {
4478            live_kit.delete_room(livekit_room).await.trace_err();
4479        }
4480    }
4481
4482    Ok(())
4483}
4484
4485async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4486    let left_channel_buffers = session
4487        .db()
4488        .await
4489        .leave_channel_buffers(session.connection_id)
4490        .await?;
4491
4492    for left_buffer in left_channel_buffers {
4493        channel_buffer_updated(
4494            session.connection_id,
4495            left_buffer.connections,
4496            &proto::UpdateChannelBufferCollaborators {
4497                channel_id: left_buffer.channel_id.to_proto(),
4498                collaborators: left_buffer.collaborators,
4499            },
4500            &session.peer,
4501        );
4502    }
4503
4504    Ok(())
4505}
4506
4507fn project_left(project: &db::LeftProject, session: &Session) {
4508    for connection_id in &project.connection_ids {
4509        if project.should_unshare {
4510            session
4511                .peer
4512                .send(
4513                    *connection_id,
4514                    proto::UnshareProject {
4515                        project_id: project.id.to_proto(),
4516                    },
4517                )
4518                .trace_err();
4519        } else {
4520            session
4521                .peer
4522                .send(
4523                    *connection_id,
4524                    proto::RemoveProjectCollaborator {
4525                        project_id: project.id.to_proto(),
4526                        peer_id: Some(session.connection_id.into()),
4527                    },
4528                )
4529                .trace_err();
4530        }
4531    }
4532}
4533
4534pub trait ResultExt {
4535    type Ok;
4536
4537    fn trace_err(self) -> Option<Self::Ok>;
4538}
4539
4540impl<T, E> ResultExt for Result<T, E>
4541where
4542    E: std::fmt::Debug,
4543{
4544    type Ok = T;
4545
4546    #[track_caller]
4547    fn trace_err(self) -> Option<T> {
4548        match self {
4549            Ok(value) => Some(value),
4550            Err(error) => {
4551                tracing::error!("{:?}", error);
4552                None
4553            }
4554        }
4555    }
4556}