rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::db::billing_subscription::SubscriptionKind;
   5use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   6use crate::{
   7    AppState, Error, Result, auth,
   8    db::{
   9        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  10        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  11        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  12        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15};
  16use anyhow::{Context as _, anyhow, bail};
  17use async_tungstenite::tungstenite::{
  18    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  19};
  20use axum::{
  21    Extension, Router, TypedHeader,
  22    body::Body,
  23    extract::{
  24        ConnectInfo, WebSocketUpgrade,
  25        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use reqwest_client::ReqwestClient;
  38use rpc::proto::split_repository_update;
  39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  40use util::maybe;
  41
  42use futures::{
  43    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  44    stream::FuturesUnordered,
  45};
  46use prometheus::{IntGauge, register_int_gauge};
  47use rpc::{
  48    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        Arc, OnceLock,
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{MutexGuard, Semaphore, watch};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    Instrument,
  75    field::{self},
  76    info_span, instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    /// The GeoIP country code for the user.
 137    #[allow(unused)]
 138    geoip_country_code: Option<String>,
 139    system_id: Option<String>,
 140    _executor: Executor,
 141}
 142
 143impl Session {
 144    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 145        #[cfg(test)]
 146        tokio::task::yield_now().await;
 147        let guard = self.db.lock().await;
 148        #[cfg(test)]
 149        tokio::task::yield_now().await;
 150        guard
 151    }
 152
 153    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.connection_pool.lock();
 157        ConnectionPoolGuard {
 158            guard,
 159            _not_send: PhantomData,
 160        }
 161    }
 162
 163    fn is_staff(&self) -> bool {
 164        match &self.principal {
 165            Principal::User(user) => user.admin,
 166            Principal::Impersonated { .. } => true,
 167        }
 168    }
 169
 170    pub async fn has_llm_subscription(
 171        &self,
 172        db: &MutexGuard<'_, DbHandle>,
 173    ) -> anyhow::Result<bool> {
 174        if self.is_staff() {
 175            return Ok(true);
 176        }
 177
 178        let user_id = self.user_id();
 179
 180        Ok(db.has_active_billing_subscription(user_id).await?)
 181    }
 182
 183    pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 184        if self.is_staff() {
 185            return Ok(proto::Plan::ZedPro);
 186        }
 187
 188        let user_id = self.user_id();
 189
 190        let subscription = db.get_active_billing_subscription(user_id).await?;
 191        let subscription_kind = subscription.and_then(|subscription| subscription.kind);
 192
 193        let plan = if let Some(subscription_kind) = subscription_kind {
 194            match subscription_kind {
 195                SubscriptionKind::ZedPro => proto::Plan::ZedPro,
 196                SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
 197                SubscriptionKind::ZedFree => proto::Plan::Free,
 198            }
 199        } else {
 200            proto::Plan::Free
 201        };
 202
 203        Ok(plan)
 204    }
 205
 206    fn user_id(&self) -> UserId {
 207        match &self.principal {
 208            Principal::User(user) => user.id,
 209            Principal::Impersonated { user, .. } => user.id,
 210        }
 211    }
 212
 213    pub fn email(&self) -> Option<String> {
 214        match &self.principal {
 215            Principal::User(user) => user.email_address.clone(),
 216            Principal::Impersonated { user, .. } => user.email_address.clone(),
 217        }
 218    }
 219}
 220
 221impl Debug for Session {
 222    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 223        let mut result = f.debug_struct("Session");
 224        match &self.principal {
 225            Principal::User(user) => {
 226                result.field("user", &user.github_login);
 227            }
 228            Principal::Impersonated { user, admin } => {
 229                result.field("user", &user.github_login);
 230                result.field("impersonator", &admin.github_login);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct DbHandle(Arc<Database>);
 238
 239impl Deref for DbHandle {
 240    type Target = Database;
 241
 242    fn deref(&self) -> &Self::Target {
 243        self.0.as_ref()
 244    }
 245}
 246
 247pub struct Server {
 248    id: parking_lot::Mutex<ServerId>,
 249    peer: Arc<Peer>,
 250    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 251    app_state: Arc<AppState>,
 252    handlers: HashMap<TypeId, MessageHandler>,
 253    teardown: watch::Sender<bool>,
 254}
 255
 256pub(crate) struct ConnectionPoolGuard<'a> {
 257    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 258    _not_send: PhantomData<Rc<()>>,
 259}
 260
 261#[derive(Serialize)]
 262pub struct ServerSnapshot<'a> {
 263    peer: &'a Peer,
 264    #[serde(serialize_with = "serialize_deref")]
 265    connection_pool: ConnectionPoolGuard<'a>,
 266}
 267
 268pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 269where
 270    S: Serializer,
 271    T: Deref<Target = U>,
 272    U: Serialize,
 273{
 274    Serialize::serialize(value.deref(), serializer)
 275}
 276
 277impl Server {
 278    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 279        let mut server = Self {
 280            id: parking_lot::Mutex::new(id),
 281            peer: Peer::new(id.0 as u32),
 282            app_state: app_state.clone(),
 283            connection_pool: Default::default(),
 284            handlers: Default::default(),
 285            teardown: watch::channel(false).0,
 286        };
 287
 288        server
 289            .add_request_handler(ping)
 290            .add_request_handler(create_room)
 291            .add_request_handler(join_room)
 292            .add_request_handler(rejoin_room)
 293            .add_request_handler(leave_room)
 294            .add_request_handler(set_room_participant_role)
 295            .add_request_handler(call)
 296            .add_request_handler(cancel_call)
 297            .add_message_handler(decline_call)
 298            .add_request_handler(update_participant_location)
 299            .add_request_handler(share_project)
 300            .add_message_handler(unshare_project)
 301            .add_request_handler(join_project)
 302            .add_message_handler(leave_project)
 303            .add_request_handler(update_project)
 304            .add_request_handler(update_worktree)
 305            .add_request_handler(update_repository)
 306            .add_request_handler(remove_repository)
 307            .add_message_handler(start_language_server)
 308            .add_message_handler(update_language_server)
 309            .add_message_handler(update_diagnostic_summary)
 310            .add_message_handler(update_worktree_settings)
 311            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 313            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 315            .add_request_handler(forward_find_search_candidates_request)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 318            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 319            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 320            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 321            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 322            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 323            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 324            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 325            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 326            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 327            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 328            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 329            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 330            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 331            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 332            .add_request_handler(
 333                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 334            )
 335            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 336            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 337            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 338            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 339            .add_request_handler(
 340                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 341            )
 342            .add_request_handler(
 343                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 344            )
 345            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 346            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 347            .add_request_handler(
 348                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 349            )
 350            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 355            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 356            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 357            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 358            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 359            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 360            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 362            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 363            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 366            .add_request_handler(
 367                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 368            )
 369            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 370            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 371            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 372            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 373            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 374            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 375            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 376            .add_message_handler(create_buffer_for_peer)
 377            .add_request_handler(update_buffer)
 378            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 379            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 380            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 381            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 384            .add_request_handler(get_users)
 385            .add_request_handler(fuzzy_search_users)
 386            .add_request_handler(request_contact)
 387            .add_request_handler(remove_contact)
 388            .add_request_handler(respond_to_contact_request)
 389            .add_message_handler(subscribe_to_channels)
 390            .add_request_handler(create_channel)
 391            .add_request_handler(delete_channel)
 392            .add_request_handler(invite_channel_member)
 393            .add_request_handler(remove_channel_member)
 394            .add_request_handler(set_channel_member_role)
 395            .add_request_handler(set_channel_visibility)
 396            .add_request_handler(rename_channel)
 397            .add_request_handler(join_channel_buffer)
 398            .add_request_handler(leave_channel_buffer)
 399            .add_message_handler(update_channel_buffer)
 400            .add_request_handler(rejoin_channel_buffers)
 401            .add_request_handler(get_channel_members)
 402            .add_request_handler(respond_to_channel_invite)
 403            .add_request_handler(join_channel)
 404            .add_request_handler(join_channel_chat)
 405            .add_message_handler(leave_channel_chat)
 406            .add_request_handler(send_channel_message)
 407            .add_request_handler(remove_channel_message)
 408            .add_request_handler(update_channel_message)
 409            .add_request_handler(get_channel_messages)
 410            .add_request_handler(get_channel_messages_by_id)
 411            .add_request_handler(get_notifications)
 412            .add_request_handler(mark_notification_as_read)
 413            .add_request_handler(move_channel)
 414            .add_request_handler(follow)
 415            .add_message_handler(unfollow)
 416            .add_message_handler(update_followers)
 417            .add_request_handler(get_private_user_info)
 418            .add_request_handler(get_llm_api_token)
 419            .add_request_handler(accept_terms_of_service)
 420            .add_message_handler(acknowledge_channel_message)
 421            .add_message_handler(acknowledge_buffer_version)
 422            .add_request_handler(get_supermaven_api_key)
 423            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 424            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 425            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 426            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 427            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 428            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 429            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 430            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 431            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 432            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 433            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 434            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 435            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 436            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 437            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 438            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 439            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 440            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 441            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 442            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 443            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 444            .add_message_handler(update_context);
 445
 446        Arc::new(server)
 447    }
 448
 449    pub async fn start(&self) -> Result<()> {
 450        let server_id = *self.id.lock();
 451        let app_state = self.app_state.clone();
 452        let peer = self.peer.clone();
 453        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 454        let pool = self.connection_pool.clone();
 455        let livekit_client = self.app_state.livekit_client.clone();
 456
 457        let span = info_span!("start server");
 458        self.app_state.executor.spawn_detached(
 459            async move {
 460                tracing::info!("waiting for cleanup timeout");
 461                timeout.await;
 462                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 463                if let Some((room_ids, channel_ids)) = app_state
 464                    .db
 465                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 466                    .await
 467                    .trace_err()
 468                {
 469                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 470                    tracing::info!(
 471                        stale_channel_buffer_count = channel_ids.len(),
 472                        "retrieved stale channel buffers"
 473                    );
 474
 475                    for channel_id in channel_ids {
 476                        if let Some(refreshed_channel_buffer) = app_state
 477                            .db
 478                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 479                            .await
 480                            .trace_err()
 481                        {
 482                            for connection_id in refreshed_channel_buffer.connection_ids {
 483                                peer.send(
 484                                    connection_id,
 485                                    proto::UpdateChannelBufferCollaborators {
 486                                        channel_id: channel_id.to_proto(),
 487                                        collaborators: refreshed_channel_buffer
 488                                            .collaborators
 489                                            .clone(),
 490                                    },
 491                                )
 492                                .trace_err();
 493                            }
 494                        }
 495                    }
 496
 497                    for room_id in room_ids {
 498                        let mut contacts_to_update = HashSet::default();
 499                        let mut canceled_calls_to_user_ids = Vec::new();
 500                        let mut livekit_room = String::new();
 501                        let mut delete_livekit_room = false;
 502
 503                        if let Some(mut refreshed_room) = app_state
 504                            .db
 505                            .clear_stale_room_participants(room_id, server_id)
 506                            .await
 507                            .trace_err()
 508                        {
 509                            tracing::info!(
 510                                room_id = room_id.0,
 511                                new_participant_count = refreshed_room.room.participants.len(),
 512                                "refreshed room"
 513                            );
 514                            room_updated(&refreshed_room.room, &peer);
 515                            if let Some(channel) = refreshed_room.channel.as_ref() {
 516                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 517                            }
 518                            contacts_to_update
 519                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 520                            contacts_to_update
 521                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 522                            canceled_calls_to_user_ids =
 523                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 524                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 525                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 526                        }
 527
 528                        {
 529                            let pool = pool.lock();
 530                            for canceled_user_id in canceled_calls_to_user_ids {
 531                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 532                                    peer.send(
 533                                        connection_id,
 534                                        proto::CallCanceled {
 535                                            room_id: room_id.to_proto(),
 536                                        },
 537                                    )
 538                                    .trace_err();
 539                                }
 540                            }
 541                        }
 542
 543                        for user_id in contacts_to_update {
 544                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 545                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 546                            if let Some((busy, contacts)) = busy.zip(contacts) {
 547                                let pool = pool.lock();
 548                                let updated_contact = contact_for_user(user_id, busy, &pool);
 549                                for contact in contacts {
 550                                    if let db::Contact::Accepted {
 551                                        user_id: contact_user_id,
 552                                        ..
 553                                    } = contact
 554                                    {
 555                                        for contact_conn_id in
 556                                            pool.user_connection_ids(contact_user_id)
 557                                        {
 558                                            peer.send(
 559                                                contact_conn_id,
 560                                                proto::UpdateContacts {
 561                                                    contacts: vec![updated_contact.clone()],
 562                                                    remove_contacts: Default::default(),
 563                                                    incoming_requests: Default::default(),
 564                                                    remove_incoming_requests: Default::default(),
 565                                                    outgoing_requests: Default::default(),
 566                                                    remove_outgoing_requests: Default::default(),
 567                                                },
 568                                            )
 569                                            .trace_err();
 570                                        }
 571                                    }
 572                                }
 573                            }
 574                        }
 575
 576                        if let Some(live_kit) = livekit_client.as_ref() {
 577                            if delete_livekit_room {
 578                                live_kit.delete_room(livekit_room).await.trace_err();
 579                            }
 580                        }
 581                    }
 582                }
 583
 584                app_state
 585                    .db
 586                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 587                    .await
 588                    .trace_err();
 589            }
 590            .instrument(span),
 591        );
 592        Ok(())
 593    }
 594
 595    pub fn teardown(&self) {
 596        self.peer.teardown();
 597        self.connection_pool.lock().reset();
 598        let _ = self.teardown.send(true);
 599    }
 600
 601    #[cfg(test)]
 602    pub fn reset(&self, id: ServerId) {
 603        self.teardown();
 604        *self.id.lock() = id;
 605        self.peer.reset(id.0 as u32);
 606        let _ = self.teardown.send(false);
 607    }
 608
 609    #[cfg(test)]
 610    pub fn id(&self) -> ServerId {
 611        *self.id.lock()
 612    }
 613
 614    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 615    where
 616        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 617        Fut: 'static + Send + Future<Output = Result<()>>,
 618        M: EnvelopedMessage,
 619    {
 620        let prev_handler = self.handlers.insert(
 621            TypeId::of::<M>(),
 622            Box::new(move |envelope, session| {
 623                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 624                let received_at = envelope.received_at;
 625                tracing::info!("message received");
 626                let start_time = Instant::now();
 627                let future = (handler)(*envelope, session);
 628                async move {
 629                    let result = future.await;
 630                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 631                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 632                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 633                    let payload_type = M::NAME;
 634
 635                    match result {
 636                        Err(error) => {
 637                            tracing::error!(
 638                                ?error,
 639                                total_duration_ms,
 640                                processing_duration_ms,
 641                                queue_duration_ms,
 642                                payload_type,
 643                                "error handling message"
 644                            )
 645                        }
 646                        Ok(()) => tracing::info!(
 647                            total_duration_ms,
 648                            processing_duration_ms,
 649                            queue_duration_ms,
 650                            "finished handling message"
 651                        ),
 652                    }
 653                }
 654                .boxed()
 655            }),
 656        );
 657        if prev_handler.is_some() {
 658            panic!("registered a handler for the same message twice");
 659        }
 660        self
 661    }
 662
 663    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 664    where
 665        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 666        Fut: 'static + Send + Future<Output = Result<()>>,
 667        M: EnvelopedMessage,
 668    {
 669        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 670        self
 671    }
 672
 673    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 674    where
 675        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 676        Fut: Send + Future<Output = Result<()>>,
 677        M: RequestMessage,
 678    {
 679        let handler = Arc::new(handler);
 680        self.add_handler(move |envelope, session| {
 681            let receipt = envelope.receipt();
 682            let handler = handler.clone();
 683            async move {
 684                let peer = session.peer.clone();
 685                let responded = Arc::new(AtomicBool::default());
 686                let response = Response {
 687                    peer: peer.clone(),
 688                    responded: responded.clone(),
 689                    receipt,
 690                };
 691                match (handler)(envelope.payload, response, session).await {
 692                    Ok(()) => {
 693                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 694                            Ok(())
 695                        } else {
 696                            Err(anyhow!("handler did not send a response"))?
 697                        }
 698                    }
 699                    Err(error) => {
 700                        let proto_err = match &error {
 701                            Error::Internal(err) => err.to_proto(),
 702                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 703                        };
 704                        peer.respond_with_error(receipt, proto_err)?;
 705                        Err(error)
 706                    }
 707                }
 708            }
 709        })
 710    }
 711
 712    pub fn handle_connection(
 713        self: &Arc<Self>,
 714        connection: Connection,
 715        address: String,
 716        principal: Principal,
 717        zed_version: ZedVersion,
 718        geoip_country_code: Option<String>,
 719        system_id: Option<String>,
 720        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 721        executor: Executor,
 722    ) -> impl Future<Output = ()> + use<> {
 723        let this = self.clone();
 724        let span = info_span!("handle connection", %address,
 725            connection_id=field::Empty,
 726            user_id=field::Empty,
 727            login=field::Empty,
 728            impersonator=field::Empty,
 729            geoip_country_code=field::Empty
 730        );
 731        principal.update_span(&span);
 732        if let Some(country_code) = geoip_country_code.as_ref() {
 733            span.record("geoip_country_code", country_code);
 734        }
 735
 736        let mut teardown = self.teardown.subscribe();
 737        async move {
 738            if *teardown.borrow() {
 739                tracing::error!("server is tearing down");
 740                return
 741            }
 742            let (connection_id, handle_io, mut incoming_rx) = this
 743                .peer
 744                .add_connection(connection, {
 745                    let executor = executor.clone();
 746                    move |duration| executor.sleep(duration)
 747                });
 748            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 749
 750            tracing::info!("connection opened");
 751
 752            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 753            let http_client = match ReqwestClient::user_agent(&user_agent) {
 754                Ok(http_client) => Arc::new(http_client),
 755                Err(error) => {
 756                    tracing::error!(?error, "failed to create HTTP client");
 757                    return;
 758                }
 759            };
 760
 761            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 762                    supermaven_admin_api_key.to_string(),
 763                    http_client.clone(),
 764                )));
 765
 766            let session = Session {
 767                principal: principal.clone(),
 768                connection_id,
 769                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 770                peer: this.peer.clone(),
 771                connection_pool: this.connection_pool.clone(),
 772                app_state: this.app_state.clone(),
 773                geoip_country_code,
 774                system_id,
 775                _executor: executor.clone(),
 776                supermaven_client,
 777            };
 778
 779            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 780                tracing::error!(?error, "failed to send initial client update");
 781                return;
 782            }
 783
 784            let handle_io = handle_io.fuse();
 785            futures::pin_mut!(handle_io);
 786
 787            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 788            // This prevents deadlocks when e.g., client A performs a request to client B and
 789            // client B performs a request to client A. If both clients stop processing further
 790            // messages until their respective request completes, they won't have a chance to
 791            // respond to the other client's request and cause a deadlock.
 792            //
 793            // This arrangement ensures we will attempt to process earlier messages first, but fall
 794            // back to processing messages arrived later in the spirit of making progress.
 795            let mut foreground_message_handlers = FuturesUnordered::new();
 796            let concurrent_handlers = Arc::new(Semaphore::new(256));
 797            loop {
 798                let next_message = async {
 799                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 800                    let message = incoming_rx.next().await;
 801                    (permit, message)
 802                }.fuse();
 803                futures::pin_mut!(next_message);
 804                futures::select_biased! {
 805                    _ = teardown.changed().fuse() => return,
 806                    result = handle_io => {
 807                        if let Err(error) = result {
 808                            tracing::error!(?error, "error handling I/O");
 809                        }
 810                        break;
 811                    }
 812                    _ = foreground_message_handlers.next() => {}
 813                    next_message = next_message => {
 814                        let (permit, message) = next_message;
 815                        if let Some(message) = message {
 816                            let type_name = message.payload_type_name();
 817                            // note: we copy all the fields from the parent span so we can query them in the logs.
 818                            // (https://github.com/tokio-rs/tracing/issues/2670).
 819                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 820                                user_id=field::Empty,
 821                                login=field::Empty,
 822                                impersonator=field::Empty,
 823                            );
 824                            principal.update_span(&span);
 825                            let span_enter = span.enter();
 826                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 827                                let is_background = message.is_background();
 828                                let handle_message = (handler)(message, session.clone());
 829                                drop(span_enter);
 830
 831                                let handle_message = async move {
 832                                    handle_message.await;
 833                                    drop(permit);
 834                                }.instrument(span);
 835                                if is_background {
 836                                    executor.spawn_detached(handle_message);
 837                                } else {
 838                                    foreground_message_handlers.push(handle_message);
 839                                }
 840                            } else {
 841                                tracing::error!("no message handler");
 842                            }
 843                        } else {
 844                            tracing::info!("connection closed");
 845                            break;
 846                        }
 847                    }
 848                }
 849            }
 850
 851            drop(foreground_message_handlers);
 852            tracing::info!("signing out");
 853            if let Err(error) = connection_lost(session, teardown, executor).await {
 854                tracing::error!(?error, "error signing out");
 855            }
 856
 857        }.instrument(span)
 858    }
 859
 860    async fn send_initial_client_update(
 861        &self,
 862        connection_id: ConnectionId,
 863        principal: &Principal,
 864        zed_version: ZedVersion,
 865        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 866        session: &Session,
 867    ) -> Result<()> {
 868        self.peer.send(
 869            connection_id,
 870            proto::Hello {
 871                peer_id: Some(connection_id.into()),
 872            },
 873        )?;
 874        tracing::info!("sent hello message");
 875        if let Some(send_connection_id) = send_connection_id.take() {
 876            let _ = send_connection_id.send(connection_id);
 877        }
 878
 879        match principal {
 880            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 881                if !user.connected_once {
 882                    self.peer.send(connection_id, proto::ShowContacts {})?;
 883                    self.app_state
 884                        .db
 885                        .set_user_connected_once(user.id, true)
 886                        .await?;
 887                }
 888
 889                update_user_plan(user.id, session).await?;
 890
 891                let contacts = self.app_state.db.get_contacts(user.id).await?;
 892
 893                {
 894                    let mut pool = self.connection_pool.lock();
 895                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 896                    self.peer.send(
 897                        connection_id,
 898                        build_initial_contacts_update(contacts, &pool),
 899                    )?;
 900                }
 901
 902                if should_auto_subscribe_to_channels(zed_version) {
 903                    subscribe_user_to_channels(user.id, session).await?;
 904                }
 905
 906                if let Some(incoming_call) =
 907                    self.app_state.db.incoming_call_for_user(user.id).await?
 908                {
 909                    self.peer.send(connection_id, incoming_call)?;
 910                }
 911
 912                update_user_contacts(user.id, session).await?;
 913            }
 914        }
 915
 916        Ok(())
 917    }
 918
 919    pub async fn invite_code_redeemed(
 920        self: &Arc<Self>,
 921        inviter_id: UserId,
 922        invitee_id: UserId,
 923    ) -> Result<()> {
 924        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 925            if let Some(code) = &user.invite_code {
 926                let pool = self.connection_pool.lock();
 927                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 928                for connection_id in pool.user_connection_ids(inviter_id) {
 929                    self.peer.send(
 930                        connection_id,
 931                        proto::UpdateContacts {
 932                            contacts: vec![invitee_contact.clone()],
 933                            ..Default::default()
 934                        },
 935                    )?;
 936                    self.peer.send(
 937                        connection_id,
 938                        proto::UpdateInviteInfo {
 939                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 940                            count: user.invite_count as u32,
 941                        },
 942                    )?;
 943                }
 944            }
 945        }
 946        Ok(())
 947    }
 948
 949    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 950        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 951            if let Some(invite_code) = &user.invite_code {
 952                let pool = self.connection_pool.lock();
 953                for connection_id in pool.user_connection_ids(user_id) {
 954                    self.peer.send(
 955                        connection_id,
 956                        proto::UpdateInviteInfo {
 957                            url: format!(
 958                                "{}{}",
 959                                self.app_state.config.invite_link_prefix, invite_code
 960                            ),
 961                            count: user.invite_count as u32,
 962                        },
 963                    )?;
 964                }
 965            }
 966        }
 967        Ok(())
 968    }
 969
 970    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 971        let pool = self.connection_pool.lock();
 972        for connection_id in pool.user_connection_ids(user_id) {
 973            self.peer
 974                .send(connection_id, proto::RefreshLlmToken {})
 975                .trace_err();
 976        }
 977    }
 978
 979    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 980        ServerSnapshot {
 981            connection_pool: ConnectionPoolGuard {
 982                guard: self.connection_pool.lock(),
 983                _not_send: PhantomData,
 984            },
 985            peer: &self.peer,
 986        }
 987    }
 988}
 989
 990impl Deref for ConnectionPoolGuard<'_> {
 991    type Target = ConnectionPool;
 992
 993    fn deref(&self) -> &Self::Target {
 994        &self.guard
 995    }
 996}
 997
 998impl DerefMut for ConnectionPoolGuard<'_> {
 999    fn deref_mut(&mut self) -> &mut Self::Target {
1000        &mut self.guard
1001    }
1002}
1003
1004impl Drop for ConnectionPoolGuard<'_> {
1005    fn drop(&mut self) {
1006        #[cfg(test)]
1007        self.check_invariants();
1008    }
1009}
1010
1011fn broadcast<F>(
1012    sender_id: Option<ConnectionId>,
1013    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1014    mut f: F,
1015) where
1016    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1017{
1018    for receiver_id in receiver_ids {
1019        if Some(receiver_id) != sender_id {
1020            if let Err(error) = f(receiver_id) {
1021                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1022            }
1023        }
1024    }
1025}
1026
1027pub struct ProtocolVersion(u32);
1028
1029impl Header for ProtocolVersion {
1030    fn name() -> &'static HeaderName {
1031        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1032        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1033    }
1034
1035    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1036    where
1037        Self: Sized,
1038        I: Iterator<Item = &'i axum::http::HeaderValue>,
1039    {
1040        let version = values
1041            .next()
1042            .ok_or_else(axum::headers::Error::invalid)?
1043            .to_str()
1044            .map_err(|_| axum::headers::Error::invalid())?
1045            .parse()
1046            .map_err(|_| axum::headers::Error::invalid())?;
1047        Ok(Self(version))
1048    }
1049
1050    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1051        values.extend([self.0.to_string().parse().unwrap()]);
1052    }
1053}
1054
1055pub struct AppVersionHeader(SemanticVersion);
1056impl Header for AppVersionHeader {
1057    fn name() -> &'static HeaderName {
1058        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1059        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1060    }
1061
1062    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1063    where
1064        Self: Sized,
1065        I: Iterator<Item = &'i axum::http::HeaderValue>,
1066    {
1067        let version = values
1068            .next()
1069            .ok_or_else(axum::headers::Error::invalid)?
1070            .to_str()
1071            .map_err(|_| axum::headers::Error::invalid())?
1072            .parse()
1073            .map_err(|_| axum::headers::Error::invalid())?;
1074        Ok(Self(version))
1075    }
1076
1077    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1078        values.extend([self.0.to_string().parse().unwrap()]);
1079    }
1080}
1081
1082pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1083    Router::new()
1084        .route("/rpc", get(handle_websocket_request))
1085        .layer(
1086            ServiceBuilder::new()
1087                .layer(Extension(server.app_state.clone()))
1088                .layer(middleware::from_fn(auth::validate_header)),
1089        )
1090        .route("/metrics", get(handle_metrics))
1091        .layer(Extension(server))
1092}
1093
1094pub async fn handle_websocket_request(
1095    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1096    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1097    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1098    Extension(server): Extension<Arc<Server>>,
1099    Extension(principal): Extension<Principal>,
1100    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1101    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1102    ws: WebSocketUpgrade,
1103) -> axum::response::Response {
1104    if protocol_version != rpc::PROTOCOL_VERSION {
1105        return (
1106            StatusCode::UPGRADE_REQUIRED,
1107            "client must be upgraded".to_string(),
1108        )
1109            .into_response();
1110    }
1111
1112    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1113        return (
1114            StatusCode::UPGRADE_REQUIRED,
1115            "no version header found".to_string(),
1116        )
1117            .into_response();
1118    };
1119
1120    if !version.can_collaborate() {
1121        return (
1122            StatusCode::UPGRADE_REQUIRED,
1123            "client must be upgraded".to_string(),
1124        )
1125            .into_response();
1126    }
1127
1128    let socket_address = socket_address.to_string();
1129    ws.on_upgrade(move |socket| {
1130        let socket = socket
1131            .map_ok(to_tungstenite_message)
1132            .err_into()
1133            .with(|message| async move { to_axum_message(message) });
1134        let connection = Connection::new(Box::pin(socket));
1135        async move {
1136            server
1137                .handle_connection(
1138                    connection,
1139                    socket_address,
1140                    principal,
1141                    version,
1142                    country_code_header.map(|header| header.to_string()),
1143                    system_id_header.map(|header| header.to_string()),
1144                    None,
1145                    Executor::Production,
1146                )
1147                .await;
1148        }
1149    })
1150}
1151
1152pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1153    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1154    let connections_metric = CONNECTIONS_METRIC
1155        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1156
1157    let connections = server
1158        .connection_pool
1159        .lock()
1160        .connections()
1161        .filter(|connection| !connection.admin)
1162        .count();
1163    connections_metric.set(connections as _);
1164
1165    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1166    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1167        register_int_gauge!(
1168            "shared_projects",
1169            "number of open projects with one or more guests"
1170        )
1171        .unwrap()
1172    });
1173
1174    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1175    shared_projects_metric.set(shared_projects as _);
1176
1177    let encoder = prometheus::TextEncoder::new();
1178    let metric_families = prometheus::gather();
1179    let encoded_metrics = encoder
1180        .encode_to_string(&metric_families)
1181        .map_err(|err| anyhow!("{}", err))?;
1182    Ok(encoded_metrics)
1183}
1184
1185#[instrument(err, skip(executor))]
1186async fn connection_lost(
1187    session: Session,
1188    mut teardown: watch::Receiver<bool>,
1189    executor: Executor,
1190) -> Result<()> {
1191    session.peer.disconnect(session.connection_id);
1192    session
1193        .connection_pool()
1194        .await
1195        .remove_connection(session.connection_id)?;
1196
1197    session
1198        .db()
1199        .await
1200        .connection_lost(session.connection_id)
1201        .await
1202        .trace_err();
1203
1204    futures::select_biased! {
1205        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1206
1207            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1208            leave_room_for_session(&session, session.connection_id).await.trace_err();
1209            leave_channel_buffers_for_session(&session)
1210                .await
1211                .trace_err();
1212
1213            if !session
1214                .connection_pool()
1215                .await
1216                .is_user_online(session.user_id())
1217            {
1218                let db = session.db().await;
1219                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1220                    room_updated(&room, &session.peer);
1221                }
1222            }
1223
1224            update_user_contacts(session.user_id(), &session).await?;
1225        },
1226        _ = teardown.changed().fuse() => {}
1227    }
1228
1229    Ok(())
1230}
1231
1232/// Acknowledges a ping from a client, used to keep the connection alive.
1233async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1234    response.send(proto::Ack {})?;
1235    Ok(())
1236}
1237
1238/// Creates a new room for calling (outside of channels)
1239async fn create_room(
1240    _request: proto::CreateRoom,
1241    response: Response<proto::CreateRoom>,
1242    session: Session,
1243) -> Result<()> {
1244    let livekit_room = nanoid::nanoid!(30);
1245
1246    let live_kit_connection_info = util::maybe!(async {
1247        let live_kit = session.app_state.livekit_client.as_ref();
1248        let live_kit = live_kit?;
1249        let user_id = session.user_id().to_string();
1250
1251        let token = live_kit
1252            .room_token(&livekit_room, &user_id.to_string())
1253            .trace_err()?;
1254
1255        Some(proto::LiveKitConnectionInfo {
1256            server_url: live_kit.url().into(),
1257            token,
1258            can_publish: true,
1259        })
1260    })
1261    .await;
1262
1263    let room = session
1264        .db()
1265        .await
1266        .create_room(session.user_id(), session.connection_id, &livekit_room)
1267        .await?;
1268
1269    response.send(proto::CreateRoomResponse {
1270        room: Some(room.clone()),
1271        live_kit_connection_info,
1272    })?;
1273
1274    update_user_contacts(session.user_id(), &session).await?;
1275    Ok(())
1276}
1277
1278/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1279async fn join_room(
1280    request: proto::JoinRoom,
1281    response: Response<proto::JoinRoom>,
1282    session: Session,
1283) -> Result<()> {
1284    let room_id = RoomId::from_proto(request.id);
1285
1286    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1287
1288    if let Some(channel_id) = channel_id {
1289        return join_channel_internal(channel_id, Box::new(response), session).await;
1290    }
1291
1292    let joined_room = {
1293        let room = session
1294            .db()
1295            .await
1296            .join_room(room_id, session.user_id(), session.connection_id)
1297            .await?;
1298        room_updated(&room.room, &session.peer);
1299        room.into_inner()
1300    };
1301
1302    for connection_id in session
1303        .connection_pool()
1304        .await
1305        .user_connection_ids(session.user_id())
1306    {
1307        session
1308            .peer
1309            .send(
1310                connection_id,
1311                proto::CallCanceled {
1312                    room_id: room_id.to_proto(),
1313                },
1314            )
1315            .trace_err();
1316    }
1317
1318    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1319    {
1320        live_kit
1321            .room_token(
1322                &joined_room.room.livekit_room,
1323                &session.user_id().to_string(),
1324            )
1325            .trace_err()
1326            .map(|token| proto::LiveKitConnectionInfo {
1327                server_url: live_kit.url().into(),
1328                token,
1329                can_publish: true,
1330            })
1331    } else {
1332        None
1333    };
1334
1335    response.send(proto::JoinRoomResponse {
1336        room: Some(joined_room.room),
1337        channel_id: None,
1338        live_kit_connection_info,
1339    })?;
1340
1341    update_user_contacts(session.user_id(), &session).await?;
1342    Ok(())
1343}
1344
1345/// Rejoin room is used to reconnect to a room after connection errors.
1346async fn rejoin_room(
1347    request: proto::RejoinRoom,
1348    response: Response<proto::RejoinRoom>,
1349    session: Session,
1350) -> Result<()> {
1351    let room;
1352    let channel;
1353    {
1354        let mut rejoined_room = session
1355            .db()
1356            .await
1357            .rejoin_room(request, session.user_id(), session.connection_id)
1358            .await?;
1359
1360        response.send(proto::RejoinRoomResponse {
1361            room: Some(rejoined_room.room.clone()),
1362            reshared_projects: rejoined_room
1363                .reshared_projects
1364                .iter()
1365                .map(|project| proto::ResharedProject {
1366                    id: project.id.to_proto(),
1367                    collaborators: project
1368                        .collaborators
1369                        .iter()
1370                        .map(|collaborator| collaborator.to_proto())
1371                        .collect(),
1372                })
1373                .collect(),
1374            rejoined_projects: rejoined_room
1375                .rejoined_projects
1376                .iter()
1377                .map(|rejoined_project| rejoined_project.to_proto())
1378                .collect(),
1379        })?;
1380        room_updated(&rejoined_room.room, &session.peer);
1381
1382        for project in &rejoined_room.reshared_projects {
1383            for collaborator in &project.collaborators {
1384                session
1385                    .peer
1386                    .send(
1387                        collaborator.connection_id,
1388                        proto::UpdateProjectCollaborator {
1389                            project_id: project.id.to_proto(),
1390                            old_peer_id: Some(project.old_connection_id.into()),
1391                            new_peer_id: Some(session.connection_id.into()),
1392                        },
1393                    )
1394                    .trace_err();
1395            }
1396
1397            broadcast(
1398                Some(session.connection_id),
1399                project
1400                    .collaborators
1401                    .iter()
1402                    .map(|collaborator| collaborator.connection_id),
1403                |connection_id| {
1404                    session.peer.forward_send(
1405                        session.connection_id,
1406                        connection_id,
1407                        proto::UpdateProject {
1408                            project_id: project.id.to_proto(),
1409                            worktrees: project.worktrees.clone(),
1410                        },
1411                    )
1412                },
1413            );
1414        }
1415
1416        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1417
1418        let rejoined_room = rejoined_room.into_inner();
1419
1420        room = rejoined_room.room;
1421        channel = rejoined_room.channel;
1422    }
1423
1424    if let Some(channel) = channel {
1425        channel_updated(
1426            &channel,
1427            &room,
1428            &session.peer,
1429            &*session.connection_pool().await,
1430        );
1431    }
1432
1433    update_user_contacts(session.user_id(), &session).await?;
1434    Ok(())
1435}
1436
1437fn notify_rejoined_projects(
1438    rejoined_projects: &mut Vec<RejoinedProject>,
1439    session: &Session,
1440) -> Result<()> {
1441    for project in rejoined_projects.iter() {
1442        for collaborator in &project.collaborators {
1443            session
1444                .peer
1445                .send(
1446                    collaborator.connection_id,
1447                    proto::UpdateProjectCollaborator {
1448                        project_id: project.id.to_proto(),
1449                        old_peer_id: Some(project.old_connection_id.into()),
1450                        new_peer_id: Some(session.connection_id.into()),
1451                    },
1452                )
1453                .trace_err();
1454        }
1455    }
1456
1457    for project in rejoined_projects {
1458        for worktree in mem::take(&mut project.worktrees) {
1459            // Stream this worktree's entries.
1460            let message = proto::UpdateWorktree {
1461                project_id: project.id.to_proto(),
1462                worktree_id: worktree.id,
1463                abs_path: worktree.abs_path.clone(),
1464                root_name: worktree.root_name,
1465                updated_entries: worktree.updated_entries,
1466                removed_entries: worktree.removed_entries,
1467                scan_id: worktree.scan_id,
1468                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1469                updated_repositories: worktree.updated_repositories,
1470                removed_repositories: worktree.removed_repositories,
1471            };
1472            for update in proto::split_worktree_update(message) {
1473                session.peer.send(session.connection_id, update)?;
1474            }
1475
1476            // Stream this worktree's diagnostics.
1477            for summary in worktree.diagnostic_summaries {
1478                session.peer.send(
1479                    session.connection_id,
1480                    proto::UpdateDiagnosticSummary {
1481                        project_id: project.id.to_proto(),
1482                        worktree_id: worktree.id,
1483                        summary: Some(summary),
1484                    },
1485                )?;
1486            }
1487
1488            for settings_file in worktree.settings_files {
1489                session.peer.send(
1490                    session.connection_id,
1491                    proto::UpdateWorktreeSettings {
1492                        project_id: project.id.to_proto(),
1493                        worktree_id: worktree.id,
1494                        path: settings_file.path,
1495                        content: Some(settings_file.content),
1496                        kind: Some(settings_file.kind.to_proto().into()),
1497                    },
1498                )?;
1499            }
1500        }
1501
1502        for repository in mem::take(&mut project.updated_repositories) {
1503            for update in split_repository_update(repository) {
1504                session.peer.send(session.connection_id, update)?;
1505            }
1506        }
1507
1508        for id in mem::take(&mut project.removed_repositories) {
1509            session.peer.send(
1510                session.connection_id,
1511                proto::RemoveRepository {
1512                    project_id: project.id.to_proto(),
1513                    id,
1514                },
1515            )?;
1516        }
1517    }
1518
1519    Ok(())
1520}
1521
1522/// leave room disconnects from the room.
1523async fn leave_room(
1524    _: proto::LeaveRoom,
1525    response: Response<proto::LeaveRoom>,
1526    session: Session,
1527) -> Result<()> {
1528    leave_room_for_session(&session, session.connection_id).await?;
1529    response.send(proto::Ack {})?;
1530    Ok(())
1531}
1532
1533/// Updates the permissions of someone else in the room.
1534async fn set_room_participant_role(
1535    request: proto::SetRoomParticipantRole,
1536    response: Response<proto::SetRoomParticipantRole>,
1537    session: Session,
1538) -> Result<()> {
1539    let user_id = UserId::from_proto(request.user_id);
1540    let role = ChannelRole::from(request.role());
1541
1542    let (livekit_room, can_publish) = {
1543        let room = session
1544            .db()
1545            .await
1546            .set_room_participant_role(
1547                session.user_id(),
1548                RoomId::from_proto(request.room_id),
1549                user_id,
1550                role,
1551            )
1552            .await?;
1553
1554        let livekit_room = room.livekit_room.clone();
1555        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1556        room_updated(&room, &session.peer);
1557        (livekit_room, can_publish)
1558    };
1559
1560    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1561        live_kit
1562            .update_participant(
1563                livekit_room.clone(),
1564                request.user_id.to_string(),
1565                livekit_api::proto::ParticipantPermission {
1566                    can_subscribe: true,
1567                    can_publish,
1568                    can_publish_data: can_publish,
1569                    hidden: false,
1570                    recorder: false,
1571                },
1572            )
1573            .await
1574            .trace_err();
1575    }
1576
1577    response.send(proto::Ack {})?;
1578    Ok(())
1579}
1580
1581/// Call someone else into the current room
1582async fn call(
1583    request: proto::Call,
1584    response: Response<proto::Call>,
1585    session: Session,
1586) -> Result<()> {
1587    let room_id = RoomId::from_proto(request.room_id);
1588    let calling_user_id = session.user_id();
1589    let calling_connection_id = session.connection_id;
1590    let called_user_id = UserId::from_proto(request.called_user_id);
1591    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1592    if !session
1593        .db()
1594        .await
1595        .has_contact(calling_user_id, called_user_id)
1596        .await?
1597    {
1598        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1599    }
1600
1601    let incoming_call = {
1602        let (room, incoming_call) = &mut *session
1603            .db()
1604            .await
1605            .call(
1606                room_id,
1607                calling_user_id,
1608                calling_connection_id,
1609                called_user_id,
1610                initial_project_id,
1611            )
1612            .await?;
1613        room_updated(room, &session.peer);
1614        mem::take(incoming_call)
1615    };
1616    update_user_contacts(called_user_id, &session).await?;
1617
1618    let mut calls = session
1619        .connection_pool()
1620        .await
1621        .user_connection_ids(called_user_id)
1622        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1623        .collect::<FuturesUnordered<_>>();
1624
1625    while let Some(call_response) = calls.next().await {
1626        match call_response.as_ref() {
1627            Ok(_) => {
1628                response.send(proto::Ack {})?;
1629                return Ok(());
1630            }
1631            Err(_) => {
1632                call_response.trace_err();
1633            }
1634        }
1635    }
1636
1637    {
1638        let room = session
1639            .db()
1640            .await
1641            .call_failed(room_id, called_user_id)
1642            .await?;
1643        room_updated(&room, &session.peer);
1644    }
1645    update_user_contacts(called_user_id, &session).await?;
1646
1647    Err(anyhow!("failed to ring user"))?
1648}
1649
1650/// Cancel an outgoing call.
1651async fn cancel_call(
1652    request: proto::CancelCall,
1653    response: Response<proto::CancelCall>,
1654    session: Session,
1655) -> Result<()> {
1656    let called_user_id = UserId::from_proto(request.called_user_id);
1657    let room_id = RoomId::from_proto(request.room_id);
1658    {
1659        let room = session
1660            .db()
1661            .await
1662            .cancel_call(room_id, session.connection_id, called_user_id)
1663            .await?;
1664        room_updated(&room, &session.peer);
1665    }
1666
1667    for connection_id in session
1668        .connection_pool()
1669        .await
1670        .user_connection_ids(called_user_id)
1671    {
1672        session
1673            .peer
1674            .send(
1675                connection_id,
1676                proto::CallCanceled {
1677                    room_id: room_id.to_proto(),
1678                },
1679            )
1680            .trace_err();
1681    }
1682    response.send(proto::Ack {})?;
1683
1684    update_user_contacts(called_user_id, &session).await?;
1685    Ok(())
1686}
1687
1688/// Decline an incoming call.
1689async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1690    let room_id = RoomId::from_proto(message.room_id);
1691    {
1692        let room = session
1693            .db()
1694            .await
1695            .decline_call(Some(room_id), session.user_id())
1696            .await?
1697            .ok_or_else(|| anyhow!("failed to decline call"))?;
1698        room_updated(&room, &session.peer);
1699    }
1700
1701    for connection_id in session
1702        .connection_pool()
1703        .await
1704        .user_connection_ids(session.user_id())
1705    {
1706        session
1707            .peer
1708            .send(
1709                connection_id,
1710                proto::CallCanceled {
1711                    room_id: room_id.to_proto(),
1712                },
1713            )
1714            .trace_err();
1715    }
1716    update_user_contacts(session.user_id(), &session).await?;
1717    Ok(())
1718}
1719
1720/// Updates other participants in the room with your current location.
1721async fn update_participant_location(
1722    request: proto::UpdateParticipantLocation,
1723    response: Response<proto::UpdateParticipantLocation>,
1724    session: Session,
1725) -> Result<()> {
1726    let room_id = RoomId::from_proto(request.room_id);
1727    let location = request
1728        .location
1729        .ok_or_else(|| anyhow!("invalid location"))?;
1730
1731    let db = session.db().await;
1732    let room = db
1733        .update_room_participant_location(room_id, session.connection_id, location)
1734        .await?;
1735
1736    room_updated(&room, &session.peer);
1737    response.send(proto::Ack {})?;
1738    Ok(())
1739}
1740
1741/// Share a project into the room.
1742async fn share_project(
1743    request: proto::ShareProject,
1744    response: Response<proto::ShareProject>,
1745    session: Session,
1746) -> Result<()> {
1747    let (project_id, room) = &*session
1748        .db()
1749        .await
1750        .share_project(
1751            RoomId::from_proto(request.room_id),
1752            session.connection_id,
1753            &request.worktrees,
1754            request.is_ssh_project,
1755        )
1756        .await?;
1757    response.send(proto::ShareProjectResponse {
1758        project_id: project_id.to_proto(),
1759    })?;
1760    room_updated(room, &session.peer);
1761
1762    Ok(())
1763}
1764
1765/// Unshare a project from the room.
1766async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1767    let project_id = ProjectId::from_proto(message.project_id);
1768    unshare_project_internal(project_id, session.connection_id, &session).await
1769}
1770
1771async fn unshare_project_internal(
1772    project_id: ProjectId,
1773    connection_id: ConnectionId,
1774    session: &Session,
1775) -> Result<()> {
1776    let delete = {
1777        let room_guard = session
1778            .db()
1779            .await
1780            .unshare_project(project_id, connection_id)
1781            .await?;
1782
1783        let (delete, room, guest_connection_ids) = &*room_guard;
1784
1785        let message = proto::UnshareProject {
1786            project_id: project_id.to_proto(),
1787        };
1788
1789        broadcast(
1790            Some(connection_id),
1791            guest_connection_ids.iter().copied(),
1792            |conn_id| session.peer.send(conn_id, message.clone()),
1793        );
1794        if let Some(room) = room {
1795            room_updated(room, &session.peer);
1796        }
1797
1798        *delete
1799    };
1800
1801    if delete {
1802        let db = session.db().await;
1803        db.delete_project(project_id).await?;
1804    }
1805
1806    Ok(())
1807}
1808
1809/// Join someone elses shared project.
1810async fn join_project(
1811    request: proto::JoinProject,
1812    response: Response<proto::JoinProject>,
1813    session: Session,
1814) -> Result<()> {
1815    let project_id = ProjectId::from_proto(request.project_id);
1816
1817    tracing::info!(%project_id, "join project");
1818
1819    let db = session.db().await;
1820    let (project, replica_id) = &mut *db
1821        .join_project(project_id, session.connection_id, session.user_id())
1822        .await?;
1823    drop(db);
1824    tracing::info!(%project_id, "join remote project");
1825    join_project_internal(response, session, project, replica_id)
1826}
1827
1828trait JoinProjectInternalResponse {
1829    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1830}
1831impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1832    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1833        Response::<proto::JoinProject>::send(self, result)
1834    }
1835}
1836
1837fn join_project_internal(
1838    response: impl JoinProjectInternalResponse,
1839    session: Session,
1840    project: &mut Project,
1841    replica_id: &ReplicaId,
1842) -> Result<()> {
1843    let collaborators = project
1844        .collaborators
1845        .iter()
1846        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1847        .map(|collaborator| collaborator.to_proto())
1848        .collect::<Vec<_>>();
1849    let project_id = project.id;
1850    let guest_user_id = session.user_id();
1851
1852    let worktrees = project
1853        .worktrees
1854        .iter()
1855        .map(|(id, worktree)| proto::WorktreeMetadata {
1856            id: *id,
1857            root_name: worktree.root_name.clone(),
1858            visible: worktree.visible,
1859            abs_path: worktree.abs_path.clone(),
1860        })
1861        .collect::<Vec<_>>();
1862
1863    let add_project_collaborator = proto::AddProjectCollaborator {
1864        project_id: project_id.to_proto(),
1865        collaborator: Some(proto::Collaborator {
1866            peer_id: Some(session.connection_id.into()),
1867            replica_id: replica_id.0 as u32,
1868            user_id: guest_user_id.to_proto(),
1869            is_host: false,
1870        }),
1871    };
1872
1873    for collaborator in &collaborators {
1874        session
1875            .peer
1876            .send(
1877                collaborator.peer_id.unwrap().into(),
1878                add_project_collaborator.clone(),
1879            )
1880            .trace_err();
1881    }
1882
1883    // First, we send the metadata associated with each worktree.
1884    response.send(proto::JoinProjectResponse {
1885        project_id: project.id.0 as u64,
1886        worktrees: worktrees.clone(),
1887        replica_id: replica_id.0 as u32,
1888        collaborators: collaborators.clone(),
1889        language_servers: project.language_servers.clone(),
1890        role: project.role.into(),
1891    })?;
1892
1893    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1894        // Stream this worktree's entries.
1895        let message = proto::UpdateWorktree {
1896            project_id: project_id.to_proto(),
1897            worktree_id,
1898            abs_path: worktree.abs_path.clone(),
1899            root_name: worktree.root_name,
1900            updated_entries: worktree.entries,
1901            removed_entries: Default::default(),
1902            scan_id: worktree.scan_id,
1903            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1904            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1905            removed_repositories: Default::default(),
1906        };
1907        for update in proto::split_worktree_update(message) {
1908            session.peer.send(session.connection_id, update.clone())?;
1909        }
1910
1911        // Stream this worktree's diagnostics.
1912        for summary in worktree.diagnostic_summaries {
1913            session.peer.send(
1914                session.connection_id,
1915                proto::UpdateDiagnosticSummary {
1916                    project_id: project_id.to_proto(),
1917                    worktree_id: worktree.id,
1918                    summary: Some(summary),
1919                },
1920            )?;
1921        }
1922
1923        for settings_file in worktree.settings_files {
1924            session.peer.send(
1925                session.connection_id,
1926                proto::UpdateWorktreeSettings {
1927                    project_id: project_id.to_proto(),
1928                    worktree_id: worktree.id,
1929                    path: settings_file.path,
1930                    content: Some(settings_file.content),
1931                    kind: Some(settings_file.kind.to_proto() as i32),
1932                },
1933            )?;
1934        }
1935    }
1936
1937    for repository in mem::take(&mut project.repositories) {
1938        for update in split_repository_update(repository) {
1939            session.peer.send(session.connection_id, update)?;
1940        }
1941    }
1942
1943    for language_server in &project.language_servers {
1944        session.peer.send(
1945            session.connection_id,
1946            proto::UpdateLanguageServer {
1947                project_id: project_id.to_proto(),
1948                language_server_id: language_server.id,
1949                variant: Some(
1950                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1951                        proto::LspDiskBasedDiagnosticsUpdated {},
1952                    ),
1953                ),
1954            },
1955        )?;
1956    }
1957
1958    Ok(())
1959}
1960
1961/// Leave someone elses shared project.
1962async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1963    let sender_id = session.connection_id;
1964    let project_id = ProjectId::from_proto(request.project_id);
1965    let db = session.db().await;
1966
1967    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1968    tracing::info!(
1969        %project_id,
1970        "leave project"
1971    );
1972
1973    project_left(project, &session);
1974    if let Some(room) = room {
1975        room_updated(room, &session.peer);
1976    }
1977
1978    Ok(())
1979}
1980
1981/// Updates other participants with changes to the project
1982async fn update_project(
1983    request: proto::UpdateProject,
1984    response: Response<proto::UpdateProject>,
1985    session: Session,
1986) -> Result<()> {
1987    let project_id = ProjectId::from_proto(request.project_id);
1988    let (room, guest_connection_ids) = &*session
1989        .db()
1990        .await
1991        .update_project(project_id, session.connection_id, &request.worktrees)
1992        .await?;
1993    broadcast(
1994        Some(session.connection_id),
1995        guest_connection_ids.iter().copied(),
1996        |connection_id| {
1997            session
1998                .peer
1999                .forward_send(session.connection_id, connection_id, request.clone())
2000        },
2001    );
2002    if let Some(room) = room {
2003        room_updated(room, &session.peer);
2004    }
2005    response.send(proto::Ack {})?;
2006
2007    Ok(())
2008}
2009
2010/// Updates other participants with changes to the worktree
2011async fn update_worktree(
2012    request: proto::UpdateWorktree,
2013    response: Response<proto::UpdateWorktree>,
2014    session: Session,
2015) -> Result<()> {
2016    let guest_connection_ids = session
2017        .db()
2018        .await
2019        .update_worktree(&request, session.connection_id)
2020        .await?;
2021
2022    broadcast(
2023        Some(session.connection_id),
2024        guest_connection_ids.iter().copied(),
2025        |connection_id| {
2026            session
2027                .peer
2028                .forward_send(session.connection_id, connection_id, request.clone())
2029        },
2030    );
2031    response.send(proto::Ack {})?;
2032    Ok(())
2033}
2034
2035async fn update_repository(
2036    request: proto::UpdateRepository,
2037    response: Response<proto::UpdateRepository>,
2038    session: Session,
2039) -> Result<()> {
2040    let guest_connection_ids = session
2041        .db()
2042        .await
2043        .update_repository(&request, session.connection_id)
2044        .await?;
2045
2046    broadcast(
2047        Some(session.connection_id),
2048        guest_connection_ids.iter().copied(),
2049        |connection_id| {
2050            session
2051                .peer
2052                .forward_send(session.connection_id, connection_id, request.clone())
2053        },
2054    );
2055    response.send(proto::Ack {})?;
2056    Ok(())
2057}
2058
2059async fn remove_repository(
2060    request: proto::RemoveRepository,
2061    response: Response<proto::RemoveRepository>,
2062    session: Session,
2063) -> Result<()> {
2064    let guest_connection_ids = session
2065        .db()
2066        .await
2067        .remove_repository(&request, session.connection_id)
2068        .await?;
2069
2070    broadcast(
2071        Some(session.connection_id),
2072        guest_connection_ids.iter().copied(),
2073        |connection_id| {
2074            session
2075                .peer
2076                .forward_send(session.connection_id, connection_id, request.clone())
2077        },
2078    );
2079    response.send(proto::Ack {})?;
2080    Ok(())
2081}
2082
2083/// Updates other participants with changes to the diagnostics
2084async fn update_diagnostic_summary(
2085    message: proto::UpdateDiagnosticSummary,
2086    session: Session,
2087) -> Result<()> {
2088    let guest_connection_ids = session
2089        .db()
2090        .await
2091        .update_diagnostic_summary(&message, session.connection_id)
2092        .await?;
2093
2094    broadcast(
2095        Some(session.connection_id),
2096        guest_connection_ids.iter().copied(),
2097        |connection_id| {
2098            session
2099                .peer
2100                .forward_send(session.connection_id, connection_id, message.clone())
2101        },
2102    );
2103
2104    Ok(())
2105}
2106
2107/// Updates other participants with changes to the worktree settings
2108async fn update_worktree_settings(
2109    message: proto::UpdateWorktreeSettings,
2110    session: Session,
2111) -> Result<()> {
2112    let guest_connection_ids = session
2113        .db()
2114        .await
2115        .update_worktree_settings(&message, session.connection_id)
2116        .await?;
2117
2118    broadcast(
2119        Some(session.connection_id),
2120        guest_connection_ids.iter().copied(),
2121        |connection_id| {
2122            session
2123                .peer
2124                .forward_send(session.connection_id, connection_id, message.clone())
2125        },
2126    );
2127
2128    Ok(())
2129}
2130
2131/// Notify other participants that a language server has started.
2132async fn start_language_server(
2133    request: proto::StartLanguageServer,
2134    session: Session,
2135) -> Result<()> {
2136    let guest_connection_ids = session
2137        .db()
2138        .await
2139        .start_language_server(&request, session.connection_id)
2140        .await?;
2141
2142    broadcast(
2143        Some(session.connection_id),
2144        guest_connection_ids.iter().copied(),
2145        |connection_id| {
2146            session
2147                .peer
2148                .forward_send(session.connection_id, connection_id, request.clone())
2149        },
2150    );
2151    Ok(())
2152}
2153
2154/// Notify other participants that a language server has changed.
2155async fn update_language_server(
2156    request: proto::UpdateLanguageServer,
2157    session: Session,
2158) -> Result<()> {
2159    let project_id = ProjectId::from_proto(request.project_id);
2160    let project_connection_ids = session
2161        .db()
2162        .await
2163        .project_connection_ids(project_id, session.connection_id, true)
2164        .await?;
2165    broadcast(
2166        Some(session.connection_id),
2167        project_connection_ids.iter().copied(),
2168        |connection_id| {
2169            session
2170                .peer
2171                .forward_send(session.connection_id, connection_id, request.clone())
2172        },
2173    );
2174    Ok(())
2175}
2176
2177/// forward a project request to the host. These requests should be read only
2178/// as guests are allowed to send them.
2179async fn forward_read_only_project_request<T>(
2180    request: T,
2181    response: Response<T>,
2182    session: Session,
2183) -> Result<()>
2184where
2185    T: EntityMessage + RequestMessage,
2186{
2187    let project_id = ProjectId::from_proto(request.remote_entity_id());
2188    let host_connection_id = session
2189        .db()
2190        .await
2191        .host_for_read_only_project_request(project_id, session.connection_id)
2192        .await?;
2193    let payload = session
2194        .peer
2195        .forward_request(session.connection_id, host_connection_id, request)
2196        .await?;
2197    response.send(payload)?;
2198    Ok(())
2199}
2200
2201async fn forward_find_search_candidates_request(
2202    request: proto::FindSearchCandidates,
2203    response: Response<proto::FindSearchCandidates>,
2204    session: Session,
2205) -> Result<()> {
2206    let project_id = ProjectId::from_proto(request.remote_entity_id());
2207    let host_connection_id = session
2208        .db()
2209        .await
2210        .host_for_read_only_project_request(project_id, session.connection_id)
2211        .await?;
2212    let payload = session
2213        .peer
2214        .forward_request(session.connection_id, host_connection_id, request)
2215        .await?;
2216    response.send(payload)?;
2217    Ok(())
2218}
2219
2220/// forward a project request to the host. These requests are disallowed
2221/// for guests.
2222async fn forward_mutating_project_request<T>(
2223    request: T,
2224    response: Response<T>,
2225    session: Session,
2226) -> Result<()>
2227where
2228    T: EntityMessage + RequestMessage,
2229{
2230    let project_id = ProjectId::from_proto(request.remote_entity_id());
2231
2232    let host_connection_id = session
2233        .db()
2234        .await
2235        .host_for_mutating_project_request(project_id, session.connection_id)
2236        .await?;
2237    let payload = session
2238        .peer
2239        .forward_request(session.connection_id, host_connection_id, request)
2240        .await?;
2241    response.send(payload)?;
2242    Ok(())
2243}
2244
2245/// Notify other participants that a new buffer has been created
2246async fn create_buffer_for_peer(
2247    request: proto::CreateBufferForPeer,
2248    session: Session,
2249) -> Result<()> {
2250    session
2251        .db()
2252        .await
2253        .check_user_is_project_host(
2254            ProjectId::from_proto(request.project_id),
2255            session.connection_id,
2256        )
2257        .await?;
2258    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2259    session
2260        .peer
2261        .forward_send(session.connection_id, peer_id.into(), request)?;
2262    Ok(())
2263}
2264
2265/// Notify other participants that a buffer has been updated. This is
2266/// allowed for guests as long as the update is limited to selections.
2267async fn update_buffer(
2268    request: proto::UpdateBuffer,
2269    response: Response<proto::UpdateBuffer>,
2270    session: Session,
2271) -> Result<()> {
2272    let project_id = ProjectId::from_proto(request.project_id);
2273    let mut capability = Capability::ReadOnly;
2274
2275    for op in request.operations.iter() {
2276        match op.variant {
2277            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2278            Some(_) => capability = Capability::ReadWrite,
2279        }
2280    }
2281
2282    let host = {
2283        let guard = session
2284            .db()
2285            .await
2286            .connections_for_buffer_update(project_id, session.connection_id, capability)
2287            .await?;
2288
2289        let (host, guests) = &*guard;
2290
2291        broadcast(
2292            Some(session.connection_id),
2293            guests.clone(),
2294            |connection_id| {
2295                session
2296                    .peer
2297                    .forward_send(session.connection_id, connection_id, request.clone())
2298            },
2299        );
2300
2301        *host
2302    };
2303
2304    if host != session.connection_id {
2305        session
2306            .peer
2307            .forward_request(session.connection_id, host, request.clone())
2308            .await?;
2309    }
2310
2311    response.send(proto::Ack {})?;
2312    Ok(())
2313}
2314
2315async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2316    let project_id = ProjectId::from_proto(message.project_id);
2317
2318    let operation = message.operation.as_ref().context("invalid operation")?;
2319    let capability = match operation.variant.as_ref() {
2320        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2321            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2322                match buffer_op.variant {
2323                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2324                        Capability::ReadOnly
2325                    }
2326                    _ => Capability::ReadWrite,
2327                }
2328            } else {
2329                Capability::ReadWrite
2330            }
2331        }
2332        Some(_) => Capability::ReadWrite,
2333        None => Capability::ReadOnly,
2334    };
2335
2336    let guard = session
2337        .db()
2338        .await
2339        .connections_for_buffer_update(project_id, session.connection_id, capability)
2340        .await?;
2341
2342    let (host, guests) = &*guard;
2343
2344    broadcast(
2345        Some(session.connection_id),
2346        guests.iter().chain([host]).copied(),
2347        |connection_id| {
2348            session
2349                .peer
2350                .forward_send(session.connection_id, connection_id, message.clone())
2351        },
2352    );
2353
2354    Ok(())
2355}
2356
2357/// Notify other participants that a project has been updated.
2358async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2359    request: T,
2360    session: Session,
2361) -> Result<()> {
2362    let project_id = ProjectId::from_proto(request.remote_entity_id());
2363    let project_connection_ids = session
2364        .db()
2365        .await
2366        .project_connection_ids(project_id, session.connection_id, false)
2367        .await?;
2368
2369    broadcast(
2370        Some(session.connection_id),
2371        project_connection_ids.iter().copied(),
2372        |connection_id| {
2373            session
2374                .peer
2375                .forward_send(session.connection_id, connection_id, request.clone())
2376        },
2377    );
2378    Ok(())
2379}
2380
2381/// Start following another user in a call.
2382async fn follow(
2383    request: proto::Follow,
2384    response: Response<proto::Follow>,
2385    session: Session,
2386) -> Result<()> {
2387    let room_id = RoomId::from_proto(request.room_id);
2388    let project_id = request.project_id.map(ProjectId::from_proto);
2389    let leader_id = request
2390        .leader_id
2391        .ok_or_else(|| anyhow!("invalid leader id"))?
2392        .into();
2393    let follower_id = session.connection_id;
2394
2395    session
2396        .db()
2397        .await
2398        .check_room_participants(room_id, leader_id, session.connection_id)
2399        .await?;
2400
2401    let response_payload = session
2402        .peer
2403        .forward_request(session.connection_id, leader_id, request)
2404        .await?;
2405    response.send(response_payload)?;
2406
2407    if let Some(project_id) = project_id {
2408        let room = session
2409            .db()
2410            .await
2411            .follow(room_id, project_id, leader_id, follower_id)
2412            .await?;
2413        room_updated(&room, &session.peer);
2414    }
2415
2416    Ok(())
2417}
2418
2419/// Stop following another user in a call.
2420async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2421    let room_id = RoomId::from_proto(request.room_id);
2422    let project_id = request.project_id.map(ProjectId::from_proto);
2423    let leader_id = request
2424        .leader_id
2425        .ok_or_else(|| anyhow!("invalid leader id"))?
2426        .into();
2427    let follower_id = session.connection_id;
2428
2429    session
2430        .db()
2431        .await
2432        .check_room_participants(room_id, leader_id, session.connection_id)
2433        .await?;
2434
2435    session
2436        .peer
2437        .forward_send(session.connection_id, leader_id, request)?;
2438
2439    if let Some(project_id) = project_id {
2440        let room = session
2441            .db()
2442            .await
2443            .unfollow(room_id, project_id, leader_id, follower_id)
2444            .await?;
2445        room_updated(&room, &session.peer);
2446    }
2447
2448    Ok(())
2449}
2450
2451/// Notify everyone following you of your current location.
2452async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2453    let room_id = RoomId::from_proto(request.room_id);
2454    let database = session.db.lock().await;
2455
2456    let connection_ids = if let Some(project_id) = request.project_id {
2457        let project_id = ProjectId::from_proto(project_id);
2458        database
2459            .project_connection_ids(project_id, session.connection_id, true)
2460            .await?
2461    } else {
2462        database
2463            .room_connection_ids(room_id, session.connection_id)
2464            .await?
2465    };
2466
2467    // For now, don't send view update messages back to that view's current leader.
2468    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2469        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2470        _ => None,
2471    });
2472
2473    for connection_id in connection_ids.iter().cloned() {
2474        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2475            session
2476                .peer
2477                .forward_send(session.connection_id, connection_id, request.clone())?;
2478        }
2479    }
2480    Ok(())
2481}
2482
2483/// Get public data about users.
2484async fn get_users(
2485    request: proto::GetUsers,
2486    response: Response<proto::GetUsers>,
2487    session: Session,
2488) -> Result<()> {
2489    let user_ids = request
2490        .user_ids
2491        .into_iter()
2492        .map(UserId::from_proto)
2493        .collect();
2494    let users = session
2495        .db()
2496        .await
2497        .get_users_by_ids(user_ids)
2498        .await?
2499        .into_iter()
2500        .map(|user| proto::User {
2501            id: user.id.to_proto(),
2502            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2503            github_login: user.github_login,
2504            email: user.email_address,
2505            name: user.name,
2506        })
2507        .collect();
2508    response.send(proto::UsersResponse { users })?;
2509    Ok(())
2510}
2511
2512/// Search for users (to invite) buy Github login
2513async fn fuzzy_search_users(
2514    request: proto::FuzzySearchUsers,
2515    response: Response<proto::FuzzySearchUsers>,
2516    session: Session,
2517) -> Result<()> {
2518    let query = request.query;
2519    let users = match query.len() {
2520        0 => vec![],
2521        1 | 2 => session
2522            .db()
2523            .await
2524            .get_user_by_github_login(&query)
2525            .await?
2526            .into_iter()
2527            .collect(),
2528        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2529    };
2530    let users = users
2531        .into_iter()
2532        .filter(|user| user.id != session.user_id())
2533        .map(|user| proto::User {
2534            id: user.id.to_proto(),
2535            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2536            github_login: user.github_login,
2537            name: user.name,
2538            email: user.email_address,
2539        })
2540        .collect();
2541    response.send(proto::UsersResponse { users })?;
2542    Ok(())
2543}
2544
2545/// Send a contact request to another user.
2546async fn request_contact(
2547    request: proto::RequestContact,
2548    response: Response<proto::RequestContact>,
2549    session: Session,
2550) -> Result<()> {
2551    let requester_id = session.user_id();
2552    let responder_id = UserId::from_proto(request.responder_id);
2553    if requester_id == responder_id {
2554        return Err(anyhow!("cannot add yourself as a contact"))?;
2555    }
2556
2557    let notifications = session
2558        .db()
2559        .await
2560        .send_contact_request(requester_id, responder_id)
2561        .await?;
2562
2563    // Update outgoing contact requests of requester
2564    let mut update = proto::UpdateContacts::default();
2565    update.outgoing_requests.push(responder_id.to_proto());
2566    for connection_id in session
2567        .connection_pool()
2568        .await
2569        .user_connection_ids(requester_id)
2570    {
2571        session.peer.send(connection_id, update.clone())?;
2572    }
2573
2574    // Update incoming contact requests of responder
2575    let mut update = proto::UpdateContacts::default();
2576    update
2577        .incoming_requests
2578        .push(proto::IncomingContactRequest {
2579            requester_id: requester_id.to_proto(),
2580        });
2581    let connection_pool = session.connection_pool().await;
2582    for connection_id in connection_pool.user_connection_ids(responder_id) {
2583        session.peer.send(connection_id, update.clone())?;
2584    }
2585
2586    send_notifications(&connection_pool, &session.peer, notifications);
2587
2588    response.send(proto::Ack {})?;
2589    Ok(())
2590}
2591
2592/// Accept or decline a contact request
2593async fn respond_to_contact_request(
2594    request: proto::RespondToContactRequest,
2595    response: Response<proto::RespondToContactRequest>,
2596    session: Session,
2597) -> Result<()> {
2598    let responder_id = session.user_id();
2599    let requester_id = UserId::from_proto(request.requester_id);
2600    let db = session.db().await;
2601    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2602        db.dismiss_contact_notification(responder_id, requester_id)
2603            .await?;
2604    } else {
2605        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2606
2607        let notifications = db
2608            .respond_to_contact_request(responder_id, requester_id, accept)
2609            .await?;
2610        let requester_busy = db.is_user_busy(requester_id).await?;
2611        let responder_busy = db.is_user_busy(responder_id).await?;
2612
2613        let pool = session.connection_pool().await;
2614        // Update responder with new contact
2615        let mut update = proto::UpdateContacts::default();
2616        if accept {
2617            update
2618                .contacts
2619                .push(contact_for_user(requester_id, requester_busy, &pool));
2620        }
2621        update
2622            .remove_incoming_requests
2623            .push(requester_id.to_proto());
2624        for connection_id in pool.user_connection_ids(responder_id) {
2625            session.peer.send(connection_id, update.clone())?;
2626        }
2627
2628        // Update requester with new contact
2629        let mut update = proto::UpdateContacts::default();
2630        if accept {
2631            update
2632                .contacts
2633                .push(contact_for_user(responder_id, responder_busy, &pool));
2634        }
2635        update
2636            .remove_outgoing_requests
2637            .push(responder_id.to_proto());
2638
2639        for connection_id in pool.user_connection_ids(requester_id) {
2640            session.peer.send(connection_id, update.clone())?;
2641        }
2642
2643        send_notifications(&pool, &session.peer, notifications);
2644    }
2645
2646    response.send(proto::Ack {})?;
2647    Ok(())
2648}
2649
2650/// Remove a contact.
2651async fn remove_contact(
2652    request: proto::RemoveContact,
2653    response: Response<proto::RemoveContact>,
2654    session: Session,
2655) -> Result<()> {
2656    let requester_id = session.user_id();
2657    let responder_id = UserId::from_proto(request.user_id);
2658    let db = session.db().await;
2659    let (contact_accepted, deleted_notification_id) =
2660        db.remove_contact(requester_id, responder_id).await?;
2661
2662    let pool = session.connection_pool().await;
2663    // Update outgoing contact requests of requester
2664    let mut update = proto::UpdateContacts::default();
2665    if contact_accepted {
2666        update.remove_contacts.push(responder_id.to_proto());
2667    } else {
2668        update
2669            .remove_outgoing_requests
2670            .push(responder_id.to_proto());
2671    }
2672    for connection_id in pool.user_connection_ids(requester_id) {
2673        session.peer.send(connection_id, update.clone())?;
2674    }
2675
2676    // Update incoming contact requests of responder
2677    let mut update = proto::UpdateContacts::default();
2678    if contact_accepted {
2679        update.remove_contacts.push(requester_id.to_proto());
2680    } else {
2681        update
2682            .remove_incoming_requests
2683            .push(requester_id.to_proto());
2684    }
2685    for connection_id in pool.user_connection_ids(responder_id) {
2686        session.peer.send(connection_id, update.clone())?;
2687        if let Some(notification_id) = deleted_notification_id {
2688            session.peer.send(
2689                connection_id,
2690                proto::DeleteNotification {
2691                    notification_id: notification_id.to_proto(),
2692                },
2693            )?;
2694        }
2695    }
2696
2697    response.send(proto::Ack {})?;
2698    Ok(())
2699}
2700
2701fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2702    version.0.minor() < 139
2703}
2704
2705async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2706    let db = session.db().await;
2707
2708    let feature_flags = db.get_user_flags(user_id).await?;
2709    let plan = session.current_plan(&db).await?;
2710    let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2711    let billing_preferences = db.get_billing_preferences(user_id).await?;
2712
2713    let usage = if let Some(llm_db) = session.app_state.llm_db.clone() {
2714        let subscription = db.get_active_billing_subscription(user_id).await?;
2715
2716        let subscription_period = maybe!({
2717            let subscription = subscription?;
2718            let period_start_at = subscription.current_period_start_at()?;
2719            let period_end_at = subscription.current_period_end_at()?;
2720
2721            Some((period_start_at, period_end_at))
2722        });
2723
2724        if let Some((period_start_at, period_end_at)) = subscription_period {
2725            llm_db
2726                .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2727                .await?
2728        } else {
2729            None
2730        }
2731    } else {
2732        None
2733    };
2734
2735    session
2736        .peer
2737        .send(
2738            session.connection_id,
2739            proto::UpdateUserPlan {
2740                plan: plan.into(),
2741                trial_started_at: billing_customer
2742                    .and_then(|billing_customer| billing_customer.trial_started_at)
2743                    .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2744                is_usage_based_billing_enabled: billing_preferences
2745                    .map(|preferences| preferences.model_request_overages_enabled),
2746                usage: usage.map(|usage| {
2747                    let plan = match plan {
2748                        proto::Plan::Free => zed_llm_client::Plan::Free,
2749                        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2750                        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2751                    };
2752
2753                    let model_requests_limit = match plan.model_requests_limit() {
2754                        zed_llm_client::UsageLimit::Limited(limit) => {
2755                            let limit = if plan == zed_llm_client::Plan::ZedProTrial
2756                                && feature_flags
2757                                    .iter()
2758                                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2759                            {
2760                                1_000
2761                            } else {
2762                                limit
2763                            };
2764
2765                            zed_llm_client::UsageLimit::Limited(limit)
2766                        }
2767                        zed_llm_client::UsageLimit::Unlimited => {
2768                            zed_llm_client::UsageLimit::Unlimited
2769                        }
2770                    };
2771
2772                    proto::SubscriptionUsage {
2773                        model_requests_usage_amount: usage.model_requests as u32,
2774                        model_requests_usage_limit: Some(proto::UsageLimit {
2775                            variant: Some(match model_requests_limit {
2776                                zed_llm_client::UsageLimit::Limited(limit) => {
2777                                    proto::usage_limit::Variant::Limited(
2778                                        proto::usage_limit::Limited {
2779                                            limit: limit as u32,
2780                                        },
2781                                    )
2782                                }
2783                                zed_llm_client::UsageLimit::Unlimited => {
2784                                    proto::usage_limit::Variant::Unlimited(
2785                                        proto::usage_limit::Unlimited {},
2786                                    )
2787                                }
2788                            }),
2789                        }),
2790                        edit_predictions_usage_amount: usage.edit_predictions as u32,
2791                        edit_predictions_usage_limit: Some(proto::UsageLimit {
2792                            variant: Some(match plan.edit_predictions_limit() {
2793                                zed_llm_client::UsageLimit::Limited(limit) => {
2794                                    proto::usage_limit::Variant::Limited(
2795                                        proto::usage_limit::Limited {
2796                                            limit: limit as u32,
2797                                        },
2798                                    )
2799                                }
2800                                zed_llm_client::UsageLimit::Unlimited => {
2801                                    proto::usage_limit::Variant::Unlimited(
2802                                        proto::usage_limit::Unlimited {},
2803                                    )
2804                                }
2805                            }),
2806                        }),
2807                    }
2808                }),
2809            },
2810        )
2811        .trace_err();
2812
2813    Ok(())
2814}
2815
2816async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2817    subscribe_user_to_channels(session.user_id(), &session).await?;
2818    Ok(())
2819}
2820
2821async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2822    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2823    let mut pool = session.connection_pool().await;
2824    for membership in &channels_for_user.channel_memberships {
2825        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2826    }
2827    session.peer.send(
2828        session.connection_id,
2829        build_update_user_channels(&channels_for_user),
2830    )?;
2831    session.peer.send(
2832        session.connection_id,
2833        build_channels_update(channels_for_user),
2834    )?;
2835    Ok(())
2836}
2837
2838/// Creates a new channel.
2839async fn create_channel(
2840    request: proto::CreateChannel,
2841    response: Response<proto::CreateChannel>,
2842    session: Session,
2843) -> Result<()> {
2844    let db = session.db().await;
2845
2846    let parent_id = request.parent_id.map(ChannelId::from_proto);
2847    let (channel, membership) = db
2848        .create_channel(&request.name, parent_id, session.user_id())
2849        .await?;
2850
2851    let root_id = channel.root_id();
2852    let channel = Channel::from_model(channel);
2853
2854    response.send(proto::CreateChannelResponse {
2855        channel: Some(channel.to_proto()),
2856        parent_id: request.parent_id,
2857    })?;
2858
2859    let mut connection_pool = session.connection_pool().await;
2860    if let Some(membership) = membership {
2861        connection_pool.subscribe_to_channel(
2862            membership.user_id,
2863            membership.channel_id,
2864            membership.role,
2865        );
2866        let update = proto::UpdateUserChannels {
2867            channel_memberships: vec![proto::ChannelMembership {
2868                channel_id: membership.channel_id.to_proto(),
2869                role: membership.role.into(),
2870            }],
2871            ..Default::default()
2872        };
2873        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2874            session.peer.send(connection_id, update.clone())?;
2875        }
2876    }
2877
2878    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2879        if !role.can_see_channel(channel.visibility) {
2880            continue;
2881        }
2882
2883        let update = proto::UpdateChannels {
2884            channels: vec![channel.to_proto()],
2885            ..Default::default()
2886        };
2887        session.peer.send(connection_id, update.clone())?;
2888    }
2889
2890    Ok(())
2891}
2892
2893/// Delete a channel
2894async fn delete_channel(
2895    request: proto::DeleteChannel,
2896    response: Response<proto::DeleteChannel>,
2897    session: Session,
2898) -> Result<()> {
2899    let db = session.db().await;
2900
2901    let channel_id = request.channel_id;
2902    let (root_channel, removed_channels) = db
2903        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2904        .await?;
2905    response.send(proto::Ack {})?;
2906
2907    // Notify members of removed channels
2908    let mut update = proto::UpdateChannels::default();
2909    update
2910        .delete_channels
2911        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2912
2913    let connection_pool = session.connection_pool().await;
2914    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2915        session.peer.send(connection_id, update.clone())?;
2916    }
2917
2918    Ok(())
2919}
2920
2921/// Invite someone to join a channel.
2922async fn invite_channel_member(
2923    request: proto::InviteChannelMember,
2924    response: Response<proto::InviteChannelMember>,
2925    session: Session,
2926) -> Result<()> {
2927    let db = session.db().await;
2928    let channel_id = ChannelId::from_proto(request.channel_id);
2929    let invitee_id = UserId::from_proto(request.user_id);
2930    let InviteMemberResult {
2931        channel,
2932        notifications,
2933    } = db
2934        .invite_channel_member(
2935            channel_id,
2936            invitee_id,
2937            session.user_id(),
2938            request.role().into(),
2939        )
2940        .await?;
2941
2942    let update = proto::UpdateChannels {
2943        channel_invitations: vec![channel.to_proto()],
2944        ..Default::default()
2945    };
2946
2947    let connection_pool = session.connection_pool().await;
2948    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2949        session.peer.send(connection_id, update.clone())?;
2950    }
2951
2952    send_notifications(&connection_pool, &session.peer, notifications);
2953
2954    response.send(proto::Ack {})?;
2955    Ok(())
2956}
2957
2958/// remove someone from a channel
2959async fn remove_channel_member(
2960    request: proto::RemoveChannelMember,
2961    response: Response<proto::RemoveChannelMember>,
2962    session: Session,
2963) -> Result<()> {
2964    let db = session.db().await;
2965    let channel_id = ChannelId::from_proto(request.channel_id);
2966    let member_id = UserId::from_proto(request.user_id);
2967
2968    let RemoveChannelMemberResult {
2969        membership_update,
2970        notification_id,
2971    } = db
2972        .remove_channel_member(channel_id, member_id, session.user_id())
2973        .await?;
2974
2975    let mut connection_pool = session.connection_pool().await;
2976    notify_membership_updated(
2977        &mut connection_pool,
2978        membership_update,
2979        member_id,
2980        &session.peer,
2981    );
2982    for connection_id in connection_pool.user_connection_ids(member_id) {
2983        if let Some(notification_id) = notification_id {
2984            session
2985                .peer
2986                .send(
2987                    connection_id,
2988                    proto::DeleteNotification {
2989                        notification_id: notification_id.to_proto(),
2990                    },
2991                )
2992                .trace_err();
2993        }
2994    }
2995
2996    response.send(proto::Ack {})?;
2997    Ok(())
2998}
2999
3000/// Toggle the channel between public and private.
3001/// Care is taken to maintain the invariant that public channels only descend from public channels,
3002/// (though members-only channels can appear at any point in the hierarchy).
3003async fn set_channel_visibility(
3004    request: proto::SetChannelVisibility,
3005    response: Response<proto::SetChannelVisibility>,
3006    session: Session,
3007) -> Result<()> {
3008    let db = session.db().await;
3009    let channel_id = ChannelId::from_proto(request.channel_id);
3010    let visibility = request.visibility().into();
3011
3012    let channel_model = db
3013        .set_channel_visibility(channel_id, visibility, session.user_id())
3014        .await?;
3015    let root_id = channel_model.root_id();
3016    let channel = Channel::from_model(channel_model);
3017
3018    let mut connection_pool = session.connection_pool().await;
3019    for (user_id, role) in connection_pool
3020        .channel_user_ids(root_id)
3021        .collect::<Vec<_>>()
3022        .into_iter()
3023    {
3024        let update = if role.can_see_channel(channel.visibility) {
3025            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3026            proto::UpdateChannels {
3027                channels: vec![channel.to_proto()],
3028                ..Default::default()
3029            }
3030        } else {
3031            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3032            proto::UpdateChannels {
3033                delete_channels: vec![channel.id.to_proto()],
3034                ..Default::default()
3035            }
3036        };
3037
3038        for connection_id in connection_pool.user_connection_ids(user_id) {
3039            session.peer.send(connection_id, update.clone())?;
3040        }
3041    }
3042
3043    response.send(proto::Ack {})?;
3044    Ok(())
3045}
3046
3047/// Alter the role for a user in the channel.
3048async fn set_channel_member_role(
3049    request: proto::SetChannelMemberRole,
3050    response: Response<proto::SetChannelMemberRole>,
3051    session: Session,
3052) -> Result<()> {
3053    let db = session.db().await;
3054    let channel_id = ChannelId::from_proto(request.channel_id);
3055    let member_id = UserId::from_proto(request.user_id);
3056    let result = db
3057        .set_channel_member_role(
3058            channel_id,
3059            session.user_id(),
3060            member_id,
3061            request.role().into(),
3062        )
3063        .await?;
3064
3065    match result {
3066        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3067            let mut connection_pool = session.connection_pool().await;
3068            notify_membership_updated(
3069                &mut connection_pool,
3070                membership_update,
3071                member_id,
3072                &session.peer,
3073            )
3074        }
3075        db::SetMemberRoleResult::InviteUpdated(channel) => {
3076            let update = proto::UpdateChannels {
3077                channel_invitations: vec![channel.to_proto()],
3078                ..Default::default()
3079            };
3080
3081            for connection_id in session
3082                .connection_pool()
3083                .await
3084                .user_connection_ids(member_id)
3085            {
3086                session.peer.send(connection_id, update.clone())?;
3087            }
3088        }
3089    }
3090
3091    response.send(proto::Ack {})?;
3092    Ok(())
3093}
3094
3095/// Change the name of a channel
3096async fn rename_channel(
3097    request: proto::RenameChannel,
3098    response: Response<proto::RenameChannel>,
3099    session: Session,
3100) -> Result<()> {
3101    let db = session.db().await;
3102    let channel_id = ChannelId::from_proto(request.channel_id);
3103    let channel_model = db
3104        .rename_channel(channel_id, session.user_id(), &request.name)
3105        .await?;
3106    let root_id = channel_model.root_id();
3107    let channel = Channel::from_model(channel_model);
3108
3109    response.send(proto::RenameChannelResponse {
3110        channel: Some(channel.to_proto()),
3111    })?;
3112
3113    let connection_pool = session.connection_pool().await;
3114    let update = proto::UpdateChannels {
3115        channels: vec![channel.to_proto()],
3116        ..Default::default()
3117    };
3118    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3119        if role.can_see_channel(channel.visibility) {
3120            session.peer.send(connection_id, update.clone())?;
3121        }
3122    }
3123
3124    Ok(())
3125}
3126
3127/// Move a channel to a new parent.
3128async fn move_channel(
3129    request: proto::MoveChannel,
3130    response: Response<proto::MoveChannel>,
3131    session: Session,
3132) -> Result<()> {
3133    let channel_id = ChannelId::from_proto(request.channel_id);
3134    let to = ChannelId::from_proto(request.to);
3135
3136    let (root_id, channels) = session
3137        .db()
3138        .await
3139        .move_channel(channel_id, to, session.user_id())
3140        .await?;
3141
3142    let connection_pool = session.connection_pool().await;
3143    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3144        let channels = channels
3145            .iter()
3146            .filter_map(|channel| {
3147                if role.can_see_channel(channel.visibility) {
3148                    Some(channel.to_proto())
3149                } else {
3150                    None
3151                }
3152            })
3153            .collect::<Vec<_>>();
3154        if channels.is_empty() {
3155            continue;
3156        }
3157
3158        let update = proto::UpdateChannels {
3159            channels,
3160            ..Default::default()
3161        };
3162
3163        session.peer.send(connection_id, update.clone())?;
3164    }
3165
3166    response.send(Ack {})?;
3167    Ok(())
3168}
3169
3170/// Get the list of channel members
3171async fn get_channel_members(
3172    request: proto::GetChannelMembers,
3173    response: Response<proto::GetChannelMembers>,
3174    session: Session,
3175) -> Result<()> {
3176    let db = session.db().await;
3177    let channel_id = ChannelId::from_proto(request.channel_id);
3178    let limit = if request.limit == 0 {
3179        u16::MAX as u64
3180    } else {
3181        request.limit
3182    };
3183    let (members, users) = db
3184        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3185        .await?;
3186    response.send(proto::GetChannelMembersResponse { members, users })?;
3187    Ok(())
3188}
3189
3190/// Accept or decline a channel invitation.
3191async fn respond_to_channel_invite(
3192    request: proto::RespondToChannelInvite,
3193    response: Response<proto::RespondToChannelInvite>,
3194    session: Session,
3195) -> Result<()> {
3196    let db = session.db().await;
3197    let channel_id = ChannelId::from_proto(request.channel_id);
3198    let RespondToChannelInvite {
3199        membership_update,
3200        notifications,
3201    } = db
3202        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3203        .await?;
3204
3205    let mut connection_pool = session.connection_pool().await;
3206    if let Some(membership_update) = membership_update {
3207        notify_membership_updated(
3208            &mut connection_pool,
3209            membership_update,
3210            session.user_id(),
3211            &session.peer,
3212        );
3213    } else {
3214        let update = proto::UpdateChannels {
3215            remove_channel_invitations: vec![channel_id.to_proto()],
3216            ..Default::default()
3217        };
3218
3219        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3220            session.peer.send(connection_id, update.clone())?;
3221        }
3222    };
3223
3224    send_notifications(&connection_pool, &session.peer, notifications);
3225
3226    response.send(proto::Ack {})?;
3227
3228    Ok(())
3229}
3230
3231/// Join the channels' room
3232async fn join_channel(
3233    request: proto::JoinChannel,
3234    response: Response<proto::JoinChannel>,
3235    session: Session,
3236) -> Result<()> {
3237    let channel_id = ChannelId::from_proto(request.channel_id);
3238    join_channel_internal(channel_id, Box::new(response), session).await
3239}
3240
3241trait JoinChannelInternalResponse {
3242    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3243}
3244impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3245    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3246        Response::<proto::JoinChannel>::send(self, result)
3247    }
3248}
3249impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3250    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3251        Response::<proto::JoinRoom>::send(self, result)
3252    }
3253}
3254
3255async fn join_channel_internal(
3256    channel_id: ChannelId,
3257    response: Box<impl JoinChannelInternalResponse>,
3258    session: Session,
3259) -> Result<()> {
3260    let joined_room = {
3261        let mut db = session.db().await;
3262        // If zed quits without leaving the room, and the user re-opens zed before the
3263        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3264        // room they were in.
3265        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3266            tracing::info!(
3267                stale_connection_id = %connection,
3268                "cleaning up stale connection",
3269            );
3270            drop(db);
3271            leave_room_for_session(&session, connection).await?;
3272            db = session.db().await;
3273        }
3274
3275        let (joined_room, membership_updated, role) = db
3276            .join_channel(channel_id, session.user_id(), session.connection_id)
3277            .await?;
3278
3279        let live_kit_connection_info =
3280            session
3281                .app_state
3282                .livekit_client
3283                .as_ref()
3284                .and_then(|live_kit| {
3285                    let (can_publish, token) = if role == ChannelRole::Guest {
3286                        (
3287                            false,
3288                            live_kit
3289                                .guest_token(
3290                                    &joined_room.room.livekit_room,
3291                                    &session.user_id().to_string(),
3292                                )
3293                                .trace_err()?,
3294                        )
3295                    } else {
3296                        (
3297                            true,
3298                            live_kit
3299                                .room_token(
3300                                    &joined_room.room.livekit_room,
3301                                    &session.user_id().to_string(),
3302                                )
3303                                .trace_err()?,
3304                        )
3305                    };
3306
3307                    Some(LiveKitConnectionInfo {
3308                        server_url: live_kit.url().into(),
3309                        token,
3310                        can_publish,
3311                    })
3312                });
3313
3314        response.send(proto::JoinRoomResponse {
3315            room: Some(joined_room.room.clone()),
3316            channel_id: joined_room
3317                .channel
3318                .as_ref()
3319                .map(|channel| channel.id.to_proto()),
3320            live_kit_connection_info,
3321        })?;
3322
3323        let mut connection_pool = session.connection_pool().await;
3324        if let Some(membership_updated) = membership_updated {
3325            notify_membership_updated(
3326                &mut connection_pool,
3327                membership_updated,
3328                session.user_id(),
3329                &session.peer,
3330            );
3331        }
3332
3333        room_updated(&joined_room.room, &session.peer);
3334
3335        joined_room
3336    };
3337
3338    channel_updated(
3339        &joined_room
3340            .channel
3341            .ok_or_else(|| anyhow!("channel not returned"))?,
3342        &joined_room.room,
3343        &session.peer,
3344        &*session.connection_pool().await,
3345    );
3346
3347    update_user_contacts(session.user_id(), &session).await?;
3348    Ok(())
3349}
3350
3351/// Start editing the channel notes
3352async fn join_channel_buffer(
3353    request: proto::JoinChannelBuffer,
3354    response: Response<proto::JoinChannelBuffer>,
3355    session: Session,
3356) -> Result<()> {
3357    let db = session.db().await;
3358    let channel_id = ChannelId::from_proto(request.channel_id);
3359
3360    let open_response = db
3361        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3362        .await?;
3363
3364    let collaborators = open_response.collaborators.clone();
3365    response.send(open_response)?;
3366
3367    let update = UpdateChannelBufferCollaborators {
3368        channel_id: channel_id.to_proto(),
3369        collaborators: collaborators.clone(),
3370    };
3371    channel_buffer_updated(
3372        session.connection_id,
3373        collaborators
3374            .iter()
3375            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3376        &update,
3377        &session.peer,
3378    );
3379
3380    Ok(())
3381}
3382
3383/// Edit the channel notes
3384async fn update_channel_buffer(
3385    request: proto::UpdateChannelBuffer,
3386    session: Session,
3387) -> Result<()> {
3388    let db = session.db().await;
3389    let channel_id = ChannelId::from_proto(request.channel_id);
3390
3391    let (collaborators, epoch, version) = db
3392        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3393        .await?;
3394
3395    channel_buffer_updated(
3396        session.connection_id,
3397        collaborators.clone(),
3398        &proto::UpdateChannelBuffer {
3399            channel_id: channel_id.to_proto(),
3400            operations: request.operations,
3401        },
3402        &session.peer,
3403    );
3404
3405    let pool = &*session.connection_pool().await;
3406
3407    let non_collaborators =
3408        pool.channel_connection_ids(channel_id)
3409            .filter_map(|(connection_id, _)| {
3410                if collaborators.contains(&connection_id) {
3411                    None
3412                } else {
3413                    Some(connection_id)
3414                }
3415            });
3416
3417    broadcast(None, non_collaborators, |peer_id| {
3418        session.peer.send(
3419            peer_id,
3420            proto::UpdateChannels {
3421                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3422                    channel_id: channel_id.to_proto(),
3423                    epoch: epoch as u64,
3424                    version: version.clone(),
3425                }],
3426                ..Default::default()
3427            },
3428        )
3429    });
3430
3431    Ok(())
3432}
3433
3434/// Rejoin the channel notes after a connection blip
3435async fn rejoin_channel_buffers(
3436    request: proto::RejoinChannelBuffers,
3437    response: Response<proto::RejoinChannelBuffers>,
3438    session: Session,
3439) -> Result<()> {
3440    let db = session.db().await;
3441    let buffers = db
3442        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3443        .await?;
3444
3445    for rejoined_buffer in &buffers {
3446        let collaborators_to_notify = rejoined_buffer
3447            .buffer
3448            .collaborators
3449            .iter()
3450            .filter_map(|c| Some(c.peer_id?.into()));
3451        channel_buffer_updated(
3452            session.connection_id,
3453            collaborators_to_notify,
3454            &proto::UpdateChannelBufferCollaborators {
3455                channel_id: rejoined_buffer.buffer.channel_id,
3456                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3457            },
3458            &session.peer,
3459        );
3460    }
3461
3462    response.send(proto::RejoinChannelBuffersResponse {
3463        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3464    })?;
3465
3466    Ok(())
3467}
3468
3469/// Stop editing the channel notes
3470async fn leave_channel_buffer(
3471    request: proto::LeaveChannelBuffer,
3472    response: Response<proto::LeaveChannelBuffer>,
3473    session: Session,
3474) -> Result<()> {
3475    let db = session.db().await;
3476    let channel_id = ChannelId::from_proto(request.channel_id);
3477
3478    let left_buffer = db
3479        .leave_channel_buffer(channel_id, session.connection_id)
3480        .await?;
3481
3482    response.send(Ack {})?;
3483
3484    channel_buffer_updated(
3485        session.connection_id,
3486        left_buffer.connections,
3487        &proto::UpdateChannelBufferCollaborators {
3488            channel_id: channel_id.to_proto(),
3489            collaborators: left_buffer.collaborators,
3490        },
3491        &session.peer,
3492    );
3493
3494    Ok(())
3495}
3496
3497fn channel_buffer_updated<T: EnvelopedMessage>(
3498    sender_id: ConnectionId,
3499    collaborators: impl IntoIterator<Item = ConnectionId>,
3500    message: &T,
3501    peer: &Peer,
3502) {
3503    broadcast(Some(sender_id), collaborators, |peer_id| {
3504        peer.send(peer_id, message.clone())
3505    });
3506}
3507
3508fn send_notifications(
3509    connection_pool: &ConnectionPool,
3510    peer: &Peer,
3511    notifications: db::NotificationBatch,
3512) {
3513    for (user_id, notification) in notifications {
3514        for connection_id in connection_pool.user_connection_ids(user_id) {
3515            if let Err(error) = peer.send(
3516                connection_id,
3517                proto::AddNotification {
3518                    notification: Some(notification.clone()),
3519                },
3520            ) {
3521                tracing::error!(
3522                    "failed to send notification to {:?} {}",
3523                    connection_id,
3524                    error
3525                );
3526            }
3527        }
3528    }
3529}
3530
3531/// Send a message to the channel
3532async fn send_channel_message(
3533    request: proto::SendChannelMessage,
3534    response: Response<proto::SendChannelMessage>,
3535    session: Session,
3536) -> Result<()> {
3537    // Validate the message body.
3538    let body = request.body.trim().to_string();
3539    if body.len() > MAX_MESSAGE_LEN {
3540        return Err(anyhow!("message is too long"))?;
3541    }
3542    if body.is_empty() {
3543        return Err(anyhow!("message can't be blank"))?;
3544    }
3545
3546    // TODO: adjust mentions if body is trimmed
3547
3548    let timestamp = OffsetDateTime::now_utc();
3549    let nonce = request
3550        .nonce
3551        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3552
3553    let channel_id = ChannelId::from_proto(request.channel_id);
3554    let CreatedChannelMessage {
3555        message_id,
3556        participant_connection_ids,
3557        notifications,
3558    } = session
3559        .db()
3560        .await
3561        .create_channel_message(
3562            channel_id,
3563            session.user_id(),
3564            &body,
3565            &request.mentions,
3566            timestamp,
3567            nonce.clone().into(),
3568            request.reply_to_message_id.map(MessageId::from_proto),
3569        )
3570        .await?;
3571
3572    let message = proto::ChannelMessage {
3573        sender_id: session.user_id().to_proto(),
3574        id: message_id.to_proto(),
3575        body,
3576        mentions: request.mentions,
3577        timestamp: timestamp.unix_timestamp() as u64,
3578        nonce: Some(nonce),
3579        reply_to_message_id: request.reply_to_message_id,
3580        edited_at: None,
3581    };
3582    broadcast(
3583        Some(session.connection_id),
3584        participant_connection_ids.clone(),
3585        |connection| {
3586            session.peer.send(
3587                connection,
3588                proto::ChannelMessageSent {
3589                    channel_id: channel_id.to_proto(),
3590                    message: Some(message.clone()),
3591                },
3592            )
3593        },
3594    );
3595    response.send(proto::SendChannelMessageResponse {
3596        message: Some(message),
3597    })?;
3598
3599    let pool = &*session.connection_pool().await;
3600    let non_participants =
3601        pool.channel_connection_ids(channel_id)
3602            .filter_map(|(connection_id, _)| {
3603                if participant_connection_ids.contains(&connection_id) {
3604                    None
3605                } else {
3606                    Some(connection_id)
3607                }
3608            });
3609    broadcast(None, non_participants, |peer_id| {
3610        session.peer.send(
3611            peer_id,
3612            proto::UpdateChannels {
3613                latest_channel_message_ids: vec![proto::ChannelMessageId {
3614                    channel_id: channel_id.to_proto(),
3615                    message_id: message_id.to_proto(),
3616                }],
3617                ..Default::default()
3618            },
3619        )
3620    });
3621    send_notifications(pool, &session.peer, notifications);
3622
3623    Ok(())
3624}
3625
3626/// Delete a channel message
3627async fn remove_channel_message(
3628    request: proto::RemoveChannelMessage,
3629    response: Response<proto::RemoveChannelMessage>,
3630    session: Session,
3631) -> Result<()> {
3632    let channel_id = ChannelId::from_proto(request.channel_id);
3633    let message_id = MessageId::from_proto(request.message_id);
3634    let (connection_ids, existing_notification_ids) = session
3635        .db()
3636        .await
3637        .remove_channel_message(channel_id, message_id, session.user_id())
3638        .await?;
3639
3640    broadcast(
3641        Some(session.connection_id),
3642        connection_ids,
3643        move |connection| {
3644            session.peer.send(connection, request.clone())?;
3645
3646            for notification_id in &existing_notification_ids {
3647                session.peer.send(
3648                    connection,
3649                    proto::DeleteNotification {
3650                        notification_id: (*notification_id).to_proto(),
3651                    },
3652                )?;
3653            }
3654
3655            Ok(())
3656        },
3657    );
3658    response.send(proto::Ack {})?;
3659    Ok(())
3660}
3661
3662async fn update_channel_message(
3663    request: proto::UpdateChannelMessage,
3664    response: Response<proto::UpdateChannelMessage>,
3665    session: Session,
3666) -> Result<()> {
3667    let channel_id = ChannelId::from_proto(request.channel_id);
3668    let message_id = MessageId::from_proto(request.message_id);
3669    let updated_at = OffsetDateTime::now_utc();
3670    let UpdatedChannelMessage {
3671        message_id,
3672        participant_connection_ids,
3673        notifications,
3674        reply_to_message_id,
3675        timestamp,
3676        deleted_mention_notification_ids,
3677        updated_mention_notifications,
3678    } = session
3679        .db()
3680        .await
3681        .update_channel_message(
3682            channel_id,
3683            message_id,
3684            session.user_id(),
3685            request.body.as_str(),
3686            &request.mentions,
3687            updated_at,
3688        )
3689        .await?;
3690
3691    let nonce = request
3692        .nonce
3693        .clone()
3694        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3695
3696    let message = proto::ChannelMessage {
3697        sender_id: session.user_id().to_proto(),
3698        id: message_id.to_proto(),
3699        body: request.body.clone(),
3700        mentions: request.mentions.clone(),
3701        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3702        nonce: Some(nonce),
3703        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3704        edited_at: Some(updated_at.unix_timestamp() as u64),
3705    };
3706
3707    response.send(proto::Ack {})?;
3708
3709    let pool = &*session.connection_pool().await;
3710    broadcast(
3711        Some(session.connection_id),
3712        participant_connection_ids,
3713        |connection| {
3714            session.peer.send(
3715                connection,
3716                proto::ChannelMessageUpdate {
3717                    channel_id: channel_id.to_proto(),
3718                    message: Some(message.clone()),
3719                },
3720            )?;
3721
3722            for notification_id in &deleted_mention_notification_ids {
3723                session.peer.send(
3724                    connection,
3725                    proto::DeleteNotification {
3726                        notification_id: (*notification_id).to_proto(),
3727                    },
3728                )?;
3729            }
3730
3731            for notification in &updated_mention_notifications {
3732                session.peer.send(
3733                    connection,
3734                    proto::UpdateNotification {
3735                        notification: Some(notification.clone()),
3736                    },
3737                )?;
3738            }
3739
3740            Ok(())
3741        },
3742    );
3743
3744    send_notifications(pool, &session.peer, notifications);
3745
3746    Ok(())
3747}
3748
3749/// Mark a channel message as read
3750async fn acknowledge_channel_message(
3751    request: proto::AckChannelMessage,
3752    session: Session,
3753) -> Result<()> {
3754    let channel_id = ChannelId::from_proto(request.channel_id);
3755    let message_id = MessageId::from_proto(request.message_id);
3756    let notifications = session
3757        .db()
3758        .await
3759        .observe_channel_message(channel_id, session.user_id(), message_id)
3760        .await?;
3761    send_notifications(
3762        &*session.connection_pool().await,
3763        &session.peer,
3764        notifications,
3765    );
3766    Ok(())
3767}
3768
3769/// Mark a buffer version as synced
3770async fn acknowledge_buffer_version(
3771    request: proto::AckBufferOperation,
3772    session: Session,
3773) -> Result<()> {
3774    let buffer_id = BufferId::from_proto(request.buffer_id);
3775    session
3776        .db()
3777        .await
3778        .observe_buffer_version(
3779            buffer_id,
3780            session.user_id(),
3781            request.epoch as i32,
3782            &request.version,
3783        )
3784        .await?;
3785    Ok(())
3786}
3787
3788/// Get a Supermaven API key for the user
3789async fn get_supermaven_api_key(
3790    _request: proto::GetSupermavenApiKey,
3791    response: Response<proto::GetSupermavenApiKey>,
3792    session: Session,
3793) -> Result<()> {
3794    let user_id: String = session.user_id().to_string();
3795    if !session.is_staff() {
3796        return Err(anyhow!("supermaven not enabled for this account"))?;
3797    }
3798
3799    let email = session
3800        .email()
3801        .ok_or_else(|| anyhow!("user must have an email"))?;
3802
3803    let supermaven_admin_api = session
3804        .supermaven_client
3805        .as_ref()
3806        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3807
3808    let result = supermaven_admin_api
3809        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3810        .await?;
3811
3812    response.send(proto::GetSupermavenApiKeyResponse {
3813        api_key: result.api_key,
3814    })?;
3815
3816    Ok(())
3817}
3818
3819/// Start receiving chat updates for a channel
3820async fn join_channel_chat(
3821    request: proto::JoinChannelChat,
3822    response: Response<proto::JoinChannelChat>,
3823    session: Session,
3824) -> Result<()> {
3825    let channel_id = ChannelId::from_proto(request.channel_id);
3826
3827    let db = session.db().await;
3828    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3829        .await?;
3830    let messages = db
3831        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3832        .await?;
3833    response.send(proto::JoinChannelChatResponse {
3834        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3835        messages,
3836    })?;
3837    Ok(())
3838}
3839
3840/// Stop receiving chat updates for a channel
3841async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3842    let channel_id = ChannelId::from_proto(request.channel_id);
3843    session
3844        .db()
3845        .await
3846        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3847        .await?;
3848    Ok(())
3849}
3850
3851/// Retrieve the chat history for a channel
3852async fn get_channel_messages(
3853    request: proto::GetChannelMessages,
3854    response: Response<proto::GetChannelMessages>,
3855    session: Session,
3856) -> Result<()> {
3857    let channel_id = ChannelId::from_proto(request.channel_id);
3858    let messages = session
3859        .db()
3860        .await
3861        .get_channel_messages(
3862            channel_id,
3863            session.user_id(),
3864            MESSAGE_COUNT_PER_PAGE,
3865            Some(MessageId::from_proto(request.before_message_id)),
3866        )
3867        .await?;
3868    response.send(proto::GetChannelMessagesResponse {
3869        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3870        messages,
3871    })?;
3872    Ok(())
3873}
3874
3875/// Retrieve specific chat messages
3876async fn get_channel_messages_by_id(
3877    request: proto::GetChannelMessagesById,
3878    response: Response<proto::GetChannelMessagesById>,
3879    session: Session,
3880) -> Result<()> {
3881    let message_ids = request
3882        .message_ids
3883        .iter()
3884        .map(|id| MessageId::from_proto(*id))
3885        .collect::<Vec<_>>();
3886    let messages = session
3887        .db()
3888        .await
3889        .get_channel_messages_by_id(session.user_id(), &message_ids)
3890        .await?;
3891    response.send(proto::GetChannelMessagesResponse {
3892        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3893        messages,
3894    })?;
3895    Ok(())
3896}
3897
3898/// Retrieve the current users notifications
3899async fn get_notifications(
3900    request: proto::GetNotifications,
3901    response: Response<proto::GetNotifications>,
3902    session: Session,
3903) -> Result<()> {
3904    let notifications = session
3905        .db()
3906        .await
3907        .get_notifications(
3908            session.user_id(),
3909            NOTIFICATION_COUNT_PER_PAGE,
3910            request.before_id.map(db::NotificationId::from_proto),
3911        )
3912        .await?;
3913    response.send(proto::GetNotificationsResponse {
3914        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3915        notifications,
3916    })?;
3917    Ok(())
3918}
3919
3920/// Mark notifications as read
3921async fn mark_notification_as_read(
3922    request: proto::MarkNotificationRead,
3923    response: Response<proto::MarkNotificationRead>,
3924    session: Session,
3925) -> Result<()> {
3926    let database = &session.db().await;
3927    let notifications = database
3928        .mark_notification_as_read_by_id(
3929            session.user_id(),
3930            NotificationId::from_proto(request.notification_id),
3931        )
3932        .await?;
3933    send_notifications(
3934        &*session.connection_pool().await,
3935        &session.peer,
3936        notifications,
3937    );
3938    response.send(proto::Ack {})?;
3939    Ok(())
3940}
3941
3942/// Get the current users information
3943async fn get_private_user_info(
3944    _request: proto::GetPrivateUserInfo,
3945    response: Response<proto::GetPrivateUserInfo>,
3946    session: Session,
3947) -> Result<()> {
3948    let db = session.db().await;
3949
3950    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3951    let user = db
3952        .get_user_by_id(session.user_id())
3953        .await?
3954        .ok_or_else(|| anyhow!("user not found"))?;
3955    let flags = db.get_user_flags(session.user_id()).await?;
3956
3957    response.send(proto::GetPrivateUserInfoResponse {
3958        metrics_id,
3959        staff: user.admin,
3960        flags,
3961        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3962    })?;
3963    Ok(())
3964}
3965
3966/// Accept the terms of service (tos) on behalf of the current user
3967async fn accept_terms_of_service(
3968    _request: proto::AcceptTermsOfService,
3969    response: Response<proto::AcceptTermsOfService>,
3970    session: Session,
3971) -> Result<()> {
3972    let db = session.db().await;
3973
3974    let accepted_tos_at = Utc::now();
3975    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3976        .await?;
3977
3978    response.send(proto::AcceptTermsOfServiceResponse {
3979        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3980    })?;
3981    Ok(())
3982}
3983
3984/// The minimum account age an account must have in order to use the LLM service.
3985pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3986
3987async fn get_llm_api_token(
3988    _request: proto::GetLlmToken,
3989    response: Response<proto::GetLlmToken>,
3990    session: Session,
3991) -> Result<()> {
3992    let db = session.db().await;
3993
3994    let flags = db.get_user_flags(session.user_id()).await?;
3995    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3996
3997    if !session.is_staff() && !has_language_models_feature_flag {
3998        Err(anyhow!("permission denied"))?
3999    }
4000
4001    let user_id = session.user_id();
4002    let user = db
4003        .get_user_by_id(user_id)
4004        .await?
4005        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4006
4007    if user.accepted_tos_at.is_none() {
4008        Err(anyhow!("terms of service not accepted"))?
4009    }
4010
4011    let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4012    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4013    let billing_preferences = db.get_billing_preferences(user.id).await?;
4014
4015    let token = LlmTokenClaims::create(
4016        &user,
4017        session.is_staff(),
4018        billing_preferences,
4019        &flags,
4020        has_legacy_llm_subscription,
4021        billing_subscription,
4022        session.system_id.clone(),
4023        &session.app_state.config,
4024    )?;
4025    response.send(proto::GetLlmTokenResponse { token })?;
4026    Ok(())
4027}
4028
4029fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4030    let message = match message {
4031        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4032        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4033        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4034        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4035        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4036            code: frame.code.into(),
4037            reason: frame.reason.as_str().to_owned().into(),
4038        })),
4039        // We should never receive a frame while reading the message, according
4040        // to the `tungstenite` maintainers:
4041        //
4042        // > It cannot occur when you read messages from the WebSocket, but it
4043        // > can be used when you want to send the raw frames (e.g. you want to
4044        // > send the frames to the WebSocket without composing the full message first).
4045        // >
4046        // > — https://github.com/snapview/tungstenite-rs/issues/268
4047        TungsteniteMessage::Frame(_) => {
4048            bail!("received an unexpected frame while reading the message")
4049        }
4050    };
4051
4052    Ok(message)
4053}
4054
4055fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4056    match message {
4057        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4058        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4059        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4060        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4061        AxumMessage::Close(frame) => {
4062            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4063                code: frame.code.into(),
4064                reason: frame.reason.as_ref().into(),
4065            }))
4066        }
4067    }
4068}
4069
4070fn notify_membership_updated(
4071    connection_pool: &mut ConnectionPool,
4072    result: MembershipUpdated,
4073    user_id: UserId,
4074    peer: &Peer,
4075) {
4076    for membership in &result.new_channels.channel_memberships {
4077        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4078    }
4079    for channel_id in &result.removed_channels {
4080        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4081    }
4082
4083    let user_channels_update = proto::UpdateUserChannels {
4084        channel_memberships: result
4085            .new_channels
4086            .channel_memberships
4087            .iter()
4088            .map(|cm| proto::ChannelMembership {
4089                channel_id: cm.channel_id.to_proto(),
4090                role: cm.role.into(),
4091            })
4092            .collect(),
4093        ..Default::default()
4094    };
4095
4096    let mut update = build_channels_update(result.new_channels);
4097    update.delete_channels = result
4098        .removed_channels
4099        .into_iter()
4100        .map(|id| id.to_proto())
4101        .collect();
4102    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4103
4104    for connection_id in connection_pool.user_connection_ids(user_id) {
4105        peer.send(connection_id, user_channels_update.clone())
4106            .trace_err();
4107        peer.send(connection_id, update.clone()).trace_err();
4108    }
4109}
4110
4111fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4112    proto::UpdateUserChannels {
4113        channel_memberships: channels
4114            .channel_memberships
4115            .iter()
4116            .map(|m| proto::ChannelMembership {
4117                channel_id: m.channel_id.to_proto(),
4118                role: m.role.into(),
4119            })
4120            .collect(),
4121        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4122        observed_channel_message_id: channels.observed_channel_messages.clone(),
4123    }
4124}
4125
4126fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4127    let mut update = proto::UpdateChannels::default();
4128
4129    for channel in channels.channels {
4130        update.channels.push(channel.to_proto());
4131    }
4132
4133    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4134    update.latest_channel_message_ids = channels.latest_channel_messages;
4135
4136    for (channel_id, participants) in channels.channel_participants {
4137        update
4138            .channel_participants
4139            .push(proto::ChannelParticipants {
4140                channel_id: channel_id.to_proto(),
4141                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4142            });
4143    }
4144
4145    for channel in channels.invited_channels {
4146        update.channel_invitations.push(channel.to_proto());
4147    }
4148
4149    update
4150}
4151
4152fn build_initial_contacts_update(
4153    contacts: Vec<db::Contact>,
4154    pool: &ConnectionPool,
4155) -> proto::UpdateContacts {
4156    let mut update = proto::UpdateContacts::default();
4157
4158    for contact in contacts {
4159        match contact {
4160            db::Contact::Accepted { user_id, busy } => {
4161                update.contacts.push(contact_for_user(user_id, busy, pool));
4162            }
4163            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4164            db::Contact::Incoming { user_id } => {
4165                update
4166                    .incoming_requests
4167                    .push(proto::IncomingContactRequest {
4168                        requester_id: user_id.to_proto(),
4169                    })
4170            }
4171        }
4172    }
4173
4174    update
4175}
4176
4177fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4178    proto::Contact {
4179        user_id: user_id.to_proto(),
4180        online: pool.is_user_online(user_id),
4181        busy,
4182    }
4183}
4184
4185fn room_updated(room: &proto::Room, peer: &Peer) {
4186    broadcast(
4187        None,
4188        room.participants
4189            .iter()
4190            .filter_map(|participant| Some(participant.peer_id?.into())),
4191        |peer_id| {
4192            peer.send(
4193                peer_id,
4194                proto::RoomUpdated {
4195                    room: Some(room.clone()),
4196                },
4197            )
4198        },
4199    );
4200}
4201
4202fn channel_updated(
4203    channel: &db::channel::Model,
4204    room: &proto::Room,
4205    peer: &Peer,
4206    pool: &ConnectionPool,
4207) {
4208    let participants = room
4209        .participants
4210        .iter()
4211        .map(|p| p.user_id)
4212        .collect::<Vec<_>>();
4213
4214    broadcast(
4215        None,
4216        pool.channel_connection_ids(channel.root_id())
4217            .filter_map(|(channel_id, role)| {
4218                role.can_see_channel(channel.visibility)
4219                    .then_some(channel_id)
4220            }),
4221        |peer_id| {
4222            peer.send(
4223                peer_id,
4224                proto::UpdateChannels {
4225                    channel_participants: vec![proto::ChannelParticipants {
4226                        channel_id: channel.id.to_proto(),
4227                        participant_user_ids: participants.clone(),
4228                    }],
4229                    ..Default::default()
4230                },
4231            )
4232        },
4233    );
4234}
4235
4236async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4237    let db = session.db().await;
4238
4239    let contacts = db.get_contacts(user_id).await?;
4240    let busy = db.is_user_busy(user_id).await?;
4241
4242    let pool = session.connection_pool().await;
4243    let updated_contact = contact_for_user(user_id, busy, &pool);
4244    for contact in contacts {
4245        if let db::Contact::Accepted {
4246            user_id: contact_user_id,
4247            ..
4248        } = contact
4249        {
4250            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4251                session
4252                    .peer
4253                    .send(
4254                        contact_conn_id,
4255                        proto::UpdateContacts {
4256                            contacts: vec![updated_contact.clone()],
4257                            remove_contacts: Default::default(),
4258                            incoming_requests: Default::default(),
4259                            remove_incoming_requests: Default::default(),
4260                            outgoing_requests: Default::default(),
4261                            remove_outgoing_requests: Default::default(),
4262                        },
4263                    )
4264                    .trace_err();
4265            }
4266        }
4267    }
4268    Ok(())
4269}
4270
4271async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4272    let mut contacts_to_update = HashSet::default();
4273
4274    let room_id;
4275    let canceled_calls_to_user_ids;
4276    let livekit_room;
4277    let delete_livekit_room;
4278    let room;
4279    let channel;
4280
4281    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4282        contacts_to_update.insert(session.user_id());
4283
4284        for project in left_room.left_projects.values() {
4285            project_left(project, session);
4286        }
4287
4288        room_id = RoomId::from_proto(left_room.room.id);
4289        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4290        livekit_room = mem::take(&mut left_room.room.livekit_room);
4291        delete_livekit_room = left_room.deleted;
4292        room = mem::take(&mut left_room.room);
4293        channel = mem::take(&mut left_room.channel);
4294
4295        room_updated(&room, &session.peer);
4296    } else {
4297        return Ok(());
4298    }
4299
4300    if let Some(channel) = channel {
4301        channel_updated(
4302            &channel,
4303            &room,
4304            &session.peer,
4305            &*session.connection_pool().await,
4306        );
4307    }
4308
4309    {
4310        let pool = session.connection_pool().await;
4311        for canceled_user_id in canceled_calls_to_user_ids {
4312            for connection_id in pool.user_connection_ids(canceled_user_id) {
4313                session
4314                    .peer
4315                    .send(
4316                        connection_id,
4317                        proto::CallCanceled {
4318                            room_id: room_id.to_proto(),
4319                        },
4320                    )
4321                    .trace_err();
4322            }
4323            contacts_to_update.insert(canceled_user_id);
4324        }
4325    }
4326
4327    for contact_user_id in contacts_to_update {
4328        update_user_contacts(contact_user_id, session).await?;
4329    }
4330
4331    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4332        live_kit
4333            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4334            .await
4335            .trace_err();
4336
4337        if delete_livekit_room {
4338            live_kit.delete_room(livekit_room).await.trace_err();
4339        }
4340    }
4341
4342    Ok(())
4343}
4344
4345async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4346    let left_channel_buffers = session
4347        .db()
4348        .await
4349        .leave_channel_buffers(session.connection_id)
4350        .await?;
4351
4352    for left_buffer in left_channel_buffers {
4353        channel_buffer_updated(
4354            session.connection_id,
4355            left_buffer.connections,
4356            &proto::UpdateChannelBufferCollaborators {
4357                channel_id: left_buffer.channel_id.to_proto(),
4358                collaborators: left_buffer.collaborators,
4359            },
4360            &session.peer,
4361        );
4362    }
4363
4364    Ok(())
4365}
4366
4367fn project_left(project: &db::LeftProject, session: &Session) {
4368    for connection_id in &project.connection_ids {
4369        if project.should_unshare {
4370            session
4371                .peer
4372                .send(
4373                    *connection_id,
4374                    proto::UnshareProject {
4375                        project_id: project.id.to_proto(),
4376                    },
4377                )
4378                .trace_err();
4379        } else {
4380            session
4381                .peer
4382                .send(
4383                    *connection_id,
4384                    proto::RemoveProjectCollaborator {
4385                        project_id: project.id.to_proto(),
4386                        peer_id: Some(session.connection_id.into()),
4387                    },
4388                )
4389                .trace_err();
4390        }
4391    }
4392}
4393
4394pub trait ResultExt {
4395    type Ok;
4396
4397    fn trace_err(self) -> Option<Self::Ok>;
4398}
4399
4400impl<T, E> ResultExt for Result<T, E>
4401where
4402    E: std::fmt::Debug,
4403{
4404    type Ok = T;
4405
4406    #[track_caller]
4407    fn trace_err(self) -> Option<T> {
4408        match self {
4409            Ok(value) => Some(value),
4410            Err(error) => {
4411                tracing::error!("{:?}", error);
4412                None
4413            }
4414        }
4415    }
4416}