rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::db::billing_subscription::SubscriptionKind;
   5use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   6use crate::{
   7    AppState, Error, Result, auth,
   8    db::{
   9        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  10        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  11        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  12        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15};
  16use anyhow::{Context as _, anyhow, bail};
  17use async_tungstenite::tungstenite::{
  18    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  19};
  20use axum::{
  21    Extension, Router, TypedHeader,
  22    body::Body,
  23    extract::{
  24        ConnectInfo, WebSocketUpgrade,
  25        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use reqwest_client::ReqwestClient;
  38use rpc::proto::split_repository_update;
  39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{MutexGuard, Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone, Debug)]
 105pub enum Principal {
 106    User(User),
 107    Impersonated { user: User, admin: User },
 108}
 109
 110impl Principal {
 111    fn update_span(&self, span: &tracing::Span) {
 112        match &self {
 113            Principal::User(user) => {
 114                span.record("user_id", user.id.0);
 115                span.record("login", &user.github_login);
 116            }
 117            Principal::Impersonated { user, admin } => {
 118                span.record("user_id", user.id.0);
 119                span.record("login", &user.github_login);
 120                span.record("impersonator", &admin.github_login);
 121            }
 122        }
 123    }
 124}
 125
 126#[derive(Clone)]
 127struct Session {
 128    principal: Principal,
 129    connection_id: ConnectionId,
 130    db: Arc<tokio::sync::Mutex<DbHandle>>,
 131    peer: Arc<Peer>,
 132    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 133    app_state: Arc<AppState>,
 134    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 135    /// The GeoIP country code for the user.
 136    #[allow(unused)]
 137    geoip_country_code: Option<String>,
 138    system_id: Option<String>,
 139    _executor: Executor,
 140}
 141
 142impl Session {
 143    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 144        #[cfg(test)]
 145        tokio::task::yield_now().await;
 146        let guard = self.db.lock().await;
 147        #[cfg(test)]
 148        tokio::task::yield_now().await;
 149        guard
 150    }
 151
 152    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.connection_pool.lock();
 156        ConnectionPoolGuard {
 157            guard,
 158            _not_send: PhantomData,
 159        }
 160    }
 161
 162    fn is_staff(&self) -> bool {
 163        match &self.principal {
 164            Principal::User(user) => user.admin,
 165            Principal::Impersonated { .. } => true,
 166        }
 167    }
 168
 169    pub async fn has_llm_subscription(
 170        &self,
 171        db: &MutexGuard<'_, DbHandle>,
 172    ) -> anyhow::Result<bool> {
 173        if self.is_staff() {
 174            return Ok(true);
 175        }
 176
 177        let user_id = self.user_id();
 178
 179        Ok(db.has_active_billing_subscription(user_id).await?)
 180    }
 181
 182    pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 183        if self.is_staff() {
 184            return Ok(proto::Plan::ZedPro);
 185        }
 186
 187        let user_id = self.user_id();
 188
 189        let subscription = db.get_active_billing_subscription(user_id).await?;
 190        let subscription_kind = subscription.and_then(|subscription| subscription.kind);
 191
 192        let plan = if let Some(subscription_kind) = subscription_kind {
 193            match subscription_kind {
 194                SubscriptionKind::ZedPro => proto::Plan::ZedPro,
 195                SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
 196                SubscriptionKind::ZedFree => proto::Plan::Free,
 197            }
 198        } else {
 199            proto::Plan::Free
 200        };
 201
 202        Ok(plan)
 203    }
 204
 205    fn user_id(&self) -> UserId {
 206        match &self.principal {
 207            Principal::User(user) => user.id,
 208            Principal::Impersonated { user, .. } => user.id,
 209        }
 210    }
 211
 212    pub fn email(&self) -> Option<String> {
 213        match &self.principal {
 214            Principal::User(user) => user.email_address.clone(),
 215            Principal::Impersonated { user, .. } => user.email_address.clone(),
 216        }
 217    }
 218}
 219
 220impl Debug for Session {
 221    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 222        let mut result = f.debug_struct("Session");
 223        match &self.principal {
 224            Principal::User(user) => {
 225                result.field("user", &user.github_login);
 226            }
 227            Principal::Impersonated { user, admin } => {
 228                result.field("user", &user.github_login);
 229                result.field("impersonator", &admin.github_login);
 230            }
 231        }
 232        result.field("connection_id", &self.connection_id).finish()
 233    }
 234}
 235
 236struct DbHandle(Arc<Database>);
 237
 238impl Deref for DbHandle {
 239    type Target = Database;
 240
 241    fn deref(&self) -> &Self::Target {
 242        self.0.as_ref()
 243    }
 244}
 245
 246pub struct Server {
 247    id: parking_lot::Mutex<ServerId>,
 248    peer: Arc<Peer>,
 249    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 250    app_state: Arc<AppState>,
 251    handlers: HashMap<TypeId, MessageHandler>,
 252    teardown: watch::Sender<bool>,
 253}
 254
 255pub(crate) struct ConnectionPoolGuard<'a> {
 256    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 257    _not_send: PhantomData<Rc<()>>,
 258}
 259
 260#[derive(Serialize)]
 261pub struct ServerSnapshot<'a> {
 262    peer: &'a Peer,
 263    #[serde(serialize_with = "serialize_deref")]
 264    connection_pool: ConnectionPoolGuard<'a>,
 265}
 266
 267pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 268where
 269    S: Serializer,
 270    T: Deref<Target = U>,
 271    U: Serialize,
 272{
 273    Serialize::serialize(value.deref(), serializer)
 274}
 275
 276impl Server {
 277    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 278        let mut server = Self {
 279            id: parking_lot::Mutex::new(id),
 280            peer: Peer::new(id.0 as u32),
 281            app_state: app_state.clone(),
 282            connection_pool: Default::default(),
 283            handlers: Default::default(),
 284            teardown: watch::channel(false).0,
 285        };
 286
 287        server
 288            .add_request_handler(ping)
 289            .add_request_handler(create_room)
 290            .add_request_handler(join_room)
 291            .add_request_handler(rejoin_room)
 292            .add_request_handler(leave_room)
 293            .add_request_handler(set_room_participant_role)
 294            .add_request_handler(call)
 295            .add_request_handler(cancel_call)
 296            .add_message_handler(decline_call)
 297            .add_request_handler(update_participant_location)
 298            .add_request_handler(share_project)
 299            .add_message_handler(unshare_project)
 300            .add_request_handler(join_project)
 301            .add_message_handler(leave_project)
 302            .add_request_handler(update_project)
 303            .add_request_handler(update_worktree)
 304            .add_request_handler(update_repository)
 305            .add_request_handler(remove_repository)
 306            .add_message_handler(start_language_server)
 307            .add_message_handler(update_language_server)
 308            .add_message_handler(update_diagnostic_summary)
 309            .add_message_handler(update_worktree_settings)
 310            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 313            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 314            .add_request_handler(forward_find_search_candidates_request)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 318            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 319            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 320            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 321            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 322            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 323            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 324            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 325            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 326            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 327            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 328            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 329            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 330            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 331            .add_request_handler(
 332                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 333            )
 334            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 335            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 336            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 337            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 338            .add_request_handler(
 339                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 340            )
 341            .add_request_handler(
 342                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 343            )
 344            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 345            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 346            .add_request_handler(
 347                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 348            )
 349            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 350            .add_request_handler(
 351                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 352            )
 353            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 354            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 355            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 356            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 357            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 358            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 359            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 360            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 361            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 362            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 363            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 365            .add_request_handler(
 366                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 367            )
 368            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 369            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 370            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 371            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 372            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 373            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 374            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 375            .add_message_handler(create_buffer_for_peer)
 376            .add_request_handler(update_buffer)
 377            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 378            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 379            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 380            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 381            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 383            .add_request_handler(get_users)
 384            .add_request_handler(fuzzy_search_users)
 385            .add_request_handler(request_contact)
 386            .add_request_handler(remove_contact)
 387            .add_request_handler(respond_to_contact_request)
 388            .add_message_handler(subscribe_to_channels)
 389            .add_request_handler(create_channel)
 390            .add_request_handler(delete_channel)
 391            .add_request_handler(invite_channel_member)
 392            .add_request_handler(remove_channel_member)
 393            .add_request_handler(set_channel_member_role)
 394            .add_request_handler(set_channel_visibility)
 395            .add_request_handler(rename_channel)
 396            .add_request_handler(join_channel_buffer)
 397            .add_request_handler(leave_channel_buffer)
 398            .add_message_handler(update_channel_buffer)
 399            .add_request_handler(rejoin_channel_buffers)
 400            .add_request_handler(get_channel_members)
 401            .add_request_handler(respond_to_channel_invite)
 402            .add_request_handler(join_channel)
 403            .add_request_handler(join_channel_chat)
 404            .add_message_handler(leave_channel_chat)
 405            .add_request_handler(send_channel_message)
 406            .add_request_handler(remove_channel_message)
 407            .add_request_handler(update_channel_message)
 408            .add_request_handler(get_channel_messages)
 409            .add_request_handler(get_channel_messages_by_id)
 410            .add_request_handler(get_notifications)
 411            .add_request_handler(mark_notification_as_read)
 412            .add_request_handler(move_channel)
 413            .add_request_handler(follow)
 414            .add_message_handler(unfollow)
 415            .add_message_handler(update_followers)
 416            .add_request_handler(get_private_user_info)
 417            .add_request_handler(get_llm_api_token)
 418            .add_request_handler(accept_terms_of_service)
 419            .add_message_handler(acknowledge_channel_message)
 420            .add_message_handler(acknowledge_buffer_version)
 421            .add_request_handler(get_supermaven_api_key)
 422            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 423            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 424            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 425            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 426            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 427            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 428            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 429            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 430            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 431            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 432            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 433            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 434            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 435            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 436            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 437            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 438            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 439            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 440            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 441            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 442            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 443            .add_message_handler(update_context);
 444
 445        Arc::new(server)
 446    }
 447
 448    pub async fn start(&self) -> Result<()> {
 449        let server_id = *self.id.lock();
 450        let app_state = self.app_state.clone();
 451        let peer = self.peer.clone();
 452        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 453        let pool = self.connection_pool.clone();
 454        let livekit_client = self.app_state.livekit_client.clone();
 455
 456        let span = info_span!("start server");
 457        self.app_state.executor.spawn_detached(
 458            async move {
 459                tracing::info!("waiting for cleanup timeout");
 460                timeout.await;
 461                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 462                if let Some((room_ids, channel_ids)) = app_state
 463                    .db
 464                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 465                    .await
 466                    .trace_err()
 467                {
 468                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 469                    tracing::info!(
 470                        stale_channel_buffer_count = channel_ids.len(),
 471                        "retrieved stale channel buffers"
 472                    );
 473
 474                    for channel_id in channel_ids {
 475                        if let Some(refreshed_channel_buffer) = app_state
 476                            .db
 477                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 478                            .await
 479                            .trace_err()
 480                        {
 481                            for connection_id in refreshed_channel_buffer.connection_ids {
 482                                peer.send(
 483                                    connection_id,
 484                                    proto::UpdateChannelBufferCollaborators {
 485                                        channel_id: channel_id.to_proto(),
 486                                        collaborators: refreshed_channel_buffer
 487                                            .collaborators
 488                                            .clone(),
 489                                    },
 490                                )
 491                                .trace_err();
 492                            }
 493                        }
 494                    }
 495
 496                    for room_id in room_ids {
 497                        let mut contacts_to_update = HashSet::default();
 498                        let mut canceled_calls_to_user_ids = Vec::new();
 499                        let mut livekit_room = String::new();
 500                        let mut delete_livekit_room = false;
 501
 502                        if let Some(mut refreshed_room) = app_state
 503                            .db
 504                            .clear_stale_room_participants(room_id, server_id)
 505                            .await
 506                            .trace_err()
 507                        {
 508                            tracing::info!(
 509                                room_id = room_id.0,
 510                                new_participant_count = refreshed_room.room.participants.len(),
 511                                "refreshed room"
 512                            );
 513                            room_updated(&refreshed_room.room, &peer);
 514                            if let Some(channel) = refreshed_room.channel.as_ref() {
 515                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 516                            }
 517                            contacts_to_update
 518                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 519                            contacts_to_update
 520                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 521                            canceled_calls_to_user_ids =
 522                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 523                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 524                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 525                        }
 526
 527                        {
 528                            let pool = pool.lock();
 529                            for canceled_user_id in canceled_calls_to_user_ids {
 530                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 531                                    peer.send(
 532                                        connection_id,
 533                                        proto::CallCanceled {
 534                                            room_id: room_id.to_proto(),
 535                                        },
 536                                    )
 537                                    .trace_err();
 538                                }
 539                            }
 540                        }
 541
 542                        for user_id in contacts_to_update {
 543                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 544                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 545                            if let Some((busy, contacts)) = busy.zip(contacts) {
 546                                let pool = pool.lock();
 547                                let updated_contact = contact_for_user(user_id, busy, &pool);
 548                                for contact in contacts {
 549                                    if let db::Contact::Accepted {
 550                                        user_id: contact_user_id,
 551                                        ..
 552                                    } = contact
 553                                    {
 554                                        for contact_conn_id in
 555                                            pool.user_connection_ids(contact_user_id)
 556                                        {
 557                                            peer.send(
 558                                                contact_conn_id,
 559                                                proto::UpdateContacts {
 560                                                    contacts: vec![updated_contact.clone()],
 561                                                    remove_contacts: Default::default(),
 562                                                    incoming_requests: Default::default(),
 563                                                    remove_incoming_requests: Default::default(),
 564                                                    outgoing_requests: Default::default(),
 565                                                    remove_outgoing_requests: Default::default(),
 566                                                },
 567                                            )
 568                                            .trace_err();
 569                                        }
 570                                    }
 571                                }
 572                            }
 573                        }
 574
 575                        if let Some(live_kit) = livekit_client.as_ref() {
 576                            if delete_livekit_room {
 577                                live_kit.delete_room(livekit_room).await.trace_err();
 578                            }
 579                        }
 580                    }
 581                }
 582
 583                app_state
 584                    .db
 585                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 586                    .await
 587                    .trace_err();
 588            }
 589            .instrument(span),
 590        );
 591        Ok(())
 592    }
 593
 594    pub fn teardown(&self) {
 595        self.peer.teardown();
 596        self.connection_pool.lock().reset();
 597        let _ = self.teardown.send(true);
 598    }
 599
 600    #[cfg(test)]
 601    pub fn reset(&self, id: ServerId) {
 602        self.teardown();
 603        *self.id.lock() = id;
 604        self.peer.reset(id.0 as u32);
 605        let _ = self.teardown.send(false);
 606    }
 607
 608    #[cfg(test)]
 609    pub fn id(&self) -> ServerId {
 610        *self.id.lock()
 611    }
 612
 613    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 614    where
 615        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 616        Fut: 'static + Send + Future<Output = Result<()>>,
 617        M: EnvelopedMessage,
 618    {
 619        let prev_handler = self.handlers.insert(
 620            TypeId::of::<M>(),
 621            Box::new(move |envelope, session| {
 622                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 623                let received_at = envelope.received_at;
 624                tracing::info!("message received");
 625                let start_time = Instant::now();
 626                let future = (handler)(*envelope, session);
 627                async move {
 628                    let result = future.await;
 629                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 630                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 631                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 632                    let payload_type = M::NAME;
 633
 634                    match result {
 635                        Err(error) => {
 636                            tracing::error!(
 637                                ?error,
 638                                total_duration_ms,
 639                                processing_duration_ms,
 640                                queue_duration_ms,
 641                                payload_type,
 642                                "error handling message"
 643                            )
 644                        }
 645                        Ok(()) => tracing::info!(
 646                            total_duration_ms,
 647                            processing_duration_ms,
 648                            queue_duration_ms,
 649                            "finished handling message"
 650                        ),
 651                    }
 652                }
 653                .boxed()
 654            }),
 655        );
 656        if prev_handler.is_some() {
 657            panic!("registered a handler for the same message twice");
 658        }
 659        self
 660    }
 661
 662    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 663    where
 664        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 665        Fut: 'static + Send + Future<Output = Result<()>>,
 666        M: EnvelopedMessage,
 667    {
 668        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 669        self
 670    }
 671
 672    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 673    where
 674        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 675        Fut: Send + Future<Output = Result<()>>,
 676        M: RequestMessage,
 677    {
 678        let handler = Arc::new(handler);
 679        self.add_handler(move |envelope, session| {
 680            let receipt = envelope.receipt();
 681            let handler = handler.clone();
 682            async move {
 683                let peer = session.peer.clone();
 684                let responded = Arc::new(AtomicBool::default());
 685                let response = Response {
 686                    peer: peer.clone(),
 687                    responded: responded.clone(),
 688                    receipt,
 689                };
 690                match (handler)(envelope.payload, response, session).await {
 691                    Ok(()) => {
 692                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 693                            Ok(())
 694                        } else {
 695                            Err(anyhow!("handler did not send a response"))?
 696                        }
 697                    }
 698                    Err(error) => {
 699                        let proto_err = match &error {
 700                            Error::Internal(err) => err.to_proto(),
 701                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 702                        };
 703                        peer.respond_with_error(receipt, proto_err)?;
 704                        Err(error)
 705                    }
 706                }
 707            }
 708        })
 709    }
 710
 711    pub fn handle_connection(
 712        self: &Arc<Self>,
 713        connection: Connection,
 714        address: String,
 715        principal: Principal,
 716        zed_version: ZedVersion,
 717        geoip_country_code: Option<String>,
 718        system_id: Option<String>,
 719        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 720        executor: Executor,
 721    ) -> impl Future<Output = ()> + use<> {
 722        let this = self.clone();
 723        let span = info_span!("handle connection", %address,
 724            connection_id=field::Empty,
 725            user_id=field::Empty,
 726            login=field::Empty,
 727            impersonator=field::Empty,
 728            geoip_country_code=field::Empty
 729        );
 730        principal.update_span(&span);
 731        if let Some(country_code) = geoip_country_code.as_ref() {
 732            span.record("geoip_country_code", country_code);
 733        }
 734
 735        let mut teardown = self.teardown.subscribe();
 736        async move {
 737            if *teardown.borrow() {
 738                tracing::error!("server is tearing down");
 739                return
 740            }
 741            let (connection_id, handle_io, mut incoming_rx) = this
 742                .peer
 743                .add_connection(connection, {
 744                    let executor = executor.clone();
 745                    move |duration| executor.sleep(duration)
 746                });
 747            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 748
 749            tracing::info!("connection opened");
 750
 751            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 752            let http_client = match ReqwestClient::user_agent(&user_agent) {
 753                Ok(http_client) => Arc::new(http_client),
 754                Err(error) => {
 755                    tracing::error!(?error, "failed to create HTTP client");
 756                    return;
 757                }
 758            };
 759
 760            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 761                    supermaven_admin_api_key.to_string(),
 762                    http_client.clone(),
 763                )));
 764
 765            let session = Session {
 766                principal: principal.clone(),
 767                connection_id,
 768                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 769                peer: this.peer.clone(),
 770                connection_pool: this.connection_pool.clone(),
 771                app_state: this.app_state.clone(),
 772                geoip_country_code,
 773                system_id,
 774                _executor: executor.clone(),
 775                supermaven_client,
 776            };
 777
 778            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 779                tracing::error!(?error, "failed to send initial client update");
 780                return;
 781            }
 782
 783            let handle_io = handle_io.fuse();
 784            futures::pin_mut!(handle_io);
 785
 786            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 787            // This prevents deadlocks when e.g., client A performs a request to client B and
 788            // client B performs a request to client A. If both clients stop processing further
 789            // messages until their respective request completes, they won't have a chance to
 790            // respond to the other client's request and cause a deadlock.
 791            //
 792            // This arrangement ensures we will attempt to process earlier messages first, but fall
 793            // back to processing messages arrived later in the spirit of making progress.
 794            let mut foreground_message_handlers = FuturesUnordered::new();
 795            let concurrent_handlers = Arc::new(Semaphore::new(256));
 796            loop {
 797                let next_message = async {
 798                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 799                    let message = incoming_rx.next().await;
 800                    (permit, message)
 801                }.fuse();
 802                futures::pin_mut!(next_message);
 803                futures::select_biased! {
 804                    _ = teardown.changed().fuse() => return,
 805                    result = handle_io => {
 806                        if let Err(error) = result {
 807                            tracing::error!(?error, "error handling I/O");
 808                        }
 809                        break;
 810                    }
 811                    _ = foreground_message_handlers.next() => {}
 812                    next_message = next_message => {
 813                        let (permit, message) = next_message;
 814                        if let Some(message) = message {
 815                            let type_name = message.payload_type_name();
 816                            // note: we copy all the fields from the parent span so we can query them in the logs.
 817                            // (https://github.com/tokio-rs/tracing/issues/2670).
 818                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 819                                user_id=field::Empty,
 820                                login=field::Empty,
 821                                impersonator=field::Empty,
 822                            );
 823                            principal.update_span(&span);
 824                            let span_enter = span.enter();
 825                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 826                                let is_background = message.is_background();
 827                                let handle_message = (handler)(message, session.clone());
 828                                drop(span_enter);
 829
 830                                let handle_message = async move {
 831                                    handle_message.await;
 832                                    drop(permit);
 833                                }.instrument(span);
 834                                if is_background {
 835                                    executor.spawn_detached(handle_message);
 836                                } else {
 837                                    foreground_message_handlers.push(handle_message);
 838                                }
 839                            } else {
 840                                tracing::error!("no message handler");
 841                            }
 842                        } else {
 843                            tracing::info!("connection closed");
 844                            break;
 845                        }
 846                    }
 847                }
 848            }
 849
 850            drop(foreground_message_handlers);
 851            tracing::info!("signing out");
 852            if let Err(error) = connection_lost(session, teardown, executor).await {
 853                tracing::error!(?error, "error signing out");
 854            }
 855
 856        }.instrument(span)
 857    }
 858
 859    async fn send_initial_client_update(
 860        &self,
 861        connection_id: ConnectionId,
 862        principal: &Principal,
 863        zed_version: ZedVersion,
 864        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 865        session: &Session,
 866    ) -> Result<()> {
 867        self.peer.send(
 868            connection_id,
 869            proto::Hello {
 870                peer_id: Some(connection_id.into()),
 871            },
 872        )?;
 873        tracing::info!("sent hello message");
 874        if let Some(send_connection_id) = send_connection_id.take() {
 875            let _ = send_connection_id.send(connection_id);
 876        }
 877
 878        match principal {
 879            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 880                if !user.connected_once {
 881                    self.peer.send(connection_id, proto::ShowContacts {})?;
 882                    self.app_state
 883                        .db
 884                        .set_user_connected_once(user.id, true)
 885                        .await?;
 886                }
 887
 888                update_user_plan(user.id, session).await?;
 889
 890                let contacts = self.app_state.db.get_contacts(user.id).await?;
 891
 892                {
 893                    let mut pool = self.connection_pool.lock();
 894                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 895                    self.peer.send(
 896                        connection_id,
 897                        build_initial_contacts_update(contacts, &pool),
 898                    )?;
 899                }
 900
 901                if should_auto_subscribe_to_channels(zed_version) {
 902                    subscribe_user_to_channels(user.id, session).await?;
 903                }
 904
 905                if let Some(incoming_call) =
 906                    self.app_state.db.incoming_call_for_user(user.id).await?
 907                {
 908                    self.peer.send(connection_id, incoming_call)?;
 909                }
 910
 911                update_user_contacts(user.id, session).await?;
 912            }
 913        }
 914
 915        Ok(())
 916    }
 917
 918    pub async fn invite_code_redeemed(
 919        self: &Arc<Self>,
 920        inviter_id: UserId,
 921        invitee_id: UserId,
 922    ) -> Result<()> {
 923        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 924            if let Some(code) = &user.invite_code {
 925                let pool = self.connection_pool.lock();
 926                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 927                for connection_id in pool.user_connection_ids(inviter_id) {
 928                    self.peer.send(
 929                        connection_id,
 930                        proto::UpdateContacts {
 931                            contacts: vec![invitee_contact.clone()],
 932                            ..Default::default()
 933                        },
 934                    )?;
 935                    self.peer.send(
 936                        connection_id,
 937                        proto::UpdateInviteInfo {
 938                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 939                            count: user.invite_count as u32,
 940                        },
 941                    )?;
 942                }
 943            }
 944        }
 945        Ok(())
 946    }
 947
 948    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 949        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 950            if let Some(invite_code) = &user.invite_code {
 951                let pool = self.connection_pool.lock();
 952                for connection_id in pool.user_connection_ids(user_id) {
 953                    self.peer.send(
 954                        connection_id,
 955                        proto::UpdateInviteInfo {
 956                            url: format!(
 957                                "{}{}",
 958                                self.app_state.config.invite_link_prefix, invite_code
 959                            ),
 960                            count: user.invite_count as u32,
 961                        },
 962                    )?;
 963                }
 964            }
 965        }
 966        Ok(())
 967    }
 968
 969    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 970        let pool = self.connection_pool.lock();
 971        for connection_id in pool.user_connection_ids(user_id) {
 972            self.peer
 973                .send(connection_id, proto::RefreshLlmToken {})
 974                .trace_err();
 975        }
 976    }
 977
 978    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 979        ServerSnapshot {
 980            connection_pool: ConnectionPoolGuard {
 981                guard: self.connection_pool.lock(),
 982                _not_send: PhantomData,
 983            },
 984            peer: &self.peer,
 985        }
 986    }
 987}
 988
 989impl Deref for ConnectionPoolGuard<'_> {
 990    type Target = ConnectionPool;
 991
 992    fn deref(&self) -> &Self::Target {
 993        &self.guard
 994    }
 995}
 996
 997impl DerefMut for ConnectionPoolGuard<'_> {
 998    fn deref_mut(&mut self) -> &mut Self::Target {
 999        &mut self.guard
1000    }
1001}
1002
1003impl Drop for ConnectionPoolGuard<'_> {
1004    fn drop(&mut self) {
1005        #[cfg(test)]
1006        self.check_invariants();
1007    }
1008}
1009
1010fn broadcast<F>(
1011    sender_id: Option<ConnectionId>,
1012    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1013    mut f: F,
1014) where
1015    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1016{
1017    for receiver_id in receiver_ids {
1018        if Some(receiver_id) != sender_id {
1019            if let Err(error) = f(receiver_id) {
1020                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1021            }
1022        }
1023    }
1024}
1025
1026pub struct ProtocolVersion(u32);
1027
1028impl Header for ProtocolVersion {
1029    fn name() -> &'static HeaderName {
1030        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1032    }
1033
1034    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035    where
1036        Self: Sized,
1037        I: Iterator<Item = &'i axum::http::HeaderValue>,
1038    {
1039        let version = values
1040            .next()
1041            .ok_or_else(axum::headers::Error::invalid)?
1042            .to_str()
1043            .map_err(|_| axum::headers::Error::invalid())?
1044            .parse()
1045            .map_err(|_| axum::headers::Error::invalid())?;
1046        Ok(Self(version))
1047    }
1048
1049    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050        values.extend([self.0.to_string().parse().unwrap()]);
1051    }
1052}
1053
1054pub struct AppVersionHeader(SemanticVersion);
1055impl Header for AppVersionHeader {
1056    fn name() -> &'static HeaderName {
1057        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1058        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1059    }
1060
1061    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1062    where
1063        Self: Sized,
1064        I: Iterator<Item = &'i axum::http::HeaderValue>,
1065    {
1066        let version = values
1067            .next()
1068            .ok_or_else(axum::headers::Error::invalid)?
1069            .to_str()
1070            .map_err(|_| axum::headers::Error::invalid())?
1071            .parse()
1072            .map_err(|_| axum::headers::Error::invalid())?;
1073        Ok(Self(version))
1074    }
1075
1076    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1077        values.extend([self.0.to_string().parse().unwrap()]);
1078    }
1079}
1080
1081pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1082    Router::new()
1083        .route("/rpc", get(handle_websocket_request))
1084        .layer(
1085            ServiceBuilder::new()
1086                .layer(Extension(server.app_state.clone()))
1087                .layer(middleware::from_fn(auth::validate_header)),
1088        )
1089        .route("/metrics", get(handle_metrics))
1090        .layer(Extension(server))
1091}
1092
1093pub async fn handle_websocket_request(
1094    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1095    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1096    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1097    Extension(server): Extension<Arc<Server>>,
1098    Extension(principal): Extension<Principal>,
1099    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1100    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1101    ws: WebSocketUpgrade,
1102) -> axum::response::Response {
1103    if protocol_version != rpc::PROTOCOL_VERSION {
1104        return (
1105            StatusCode::UPGRADE_REQUIRED,
1106            "client must be upgraded".to_string(),
1107        )
1108            .into_response();
1109    }
1110
1111    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1112        return (
1113            StatusCode::UPGRADE_REQUIRED,
1114            "no version header found".to_string(),
1115        )
1116            .into_response();
1117    };
1118
1119    if !version.can_collaborate() {
1120        return (
1121            StatusCode::UPGRADE_REQUIRED,
1122            "client must be upgraded".to_string(),
1123        )
1124            .into_response();
1125    }
1126
1127    let socket_address = socket_address.to_string();
1128    ws.on_upgrade(move |socket| {
1129        let socket = socket
1130            .map_ok(to_tungstenite_message)
1131            .err_into()
1132            .with(|message| async move { to_axum_message(message) });
1133        let connection = Connection::new(Box::pin(socket));
1134        async move {
1135            server
1136                .handle_connection(
1137                    connection,
1138                    socket_address,
1139                    principal,
1140                    version,
1141                    country_code_header.map(|header| header.to_string()),
1142                    system_id_header.map(|header| header.to_string()),
1143                    None,
1144                    Executor::Production,
1145                )
1146                .await;
1147        }
1148    })
1149}
1150
1151pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1152    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1153    let connections_metric = CONNECTIONS_METRIC
1154        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1155
1156    let connections = server
1157        .connection_pool
1158        .lock()
1159        .connections()
1160        .filter(|connection| !connection.admin)
1161        .count();
1162    connections_metric.set(connections as _);
1163
1164    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1165    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1166        register_int_gauge!(
1167            "shared_projects",
1168            "number of open projects with one or more guests"
1169        )
1170        .unwrap()
1171    });
1172
1173    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1174    shared_projects_metric.set(shared_projects as _);
1175
1176    let encoder = prometheus::TextEncoder::new();
1177    let metric_families = prometheus::gather();
1178    let encoded_metrics = encoder
1179        .encode_to_string(&metric_families)
1180        .map_err(|err| anyhow!("{}", err))?;
1181    Ok(encoded_metrics)
1182}
1183
1184#[instrument(err, skip(executor))]
1185async fn connection_lost(
1186    session: Session,
1187    mut teardown: watch::Receiver<bool>,
1188    executor: Executor,
1189) -> Result<()> {
1190    session.peer.disconnect(session.connection_id);
1191    session
1192        .connection_pool()
1193        .await
1194        .remove_connection(session.connection_id)?;
1195
1196    session
1197        .db()
1198        .await
1199        .connection_lost(session.connection_id)
1200        .await
1201        .trace_err();
1202
1203    futures::select_biased! {
1204        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1205
1206            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1207            leave_room_for_session(&session, session.connection_id).await.trace_err();
1208            leave_channel_buffers_for_session(&session)
1209                .await
1210                .trace_err();
1211
1212            if !session
1213                .connection_pool()
1214                .await
1215                .is_user_online(session.user_id())
1216            {
1217                let db = session.db().await;
1218                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1219                    room_updated(&room, &session.peer);
1220                }
1221            }
1222
1223            update_user_contacts(session.user_id(), &session).await?;
1224        },
1225        _ = teardown.changed().fuse() => {}
1226    }
1227
1228    Ok(())
1229}
1230
1231/// Acknowledges a ping from a client, used to keep the connection alive.
1232async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1233    response.send(proto::Ack {})?;
1234    Ok(())
1235}
1236
1237/// Creates a new room for calling (outside of channels)
1238async fn create_room(
1239    _request: proto::CreateRoom,
1240    response: Response<proto::CreateRoom>,
1241    session: Session,
1242) -> Result<()> {
1243    let livekit_room = nanoid::nanoid!(30);
1244
1245    let live_kit_connection_info = util::maybe!(async {
1246        let live_kit = session.app_state.livekit_client.as_ref();
1247        let live_kit = live_kit?;
1248        let user_id = session.user_id().to_string();
1249
1250        let token = live_kit
1251            .room_token(&livekit_room, &user_id.to_string())
1252            .trace_err()?;
1253
1254        Some(proto::LiveKitConnectionInfo {
1255            server_url: live_kit.url().into(),
1256            token,
1257            can_publish: true,
1258        })
1259    })
1260    .await;
1261
1262    let room = session
1263        .db()
1264        .await
1265        .create_room(session.user_id(), session.connection_id, &livekit_room)
1266        .await?;
1267
1268    response.send(proto::CreateRoomResponse {
1269        room: Some(room.clone()),
1270        live_kit_connection_info,
1271    })?;
1272
1273    update_user_contacts(session.user_id(), &session).await?;
1274    Ok(())
1275}
1276
1277/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1278async fn join_room(
1279    request: proto::JoinRoom,
1280    response: Response<proto::JoinRoom>,
1281    session: Session,
1282) -> Result<()> {
1283    let room_id = RoomId::from_proto(request.id);
1284
1285    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1286
1287    if let Some(channel_id) = channel_id {
1288        return join_channel_internal(channel_id, Box::new(response), session).await;
1289    }
1290
1291    let joined_room = {
1292        let room = session
1293            .db()
1294            .await
1295            .join_room(room_id, session.user_id(), session.connection_id)
1296            .await?;
1297        room_updated(&room.room, &session.peer);
1298        room.into_inner()
1299    };
1300
1301    for connection_id in session
1302        .connection_pool()
1303        .await
1304        .user_connection_ids(session.user_id())
1305    {
1306        session
1307            .peer
1308            .send(
1309                connection_id,
1310                proto::CallCanceled {
1311                    room_id: room_id.to_proto(),
1312                },
1313            )
1314            .trace_err();
1315    }
1316
1317    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1318    {
1319        live_kit
1320            .room_token(
1321                &joined_room.room.livekit_room,
1322                &session.user_id().to_string(),
1323            )
1324            .trace_err()
1325            .map(|token| proto::LiveKitConnectionInfo {
1326                server_url: live_kit.url().into(),
1327                token,
1328                can_publish: true,
1329            })
1330    } else {
1331        None
1332    };
1333
1334    response.send(proto::JoinRoomResponse {
1335        room: Some(joined_room.room),
1336        channel_id: None,
1337        live_kit_connection_info,
1338    })?;
1339
1340    update_user_contacts(session.user_id(), &session).await?;
1341    Ok(())
1342}
1343
1344/// Rejoin room is used to reconnect to a room after connection errors.
1345async fn rejoin_room(
1346    request: proto::RejoinRoom,
1347    response: Response<proto::RejoinRoom>,
1348    session: Session,
1349) -> Result<()> {
1350    let room;
1351    let channel;
1352    {
1353        let mut rejoined_room = session
1354            .db()
1355            .await
1356            .rejoin_room(request, session.user_id(), session.connection_id)
1357            .await?;
1358
1359        response.send(proto::RejoinRoomResponse {
1360            room: Some(rejoined_room.room.clone()),
1361            reshared_projects: rejoined_room
1362                .reshared_projects
1363                .iter()
1364                .map(|project| proto::ResharedProject {
1365                    id: project.id.to_proto(),
1366                    collaborators: project
1367                        .collaborators
1368                        .iter()
1369                        .map(|collaborator| collaborator.to_proto())
1370                        .collect(),
1371                })
1372                .collect(),
1373            rejoined_projects: rejoined_room
1374                .rejoined_projects
1375                .iter()
1376                .map(|rejoined_project| rejoined_project.to_proto())
1377                .collect(),
1378        })?;
1379        room_updated(&rejoined_room.room, &session.peer);
1380
1381        for project in &rejoined_room.reshared_projects {
1382            for collaborator in &project.collaborators {
1383                session
1384                    .peer
1385                    .send(
1386                        collaborator.connection_id,
1387                        proto::UpdateProjectCollaborator {
1388                            project_id: project.id.to_proto(),
1389                            old_peer_id: Some(project.old_connection_id.into()),
1390                            new_peer_id: Some(session.connection_id.into()),
1391                        },
1392                    )
1393                    .trace_err();
1394            }
1395
1396            broadcast(
1397                Some(session.connection_id),
1398                project
1399                    .collaborators
1400                    .iter()
1401                    .map(|collaborator| collaborator.connection_id),
1402                |connection_id| {
1403                    session.peer.forward_send(
1404                        session.connection_id,
1405                        connection_id,
1406                        proto::UpdateProject {
1407                            project_id: project.id.to_proto(),
1408                            worktrees: project.worktrees.clone(),
1409                        },
1410                    )
1411                },
1412            );
1413        }
1414
1415        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1416
1417        let rejoined_room = rejoined_room.into_inner();
1418
1419        room = rejoined_room.room;
1420        channel = rejoined_room.channel;
1421    }
1422
1423    if let Some(channel) = channel {
1424        channel_updated(
1425            &channel,
1426            &room,
1427            &session.peer,
1428            &*session.connection_pool().await,
1429        );
1430    }
1431
1432    update_user_contacts(session.user_id(), &session).await?;
1433    Ok(())
1434}
1435
1436fn notify_rejoined_projects(
1437    rejoined_projects: &mut Vec<RejoinedProject>,
1438    session: &Session,
1439) -> Result<()> {
1440    for project in rejoined_projects.iter() {
1441        for collaborator in &project.collaborators {
1442            session
1443                .peer
1444                .send(
1445                    collaborator.connection_id,
1446                    proto::UpdateProjectCollaborator {
1447                        project_id: project.id.to_proto(),
1448                        old_peer_id: Some(project.old_connection_id.into()),
1449                        new_peer_id: Some(session.connection_id.into()),
1450                    },
1451                )
1452                .trace_err();
1453        }
1454    }
1455
1456    for project in rejoined_projects {
1457        for worktree in mem::take(&mut project.worktrees) {
1458            // Stream this worktree's entries.
1459            let message = proto::UpdateWorktree {
1460                project_id: project.id.to_proto(),
1461                worktree_id: worktree.id,
1462                abs_path: worktree.abs_path.clone(),
1463                root_name: worktree.root_name,
1464                updated_entries: worktree.updated_entries,
1465                removed_entries: worktree.removed_entries,
1466                scan_id: worktree.scan_id,
1467                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1468                updated_repositories: worktree.updated_repositories,
1469                removed_repositories: worktree.removed_repositories,
1470            };
1471            for update in proto::split_worktree_update(message) {
1472                session.peer.send(session.connection_id, update)?;
1473            }
1474
1475            // Stream this worktree's diagnostics.
1476            for summary in worktree.diagnostic_summaries {
1477                session.peer.send(
1478                    session.connection_id,
1479                    proto::UpdateDiagnosticSummary {
1480                        project_id: project.id.to_proto(),
1481                        worktree_id: worktree.id,
1482                        summary: Some(summary),
1483                    },
1484                )?;
1485            }
1486
1487            for settings_file in worktree.settings_files {
1488                session.peer.send(
1489                    session.connection_id,
1490                    proto::UpdateWorktreeSettings {
1491                        project_id: project.id.to_proto(),
1492                        worktree_id: worktree.id,
1493                        path: settings_file.path,
1494                        content: Some(settings_file.content),
1495                        kind: Some(settings_file.kind.to_proto().into()),
1496                    },
1497                )?;
1498            }
1499        }
1500
1501        for repository in mem::take(&mut project.updated_repositories) {
1502            for update in split_repository_update(repository) {
1503                session.peer.send(session.connection_id, update)?;
1504            }
1505        }
1506
1507        for id in mem::take(&mut project.removed_repositories) {
1508            session.peer.send(
1509                session.connection_id,
1510                proto::RemoveRepository {
1511                    project_id: project.id.to_proto(),
1512                    id,
1513                },
1514            )?;
1515        }
1516    }
1517
1518    Ok(())
1519}
1520
1521/// leave room disconnects from the room.
1522async fn leave_room(
1523    _: proto::LeaveRoom,
1524    response: Response<proto::LeaveRoom>,
1525    session: Session,
1526) -> Result<()> {
1527    leave_room_for_session(&session, session.connection_id).await?;
1528    response.send(proto::Ack {})?;
1529    Ok(())
1530}
1531
1532/// Updates the permissions of someone else in the room.
1533async fn set_room_participant_role(
1534    request: proto::SetRoomParticipantRole,
1535    response: Response<proto::SetRoomParticipantRole>,
1536    session: Session,
1537) -> Result<()> {
1538    let user_id = UserId::from_proto(request.user_id);
1539    let role = ChannelRole::from(request.role());
1540
1541    let (livekit_room, can_publish) = {
1542        let room = session
1543            .db()
1544            .await
1545            .set_room_participant_role(
1546                session.user_id(),
1547                RoomId::from_proto(request.room_id),
1548                user_id,
1549                role,
1550            )
1551            .await?;
1552
1553        let livekit_room = room.livekit_room.clone();
1554        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1555        room_updated(&room, &session.peer);
1556        (livekit_room, can_publish)
1557    };
1558
1559    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1560        live_kit
1561            .update_participant(
1562                livekit_room.clone(),
1563                request.user_id.to_string(),
1564                livekit_api::proto::ParticipantPermission {
1565                    can_subscribe: true,
1566                    can_publish,
1567                    can_publish_data: can_publish,
1568                    hidden: false,
1569                    recorder: false,
1570                },
1571            )
1572            .await
1573            .trace_err();
1574    }
1575
1576    response.send(proto::Ack {})?;
1577    Ok(())
1578}
1579
1580/// Call someone else into the current room
1581async fn call(
1582    request: proto::Call,
1583    response: Response<proto::Call>,
1584    session: Session,
1585) -> Result<()> {
1586    let room_id = RoomId::from_proto(request.room_id);
1587    let calling_user_id = session.user_id();
1588    let calling_connection_id = session.connection_id;
1589    let called_user_id = UserId::from_proto(request.called_user_id);
1590    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1591    if !session
1592        .db()
1593        .await
1594        .has_contact(calling_user_id, called_user_id)
1595        .await?
1596    {
1597        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1598    }
1599
1600    let incoming_call = {
1601        let (room, incoming_call) = &mut *session
1602            .db()
1603            .await
1604            .call(
1605                room_id,
1606                calling_user_id,
1607                calling_connection_id,
1608                called_user_id,
1609                initial_project_id,
1610            )
1611            .await?;
1612        room_updated(room, &session.peer);
1613        mem::take(incoming_call)
1614    };
1615    update_user_contacts(called_user_id, &session).await?;
1616
1617    let mut calls = session
1618        .connection_pool()
1619        .await
1620        .user_connection_ids(called_user_id)
1621        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1622        .collect::<FuturesUnordered<_>>();
1623
1624    while let Some(call_response) = calls.next().await {
1625        match call_response.as_ref() {
1626            Ok(_) => {
1627                response.send(proto::Ack {})?;
1628                return Ok(());
1629            }
1630            Err(_) => {
1631                call_response.trace_err();
1632            }
1633        }
1634    }
1635
1636    {
1637        let room = session
1638            .db()
1639            .await
1640            .call_failed(room_id, called_user_id)
1641            .await?;
1642        room_updated(&room, &session.peer);
1643    }
1644    update_user_contacts(called_user_id, &session).await?;
1645
1646    Err(anyhow!("failed to ring user"))?
1647}
1648
1649/// Cancel an outgoing call.
1650async fn cancel_call(
1651    request: proto::CancelCall,
1652    response: Response<proto::CancelCall>,
1653    session: Session,
1654) -> Result<()> {
1655    let called_user_id = UserId::from_proto(request.called_user_id);
1656    let room_id = RoomId::from_proto(request.room_id);
1657    {
1658        let room = session
1659            .db()
1660            .await
1661            .cancel_call(room_id, session.connection_id, called_user_id)
1662            .await?;
1663        room_updated(&room, &session.peer);
1664    }
1665
1666    for connection_id in session
1667        .connection_pool()
1668        .await
1669        .user_connection_ids(called_user_id)
1670    {
1671        session
1672            .peer
1673            .send(
1674                connection_id,
1675                proto::CallCanceled {
1676                    room_id: room_id.to_proto(),
1677                },
1678            )
1679            .trace_err();
1680    }
1681    response.send(proto::Ack {})?;
1682
1683    update_user_contacts(called_user_id, &session).await?;
1684    Ok(())
1685}
1686
1687/// Decline an incoming call.
1688async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1689    let room_id = RoomId::from_proto(message.room_id);
1690    {
1691        let room = session
1692            .db()
1693            .await
1694            .decline_call(Some(room_id), session.user_id())
1695            .await?
1696            .ok_or_else(|| anyhow!("failed to decline call"))?;
1697        room_updated(&room, &session.peer);
1698    }
1699
1700    for connection_id in session
1701        .connection_pool()
1702        .await
1703        .user_connection_ids(session.user_id())
1704    {
1705        session
1706            .peer
1707            .send(
1708                connection_id,
1709                proto::CallCanceled {
1710                    room_id: room_id.to_proto(),
1711                },
1712            )
1713            .trace_err();
1714    }
1715    update_user_contacts(session.user_id(), &session).await?;
1716    Ok(())
1717}
1718
1719/// Updates other participants in the room with your current location.
1720async fn update_participant_location(
1721    request: proto::UpdateParticipantLocation,
1722    response: Response<proto::UpdateParticipantLocation>,
1723    session: Session,
1724) -> Result<()> {
1725    let room_id = RoomId::from_proto(request.room_id);
1726    let location = request
1727        .location
1728        .ok_or_else(|| anyhow!("invalid location"))?;
1729
1730    let db = session.db().await;
1731    let room = db
1732        .update_room_participant_location(room_id, session.connection_id, location)
1733        .await?;
1734
1735    room_updated(&room, &session.peer);
1736    response.send(proto::Ack {})?;
1737    Ok(())
1738}
1739
1740/// Share a project into the room.
1741async fn share_project(
1742    request: proto::ShareProject,
1743    response: Response<proto::ShareProject>,
1744    session: Session,
1745) -> Result<()> {
1746    let (project_id, room) = &*session
1747        .db()
1748        .await
1749        .share_project(
1750            RoomId::from_proto(request.room_id),
1751            session.connection_id,
1752            &request.worktrees,
1753            request.is_ssh_project,
1754        )
1755        .await?;
1756    response.send(proto::ShareProjectResponse {
1757        project_id: project_id.to_proto(),
1758    })?;
1759    room_updated(room, &session.peer);
1760
1761    Ok(())
1762}
1763
1764/// Unshare a project from the room.
1765async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1766    let project_id = ProjectId::from_proto(message.project_id);
1767    unshare_project_internal(project_id, session.connection_id, &session).await
1768}
1769
1770async fn unshare_project_internal(
1771    project_id: ProjectId,
1772    connection_id: ConnectionId,
1773    session: &Session,
1774) -> Result<()> {
1775    let delete = {
1776        let room_guard = session
1777            .db()
1778            .await
1779            .unshare_project(project_id, connection_id)
1780            .await?;
1781
1782        let (delete, room, guest_connection_ids) = &*room_guard;
1783
1784        let message = proto::UnshareProject {
1785            project_id: project_id.to_proto(),
1786        };
1787
1788        broadcast(
1789            Some(connection_id),
1790            guest_connection_ids.iter().copied(),
1791            |conn_id| session.peer.send(conn_id, message.clone()),
1792        );
1793        if let Some(room) = room {
1794            room_updated(room, &session.peer);
1795        }
1796
1797        *delete
1798    };
1799
1800    if delete {
1801        let db = session.db().await;
1802        db.delete_project(project_id).await?;
1803    }
1804
1805    Ok(())
1806}
1807
1808/// Join someone elses shared project.
1809async fn join_project(
1810    request: proto::JoinProject,
1811    response: Response<proto::JoinProject>,
1812    session: Session,
1813) -> Result<()> {
1814    let project_id = ProjectId::from_proto(request.project_id);
1815
1816    tracing::info!(%project_id, "join project");
1817
1818    let db = session.db().await;
1819    let (project, replica_id) = &mut *db
1820        .join_project(project_id, session.connection_id, session.user_id())
1821        .await?;
1822    drop(db);
1823    tracing::info!(%project_id, "join remote project");
1824    join_project_internal(response, session, project, replica_id)
1825}
1826
1827trait JoinProjectInternalResponse {
1828    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1829}
1830impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1831    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1832        Response::<proto::JoinProject>::send(self, result)
1833    }
1834}
1835
1836fn join_project_internal(
1837    response: impl JoinProjectInternalResponse,
1838    session: Session,
1839    project: &mut Project,
1840    replica_id: &ReplicaId,
1841) -> Result<()> {
1842    let collaborators = project
1843        .collaborators
1844        .iter()
1845        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1846        .map(|collaborator| collaborator.to_proto())
1847        .collect::<Vec<_>>();
1848    let project_id = project.id;
1849    let guest_user_id = session.user_id();
1850
1851    let worktrees = project
1852        .worktrees
1853        .iter()
1854        .map(|(id, worktree)| proto::WorktreeMetadata {
1855            id: *id,
1856            root_name: worktree.root_name.clone(),
1857            visible: worktree.visible,
1858            abs_path: worktree.abs_path.clone(),
1859        })
1860        .collect::<Vec<_>>();
1861
1862    let add_project_collaborator = proto::AddProjectCollaborator {
1863        project_id: project_id.to_proto(),
1864        collaborator: Some(proto::Collaborator {
1865            peer_id: Some(session.connection_id.into()),
1866            replica_id: replica_id.0 as u32,
1867            user_id: guest_user_id.to_proto(),
1868            is_host: false,
1869        }),
1870    };
1871
1872    for collaborator in &collaborators {
1873        session
1874            .peer
1875            .send(
1876                collaborator.peer_id.unwrap().into(),
1877                add_project_collaborator.clone(),
1878            )
1879            .trace_err();
1880    }
1881
1882    // First, we send the metadata associated with each worktree.
1883    response.send(proto::JoinProjectResponse {
1884        project_id: project.id.0 as u64,
1885        worktrees: worktrees.clone(),
1886        replica_id: replica_id.0 as u32,
1887        collaborators: collaborators.clone(),
1888        language_servers: project.language_servers.clone(),
1889        role: project.role.into(),
1890    })?;
1891
1892    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1893        // Stream this worktree's entries.
1894        let message = proto::UpdateWorktree {
1895            project_id: project_id.to_proto(),
1896            worktree_id,
1897            abs_path: worktree.abs_path.clone(),
1898            root_name: worktree.root_name,
1899            updated_entries: worktree.entries,
1900            removed_entries: Default::default(),
1901            scan_id: worktree.scan_id,
1902            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1903            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1904            removed_repositories: Default::default(),
1905        };
1906        for update in proto::split_worktree_update(message) {
1907            session.peer.send(session.connection_id, update.clone())?;
1908        }
1909
1910        // Stream this worktree's diagnostics.
1911        for summary in worktree.diagnostic_summaries {
1912            session.peer.send(
1913                session.connection_id,
1914                proto::UpdateDiagnosticSummary {
1915                    project_id: project_id.to_proto(),
1916                    worktree_id: worktree.id,
1917                    summary: Some(summary),
1918                },
1919            )?;
1920        }
1921
1922        for settings_file in worktree.settings_files {
1923            session.peer.send(
1924                session.connection_id,
1925                proto::UpdateWorktreeSettings {
1926                    project_id: project_id.to_proto(),
1927                    worktree_id: worktree.id,
1928                    path: settings_file.path,
1929                    content: Some(settings_file.content),
1930                    kind: Some(settings_file.kind.to_proto() as i32),
1931                },
1932            )?;
1933        }
1934    }
1935
1936    for repository in mem::take(&mut project.repositories) {
1937        for update in split_repository_update(repository) {
1938            session.peer.send(session.connection_id, update)?;
1939        }
1940    }
1941
1942    for language_server in &project.language_servers {
1943        session.peer.send(
1944            session.connection_id,
1945            proto::UpdateLanguageServer {
1946                project_id: project_id.to_proto(),
1947                language_server_id: language_server.id,
1948                variant: Some(
1949                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1950                        proto::LspDiskBasedDiagnosticsUpdated {},
1951                    ),
1952                ),
1953            },
1954        )?;
1955    }
1956
1957    Ok(())
1958}
1959
1960/// Leave someone elses shared project.
1961async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1962    let sender_id = session.connection_id;
1963    let project_id = ProjectId::from_proto(request.project_id);
1964    let db = session.db().await;
1965
1966    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1967    tracing::info!(
1968        %project_id,
1969        "leave project"
1970    );
1971
1972    project_left(project, &session);
1973    if let Some(room) = room {
1974        room_updated(room, &session.peer);
1975    }
1976
1977    Ok(())
1978}
1979
1980/// Updates other participants with changes to the project
1981async fn update_project(
1982    request: proto::UpdateProject,
1983    response: Response<proto::UpdateProject>,
1984    session: Session,
1985) -> Result<()> {
1986    let project_id = ProjectId::from_proto(request.project_id);
1987    let (room, guest_connection_ids) = &*session
1988        .db()
1989        .await
1990        .update_project(project_id, session.connection_id, &request.worktrees)
1991        .await?;
1992    broadcast(
1993        Some(session.connection_id),
1994        guest_connection_ids.iter().copied(),
1995        |connection_id| {
1996            session
1997                .peer
1998                .forward_send(session.connection_id, connection_id, request.clone())
1999        },
2000    );
2001    if let Some(room) = room {
2002        room_updated(room, &session.peer);
2003    }
2004    response.send(proto::Ack {})?;
2005
2006    Ok(())
2007}
2008
2009/// Updates other participants with changes to the worktree
2010async fn update_worktree(
2011    request: proto::UpdateWorktree,
2012    response: Response<proto::UpdateWorktree>,
2013    session: Session,
2014) -> Result<()> {
2015    let guest_connection_ids = session
2016        .db()
2017        .await
2018        .update_worktree(&request, session.connection_id)
2019        .await?;
2020
2021    broadcast(
2022        Some(session.connection_id),
2023        guest_connection_ids.iter().copied(),
2024        |connection_id| {
2025            session
2026                .peer
2027                .forward_send(session.connection_id, connection_id, request.clone())
2028        },
2029    );
2030    response.send(proto::Ack {})?;
2031    Ok(())
2032}
2033
2034async fn update_repository(
2035    request: proto::UpdateRepository,
2036    response: Response<proto::UpdateRepository>,
2037    session: Session,
2038) -> Result<()> {
2039    let guest_connection_ids = session
2040        .db()
2041        .await
2042        .update_repository(&request, session.connection_id)
2043        .await?;
2044
2045    broadcast(
2046        Some(session.connection_id),
2047        guest_connection_ids.iter().copied(),
2048        |connection_id| {
2049            session
2050                .peer
2051                .forward_send(session.connection_id, connection_id, request.clone())
2052        },
2053    );
2054    response.send(proto::Ack {})?;
2055    Ok(())
2056}
2057
2058async fn remove_repository(
2059    request: proto::RemoveRepository,
2060    response: Response<proto::RemoveRepository>,
2061    session: Session,
2062) -> Result<()> {
2063    let guest_connection_ids = session
2064        .db()
2065        .await
2066        .remove_repository(&request, session.connection_id)
2067        .await?;
2068
2069    broadcast(
2070        Some(session.connection_id),
2071        guest_connection_ids.iter().copied(),
2072        |connection_id| {
2073            session
2074                .peer
2075                .forward_send(session.connection_id, connection_id, request.clone())
2076        },
2077    );
2078    response.send(proto::Ack {})?;
2079    Ok(())
2080}
2081
2082/// Updates other participants with changes to the diagnostics
2083async fn update_diagnostic_summary(
2084    message: proto::UpdateDiagnosticSummary,
2085    session: Session,
2086) -> Result<()> {
2087    let guest_connection_ids = session
2088        .db()
2089        .await
2090        .update_diagnostic_summary(&message, session.connection_id)
2091        .await?;
2092
2093    broadcast(
2094        Some(session.connection_id),
2095        guest_connection_ids.iter().copied(),
2096        |connection_id| {
2097            session
2098                .peer
2099                .forward_send(session.connection_id, connection_id, message.clone())
2100        },
2101    );
2102
2103    Ok(())
2104}
2105
2106/// Updates other participants with changes to the worktree settings
2107async fn update_worktree_settings(
2108    message: proto::UpdateWorktreeSettings,
2109    session: Session,
2110) -> Result<()> {
2111    let guest_connection_ids = session
2112        .db()
2113        .await
2114        .update_worktree_settings(&message, session.connection_id)
2115        .await?;
2116
2117    broadcast(
2118        Some(session.connection_id),
2119        guest_connection_ids.iter().copied(),
2120        |connection_id| {
2121            session
2122                .peer
2123                .forward_send(session.connection_id, connection_id, message.clone())
2124        },
2125    );
2126
2127    Ok(())
2128}
2129
2130/// Notify other participants that a language server has started.
2131async fn start_language_server(
2132    request: proto::StartLanguageServer,
2133    session: Session,
2134) -> Result<()> {
2135    let guest_connection_ids = session
2136        .db()
2137        .await
2138        .start_language_server(&request, session.connection_id)
2139        .await?;
2140
2141    broadcast(
2142        Some(session.connection_id),
2143        guest_connection_ids.iter().copied(),
2144        |connection_id| {
2145            session
2146                .peer
2147                .forward_send(session.connection_id, connection_id, request.clone())
2148        },
2149    );
2150    Ok(())
2151}
2152
2153/// Notify other participants that a language server has changed.
2154async fn update_language_server(
2155    request: proto::UpdateLanguageServer,
2156    session: Session,
2157) -> Result<()> {
2158    let project_id = ProjectId::from_proto(request.project_id);
2159    let project_connection_ids = session
2160        .db()
2161        .await
2162        .project_connection_ids(project_id, session.connection_id, true)
2163        .await?;
2164    broadcast(
2165        Some(session.connection_id),
2166        project_connection_ids.iter().copied(),
2167        |connection_id| {
2168            session
2169                .peer
2170                .forward_send(session.connection_id, connection_id, request.clone())
2171        },
2172    );
2173    Ok(())
2174}
2175
2176/// forward a project request to the host. These requests should be read only
2177/// as guests are allowed to send them.
2178async fn forward_read_only_project_request<T>(
2179    request: T,
2180    response: Response<T>,
2181    session: Session,
2182) -> Result<()>
2183where
2184    T: EntityMessage + RequestMessage,
2185{
2186    let project_id = ProjectId::from_proto(request.remote_entity_id());
2187    let host_connection_id = session
2188        .db()
2189        .await
2190        .host_for_read_only_project_request(project_id, session.connection_id)
2191        .await?;
2192    let payload = session
2193        .peer
2194        .forward_request(session.connection_id, host_connection_id, request)
2195        .await?;
2196    response.send(payload)?;
2197    Ok(())
2198}
2199
2200async fn forward_find_search_candidates_request(
2201    request: proto::FindSearchCandidates,
2202    response: Response<proto::FindSearchCandidates>,
2203    session: Session,
2204) -> Result<()> {
2205    let project_id = ProjectId::from_proto(request.remote_entity_id());
2206    let host_connection_id = session
2207        .db()
2208        .await
2209        .host_for_read_only_project_request(project_id, session.connection_id)
2210        .await?;
2211    let payload = session
2212        .peer
2213        .forward_request(session.connection_id, host_connection_id, request)
2214        .await?;
2215    response.send(payload)?;
2216    Ok(())
2217}
2218
2219/// forward a project request to the host. These requests are disallowed
2220/// for guests.
2221async fn forward_mutating_project_request<T>(
2222    request: T,
2223    response: Response<T>,
2224    session: Session,
2225) -> Result<()>
2226where
2227    T: EntityMessage + RequestMessage,
2228{
2229    let project_id = ProjectId::from_proto(request.remote_entity_id());
2230
2231    let host_connection_id = session
2232        .db()
2233        .await
2234        .host_for_mutating_project_request(project_id, session.connection_id)
2235        .await?;
2236    let payload = session
2237        .peer
2238        .forward_request(session.connection_id, host_connection_id, request)
2239        .await?;
2240    response.send(payload)?;
2241    Ok(())
2242}
2243
2244/// Notify other participants that a new buffer has been created
2245async fn create_buffer_for_peer(
2246    request: proto::CreateBufferForPeer,
2247    session: Session,
2248) -> Result<()> {
2249    session
2250        .db()
2251        .await
2252        .check_user_is_project_host(
2253            ProjectId::from_proto(request.project_id),
2254            session.connection_id,
2255        )
2256        .await?;
2257    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2258    session
2259        .peer
2260        .forward_send(session.connection_id, peer_id.into(), request)?;
2261    Ok(())
2262}
2263
2264/// Notify other participants that a buffer has been updated. This is
2265/// allowed for guests as long as the update is limited to selections.
2266async fn update_buffer(
2267    request: proto::UpdateBuffer,
2268    response: Response<proto::UpdateBuffer>,
2269    session: Session,
2270) -> Result<()> {
2271    let project_id = ProjectId::from_proto(request.project_id);
2272    let mut capability = Capability::ReadOnly;
2273
2274    for op in request.operations.iter() {
2275        match op.variant {
2276            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2277            Some(_) => capability = Capability::ReadWrite,
2278        }
2279    }
2280
2281    let host = {
2282        let guard = session
2283            .db()
2284            .await
2285            .connections_for_buffer_update(project_id, session.connection_id, capability)
2286            .await?;
2287
2288        let (host, guests) = &*guard;
2289
2290        broadcast(
2291            Some(session.connection_id),
2292            guests.clone(),
2293            |connection_id| {
2294                session
2295                    .peer
2296                    .forward_send(session.connection_id, connection_id, request.clone())
2297            },
2298        );
2299
2300        *host
2301    };
2302
2303    if host != session.connection_id {
2304        session
2305            .peer
2306            .forward_request(session.connection_id, host, request.clone())
2307            .await?;
2308    }
2309
2310    response.send(proto::Ack {})?;
2311    Ok(())
2312}
2313
2314async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2315    let project_id = ProjectId::from_proto(message.project_id);
2316
2317    let operation = message.operation.as_ref().context("invalid operation")?;
2318    let capability = match operation.variant.as_ref() {
2319        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2320            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2321                match buffer_op.variant {
2322                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2323                        Capability::ReadOnly
2324                    }
2325                    _ => Capability::ReadWrite,
2326                }
2327            } else {
2328                Capability::ReadWrite
2329            }
2330        }
2331        Some(_) => Capability::ReadWrite,
2332        None => Capability::ReadOnly,
2333    };
2334
2335    let guard = session
2336        .db()
2337        .await
2338        .connections_for_buffer_update(project_id, session.connection_id, capability)
2339        .await?;
2340
2341    let (host, guests) = &*guard;
2342
2343    broadcast(
2344        Some(session.connection_id),
2345        guests.iter().chain([host]).copied(),
2346        |connection_id| {
2347            session
2348                .peer
2349                .forward_send(session.connection_id, connection_id, message.clone())
2350        },
2351    );
2352
2353    Ok(())
2354}
2355
2356/// Notify other participants that a project has been updated.
2357async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2358    request: T,
2359    session: Session,
2360) -> Result<()> {
2361    let project_id = ProjectId::from_proto(request.remote_entity_id());
2362    let project_connection_ids = session
2363        .db()
2364        .await
2365        .project_connection_ids(project_id, session.connection_id, false)
2366        .await?;
2367
2368    broadcast(
2369        Some(session.connection_id),
2370        project_connection_ids.iter().copied(),
2371        |connection_id| {
2372            session
2373                .peer
2374                .forward_send(session.connection_id, connection_id, request.clone())
2375        },
2376    );
2377    Ok(())
2378}
2379
2380/// Start following another user in a call.
2381async fn follow(
2382    request: proto::Follow,
2383    response: Response<proto::Follow>,
2384    session: Session,
2385) -> Result<()> {
2386    let room_id = RoomId::from_proto(request.room_id);
2387    let project_id = request.project_id.map(ProjectId::from_proto);
2388    let leader_id = request
2389        .leader_id
2390        .ok_or_else(|| anyhow!("invalid leader id"))?
2391        .into();
2392    let follower_id = session.connection_id;
2393
2394    session
2395        .db()
2396        .await
2397        .check_room_participants(room_id, leader_id, session.connection_id)
2398        .await?;
2399
2400    let response_payload = session
2401        .peer
2402        .forward_request(session.connection_id, leader_id, request)
2403        .await?;
2404    response.send(response_payload)?;
2405
2406    if let Some(project_id) = project_id {
2407        let room = session
2408            .db()
2409            .await
2410            .follow(room_id, project_id, leader_id, follower_id)
2411            .await?;
2412        room_updated(&room, &session.peer);
2413    }
2414
2415    Ok(())
2416}
2417
2418/// Stop following another user in a call.
2419async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2420    let room_id = RoomId::from_proto(request.room_id);
2421    let project_id = request.project_id.map(ProjectId::from_proto);
2422    let leader_id = request
2423        .leader_id
2424        .ok_or_else(|| anyhow!("invalid leader id"))?
2425        .into();
2426    let follower_id = session.connection_id;
2427
2428    session
2429        .db()
2430        .await
2431        .check_room_participants(room_id, leader_id, session.connection_id)
2432        .await?;
2433
2434    session
2435        .peer
2436        .forward_send(session.connection_id, leader_id, request)?;
2437
2438    if let Some(project_id) = project_id {
2439        let room = session
2440            .db()
2441            .await
2442            .unfollow(room_id, project_id, leader_id, follower_id)
2443            .await?;
2444        room_updated(&room, &session.peer);
2445    }
2446
2447    Ok(())
2448}
2449
2450/// Notify everyone following you of your current location.
2451async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2452    let room_id = RoomId::from_proto(request.room_id);
2453    let database = session.db.lock().await;
2454
2455    let connection_ids = if let Some(project_id) = request.project_id {
2456        let project_id = ProjectId::from_proto(project_id);
2457        database
2458            .project_connection_ids(project_id, session.connection_id, true)
2459            .await?
2460    } else {
2461        database
2462            .room_connection_ids(room_id, session.connection_id)
2463            .await?
2464    };
2465
2466    // For now, don't send view update messages back to that view's current leader.
2467    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2468        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2469        _ => None,
2470    });
2471
2472    for connection_id in connection_ids.iter().cloned() {
2473        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2474            session
2475                .peer
2476                .forward_send(session.connection_id, connection_id, request.clone())?;
2477        }
2478    }
2479    Ok(())
2480}
2481
2482/// Get public data about users.
2483async fn get_users(
2484    request: proto::GetUsers,
2485    response: Response<proto::GetUsers>,
2486    session: Session,
2487) -> Result<()> {
2488    let user_ids = request
2489        .user_ids
2490        .into_iter()
2491        .map(UserId::from_proto)
2492        .collect();
2493    let users = session
2494        .db()
2495        .await
2496        .get_users_by_ids(user_ids)
2497        .await?
2498        .into_iter()
2499        .map(|user| proto::User {
2500            id: user.id.to_proto(),
2501            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2502            github_login: user.github_login,
2503            email: user.email_address,
2504            name: user.name,
2505        })
2506        .collect();
2507    response.send(proto::UsersResponse { users })?;
2508    Ok(())
2509}
2510
2511/// Search for users (to invite) buy Github login
2512async fn fuzzy_search_users(
2513    request: proto::FuzzySearchUsers,
2514    response: Response<proto::FuzzySearchUsers>,
2515    session: Session,
2516) -> Result<()> {
2517    let query = request.query;
2518    let users = match query.len() {
2519        0 => vec![],
2520        1 | 2 => session
2521            .db()
2522            .await
2523            .get_user_by_github_login(&query)
2524            .await?
2525            .into_iter()
2526            .collect(),
2527        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2528    };
2529    let users = users
2530        .into_iter()
2531        .filter(|user| user.id != session.user_id())
2532        .map(|user| proto::User {
2533            id: user.id.to_proto(),
2534            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2535            github_login: user.github_login,
2536            name: user.name,
2537            email: user.email_address,
2538        })
2539        .collect();
2540    response.send(proto::UsersResponse { users })?;
2541    Ok(())
2542}
2543
2544/// Send a contact request to another user.
2545async fn request_contact(
2546    request: proto::RequestContact,
2547    response: Response<proto::RequestContact>,
2548    session: Session,
2549) -> Result<()> {
2550    let requester_id = session.user_id();
2551    let responder_id = UserId::from_proto(request.responder_id);
2552    if requester_id == responder_id {
2553        return Err(anyhow!("cannot add yourself as a contact"))?;
2554    }
2555
2556    let notifications = session
2557        .db()
2558        .await
2559        .send_contact_request(requester_id, responder_id)
2560        .await?;
2561
2562    // Update outgoing contact requests of requester
2563    let mut update = proto::UpdateContacts::default();
2564    update.outgoing_requests.push(responder_id.to_proto());
2565    for connection_id in session
2566        .connection_pool()
2567        .await
2568        .user_connection_ids(requester_id)
2569    {
2570        session.peer.send(connection_id, update.clone())?;
2571    }
2572
2573    // Update incoming contact requests of responder
2574    let mut update = proto::UpdateContacts::default();
2575    update
2576        .incoming_requests
2577        .push(proto::IncomingContactRequest {
2578            requester_id: requester_id.to_proto(),
2579        });
2580    let connection_pool = session.connection_pool().await;
2581    for connection_id in connection_pool.user_connection_ids(responder_id) {
2582        session.peer.send(connection_id, update.clone())?;
2583    }
2584
2585    send_notifications(&connection_pool, &session.peer, notifications);
2586
2587    response.send(proto::Ack {})?;
2588    Ok(())
2589}
2590
2591/// Accept or decline a contact request
2592async fn respond_to_contact_request(
2593    request: proto::RespondToContactRequest,
2594    response: Response<proto::RespondToContactRequest>,
2595    session: Session,
2596) -> Result<()> {
2597    let responder_id = session.user_id();
2598    let requester_id = UserId::from_proto(request.requester_id);
2599    let db = session.db().await;
2600    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2601        db.dismiss_contact_notification(responder_id, requester_id)
2602            .await?;
2603    } else {
2604        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2605
2606        let notifications = db
2607            .respond_to_contact_request(responder_id, requester_id, accept)
2608            .await?;
2609        let requester_busy = db.is_user_busy(requester_id).await?;
2610        let responder_busy = db.is_user_busy(responder_id).await?;
2611
2612        let pool = session.connection_pool().await;
2613        // Update responder with new contact
2614        let mut update = proto::UpdateContacts::default();
2615        if accept {
2616            update
2617                .contacts
2618                .push(contact_for_user(requester_id, requester_busy, &pool));
2619        }
2620        update
2621            .remove_incoming_requests
2622            .push(requester_id.to_proto());
2623        for connection_id in pool.user_connection_ids(responder_id) {
2624            session.peer.send(connection_id, update.clone())?;
2625        }
2626
2627        // Update requester with new contact
2628        let mut update = proto::UpdateContacts::default();
2629        if accept {
2630            update
2631                .contacts
2632                .push(contact_for_user(responder_id, responder_busy, &pool));
2633        }
2634        update
2635            .remove_outgoing_requests
2636            .push(responder_id.to_proto());
2637
2638        for connection_id in pool.user_connection_ids(requester_id) {
2639            session.peer.send(connection_id, update.clone())?;
2640        }
2641
2642        send_notifications(&pool, &session.peer, notifications);
2643    }
2644
2645    response.send(proto::Ack {})?;
2646    Ok(())
2647}
2648
2649/// Remove a contact.
2650async fn remove_contact(
2651    request: proto::RemoveContact,
2652    response: Response<proto::RemoveContact>,
2653    session: Session,
2654) -> Result<()> {
2655    let requester_id = session.user_id();
2656    let responder_id = UserId::from_proto(request.user_id);
2657    let db = session.db().await;
2658    let (contact_accepted, deleted_notification_id) =
2659        db.remove_contact(requester_id, responder_id).await?;
2660
2661    let pool = session.connection_pool().await;
2662    // Update outgoing contact requests of requester
2663    let mut update = proto::UpdateContacts::default();
2664    if contact_accepted {
2665        update.remove_contacts.push(responder_id.to_proto());
2666    } else {
2667        update
2668            .remove_outgoing_requests
2669            .push(responder_id.to_proto());
2670    }
2671    for connection_id in pool.user_connection_ids(requester_id) {
2672        session.peer.send(connection_id, update.clone())?;
2673    }
2674
2675    // Update incoming contact requests of responder
2676    let mut update = proto::UpdateContacts::default();
2677    if contact_accepted {
2678        update.remove_contacts.push(requester_id.to_proto());
2679    } else {
2680        update
2681            .remove_incoming_requests
2682            .push(requester_id.to_proto());
2683    }
2684    for connection_id in pool.user_connection_ids(responder_id) {
2685        session.peer.send(connection_id, update.clone())?;
2686        if let Some(notification_id) = deleted_notification_id {
2687            session.peer.send(
2688                connection_id,
2689                proto::DeleteNotification {
2690                    notification_id: notification_id.to_proto(),
2691                },
2692            )?;
2693        }
2694    }
2695
2696    response.send(proto::Ack {})?;
2697    Ok(())
2698}
2699
2700fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2701    version.0.minor() < 139
2702}
2703
2704async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2705    let db = session.db().await;
2706
2707    let feature_flags = db.get_user_flags(user_id).await?;
2708    let plan = session.current_plan(&db).await?;
2709    let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2710    let billing_preferences = db.get_billing_preferences(user_id).await?;
2711
2712    let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
2713        let subscription = db.get_active_billing_subscription(user_id).await?;
2714
2715        let subscription_period = crate::db::billing_subscription::Model::current_period(
2716            subscription,
2717            session.is_staff(),
2718        );
2719
2720        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2721            llm_db
2722                .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2723                .await?
2724        } else {
2725            None
2726        };
2727
2728        (subscription_period, usage)
2729    } else {
2730        (None, None)
2731    };
2732
2733    session
2734        .peer
2735        .send(
2736            session.connection_id,
2737            proto::UpdateUserPlan {
2738                plan: plan.into(),
2739                trial_started_at: billing_customer
2740                    .and_then(|billing_customer| billing_customer.trial_started_at)
2741                    .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2742                is_usage_based_billing_enabled: if session.is_staff() {
2743                    Some(true)
2744                } else {
2745                    billing_preferences
2746                        .map(|preferences| preferences.model_request_overages_enabled)
2747                },
2748                subscription_period: subscription_period.map(|(started_at, ended_at)| {
2749                    proto::SubscriptionPeriod {
2750                        started_at: started_at.timestamp() as u64,
2751                        ended_at: ended_at.timestamp() as u64,
2752                    }
2753                }),
2754                usage: usage.map(|usage| {
2755                    let plan = match plan {
2756                        proto::Plan::Free => zed_llm_client::Plan::Free,
2757                        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2758                        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2759                    };
2760
2761                    let model_requests_limit = match plan.model_requests_limit() {
2762                        zed_llm_client::UsageLimit::Limited(limit) => {
2763                            let limit = if plan == zed_llm_client::Plan::ZedProTrial
2764                                && feature_flags
2765                                    .iter()
2766                                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2767                            {
2768                                1_000
2769                            } else {
2770                                limit
2771                            };
2772
2773                            zed_llm_client::UsageLimit::Limited(limit)
2774                        }
2775                        zed_llm_client::UsageLimit::Unlimited => {
2776                            zed_llm_client::UsageLimit::Unlimited
2777                        }
2778                    };
2779
2780                    proto::SubscriptionUsage {
2781                        model_requests_usage_amount: usage.model_requests as u32,
2782                        model_requests_usage_limit: Some(proto::UsageLimit {
2783                            variant: Some(match model_requests_limit {
2784                                zed_llm_client::UsageLimit::Limited(limit) => {
2785                                    proto::usage_limit::Variant::Limited(
2786                                        proto::usage_limit::Limited {
2787                                            limit: limit as u32,
2788                                        },
2789                                    )
2790                                }
2791                                zed_llm_client::UsageLimit::Unlimited => {
2792                                    proto::usage_limit::Variant::Unlimited(
2793                                        proto::usage_limit::Unlimited {},
2794                                    )
2795                                }
2796                            }),
2797                        }),
2798                        edit_predictions_usage_amount: usage.edit_predictions as u32,
2799                        edit_predictions_usage_limit: Some(proto::UsageLimit {
2800                            variant: Some(match plan.edit_predictions_limit() {
2801                                zed_llm_client::UsageLimit::Limited(limit) => {
2802                                    proto::usage_limit::Variant::Limited(
2803                                        proto::usage_limit::Limited {
2804                                            limit: limit as u32,
2805                                        },
2806                                    )
2807                                }
2808                                zed_llm_client::UsageLimit::Unlimited => {
2809                                    proto::usage_limit::Variant::Unlimited(
2810                                        proto::usage_limit::Unlimited {},
2811                                    )
2812                                }
2813                            }),
2814                        }),
2815                    }
2816                }),
2817            },
2818        )
2819        .trace_err();
2820
2821    Ok(())
2822}
2823
2824async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2825    subscribe_user_to_channels(session.user_id(), &session).await?;
2826    Ok(())
2827}
2828
2829async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2830    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2831    let mut pool = session.connection_pool().await;
2832    for membership in &channels_for_user.channel_memberships {
2833        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2834    }
2835    session.peer.send(
2836        session.connection_id,
2837        build_update_user_channels(&channels_for_user),
2838    )?;
2839    session.peer.send(
2840        session.connection_id,
2841        build_channels_update(channels_for_user),
2842    )?;
2843    Ok(())
2844}
2845
2846/// Creates a new channel.
2847async fn create_channel(
2848    request: proto::CreateChannel,
2849    response: Response<proto::CreateChannel>,
2850    session: Session,
2851) -> Result<()> {
2852    let db = session.db().await;
2853
2854    let parent_id = request.parent_id.map(ChannelId::from_proto);
2855    let (channel, membership) = db
2856        .create_channel(&request.name, parent_id, session.user_id())
2857        .await?;
2858
2859    let root_id = channel.root_id();
2860    let channel = Channel::from_model(channel);
2861
2862    response.send(proto::CreateChannelResponse {
2863        channel: Some(channel.to_proto()),
2864        parent_id: request.parent_id,
2865    })?;
2866
2867    let mut connection_pool = session.connection_pool().await;
2868    if let Some(membership) = membership {
2869        connection_pool.subscribe_to_channel(
2870            membership.user_id,
2871            membership.channel_id,
2872            membership.role,
2873        );
2874        let update = proto::UpdateUserChannels {
2875            channel_memberships: vec![proto::ChannelMembership {
2876                channel_id: membership.channel_id.to_proto(),
2877                role: membership.role.into(),
2878            }],
2879            ..Default::default()
2880        };
2881        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2882            session.peer.send(connection_id, update.clone())?;
2883        }
2884    }
2885
2886    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2887        if !role.can_see_channel(channel.visibility) {
2888            continue;
2889        }
2890
2891        let update = proto::UpdateChannels {
2892            channels: vec![channel.to_proto()],
2893            ..Default::default()
2894        };
2895        session.peer.send(connection_id, update.clone())?;
2896    }
2897
2898    Ok(())
2899}
2900
2901/// Delete a channel
2902async fn delete_channel(
2903    request: proto::DeleteChannel,
2904    response: Response<proto::DeleteChannel>,
2905    session: Session,
2906) -> Result<()> {
2907    let db = session.db().await;
2908
2909    let channel_id = request.channel_id;
2910    let (root_channel, removed_channels) = db
2911        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2912        .await?;
2913    response.send(proto::Ack {})?;
2914
2915    // Notify members of removed channels
2916    let mut update = proto::UpdateChannels::default();
2917    update
2918        .delete_channels
2919        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2920
2921    let connection_pool = session.connection_pool().await;
2922    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2923        session.peer.send(connection_id, update.clone())?;
2924    }
2925
2926    Ok(())
2927}
2928
2929/// Invite someone to join a channel.
2930async fn invite_channel_member(
2931    request: proto::InviteChannelMember,
2932    response: Response<proto::InviteChannelMember>,
2933    session: Session,
2934) -> Result<()> {
2935    let db = session.db().await;
2936    let channel_id = ChannelId::from_proto(request.channel_id);
2937    let invitee_id = UserId::from_proto(request.user_id);
2938    let InviteMemberResult {
2939        channel,
2940        notifications,
2941    } = db
2942        .invite_channel_member(
2943            channel_id,
2944            invitee_id,
2945            session.user_id(),
2946            request.role().into(),
2947        )
2948        .await?;
2949
2950    let update = proto::UpdateChannels {
2951        channel_invitations: vec![channel.to_proto()],
2952        ..Default::default()
2953    };
2954
2955    let connection_pool = session.connection_pool().await;
2956    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2957        session.peer.send(connection_id, update.clone())?;
2958    }
2959
2960    send_notifications(&connection_pool, &session.peer, notifications);
2961
2962    response.send(proto::Ack {})?;
2963    Ok(())
2964}
2965
2966/// remove someone from a channel
2967async fn remove_channel_member(
2968    request: proto::RemoveChannelMember,
2969    response: Response<proto::RemoveChannelMember>,
2970    session: Session,
2971) -> Result<()> {
2972    let db = session.db().await;
2973    let channel_id = ChannelId::from_proto(request.channel_id);
2974    let member_id = UserId::from_proto(request.user_id);
2975
2976    let RemoveChannelMemberResult {
2977        membership_update,
2978        notification_id,
2979    } = db
2980        .remove_channel_member(channel_id, member_id, session.user_id())
2981        .await?;
2982
2983    let mut connection_pool = session.connection_pool().await;
2984    notify_membership_updated(
2985        &mut connection_pool,
2986        membership_update,
2987        member_id,
2988        &session.peer,
2989    );
2990    for connection_id in connection_pool.user_connection_ids(member_id) {
2991        if let Some(notification_id) = notification_id {
2992            session
2993                .peer
2994                .send(
2995                    connection_id,
2996                    proto::DeleteNotification {
2997                        notification_id: notification_id.to_proto(),
2998                    },
2999                )
3000                .trace_err();
3001        }
3002    }
3003
3004    response.send(proto::Ack {})?;
3005    Ok(())
3006}
3007
3008/// Toggle the channel between public and private.
3009/// Care is taken to maintain the invariant that public channels only descend from public channels,
3010/// (though members-only channels can appear at any point in the hierarchy).
3011async fn set_channel_visibility(
3012    request: proto::SetChannelVisibility,
3013    response: Response<proto::SetChannelVisibility>,
3014    session: Session,
3015) -> Result<()> {
3016    let db = session.db().await;
3017    let channel_id = ChannelId::from_proto(request.channel_id);
3018    let visibility = request.visibility().into();
3019
3020    let channel_model = db
3021        .set_channel_visibility(channel_id, visibility, session.user_id())
3022        .await?;
3023    let root_id = channel_model.root_id();
3024    let channel = Channel::from_model(channel_model);
3025
3026    let mut connection_pool = session.connection_pool().await;
3027    for (user_id, role) in connection_pool
3028        .channel_user_ids(root_id)
3029        .collect::<Vec<_>>()
3030        .into_iter()
3031    {
3032        let update = if role.can_see_channel(channel.visibility) {
3033            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3034            proto::UpdateChannels {
3035                channels: vec![channel.to_proto()],
3036                ..Default::default()
3037            }
3038        } else {
3039            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3040            proto::UpdateChannels {
3041                delete_channels: vec![channel.id.to_proto()],
3042                ..Default::default()
3043            }
3044        };
3045
3046        for connection_id in connection_pool.user_connection_ids(user_id) {
3047            session.peer.send(connection_id, update.clone())?;
3048        }
3049    }
3050
3051    response.send(proto::Ack {})?;
3052    Ok(())
3053}
3054
3055/// Alter the role for a user in the channel.
3056async fn set_channel_member_role(
3057    request: proto::SetChannelMemberRole,
3058    response: Response<proto::SetChannelMemberRole>,
3059    session: Session,
3060) -> Result<()> {
3061    let db = session.db().await;
3062    let channel_id = ChannelId::from_proto(request.channel_id);
3063    let member_id = UserId::from_proto(request.user_id);
3064    let result = db
3065        .set_channel_member_role(
3066            channel_id,
3067            session.user_id(),
3068            member_id,
3069            request.role().into(),
3070        )
3071        .await?;
3072
3073    match result {
3074        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3075            let mut connection_pool = session.connection_pool().await;
3076            notify_membership_updated(
3077                &mut connection_pool,
3078                membership_update,
3079                member_id,
3080                &session.peer,
3081            )
3082        }
3083        db::SetMemberRoleResult::InviteUpdated(channel) => {
3084            let update = proto::UpdateChannels {
3085                channel_invitations: vec![channel.to_proto()],
3086                ..Default::default()
3087            };
3088
3089            for connection_id in session
3090                .connection_pool()
3091                .await
3092                .user_connection_ids(member_id)
3093            {
3094                session.peer.send(connection_id, update.clone())?;
3095            }
3096        }
3097    }
3098
3099    response.send(proto::Ack {})?;
3100    Ok(())
3101}
3102
3103/// Change the name of a channel
3104async fn rename_channel(
3105    request: proto::RenameChannel,
3106    response: Response<proto::RenameChannel>,
3107    session: Session,
3108) -> Result<()> {
3109    let db = session.db().await;
3110    let channel_id = ChannelId::from_proto(request.channel_id);
3111    let channel_model = db
3112        .rename_channel(channel_id, session.user_id(), &request.name)
3113        .await?;
3114    let root_id = channel_model.root_id();
3115    let channel = Channel::from_model(channel_model);
3116
3117    response.send(proto::RenameChannelResponse {
3118        channel: Some(channel.to_proto()),
3119    })?;
3120
3121    let connection_pool = session.connection_pool().await;
3122    let update = proto::UpdateChannels {
3123        channels: vec![channel.to_proto()],
3124        ..Default::default()
3125    };
3126    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3127        if role.can_see_channel(channel.visibility) {
3128            session.peer.send(connection_id, update.clone())?;
3129        }
3130    }
3131
3132    Ok(())
3133}
3134
3135/// Move a channel to a new parent.
3136async fn move_channel(
3137    request: proto::MoveChannel,
3138    response: Response<proto::MoveChannel>,
3139    session: Session,
3140) -> Result<()> {
3141    let channel_id = ChannelId::from_proto(request.channel_id);
3142    let to = ChannelId::from_proto(request.to);
3143
3144    let (root_id, channels) = session
3145        .db()
3146        .await
3147        .move_channel(channel_id, to, session.user_id())
3148        .await?;
3149
3150    let connection_pool = session.connection_pool().await;
3151    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3152        let channels = channels
3153            .iter()
3154            .filter_map(|channel| {
3155                if role.can_see_channel(channel.visibility) {
3156                    Some(channel.to_proto())
3157                } else {
3158                    None
3159                }
3160            })
3161            .collect::<Vec<_>>();
3162        if channels.is_empty() {
3163            continue;
3164        }
3165
3166        let update = proto::UpdateChannels {
3167            channels,
3168            ..Default::default()
3169        };
3170
3171        session.peer.send(connection_id, update.clone())?;
3172    }
3173
3174    response.send(Ack {})?;
3175    Ok(())
3176}
3177
3178/// Get the list of channel members
3179async fn get_channel_members(
3180    request: proto::GetChannelMembers,
3181    response: Response<proto::GetChannelMembers>,
3182    session: Session,
3183) -> Result<()> {
3184    let db = session.db().await;
3185    let channel_id = ChannelId::from_proto(request.channel_id);
3186    let limit = if request.limit == 0 {
3187        u16::MAX as u64
3188    } else {
3189        request.limit
3190    };
3191    let (members, users) = db
3192        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3193        .await?;
3194    response.send(proto::GetChannelMembersResponse { members, users })?;
3195    Ok(())
3196}
3197
3198/// Accept or decline a channel invitation.
3199async fn respond_to_channel_invite(
3200    request: proto::RespondToChannelInvite,
3201    response: Response<proto::RespondToChannelInvite>,
3202    session: Session,
3203) -> Result<()> {
3204    let db = session.db().await;
3205    let channel_id = ChannelId::from_proto(request.channel_id);
3206    let RespondToChannelInvite {
3207        membership_update,
3208        notifications,
3209    } = db
3210        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3211        .await?;
3212
3213    let mut connection_pool = session.connection_pool().await;
3214    if let Some(membership_update) = membership_update {
3215        notify_membership_updated(
3216            &mut connection_pool,
3217            membership_update,
3218            session.user_id(),
3219            &session.peer,
3220        );
3221    } else {
3222        let update = proto::UpdateChannels {
3223            remove_channel_invitations: vec![channel_id.to_proto()],
3224            ..Default::default()
3225        };
3226
3227        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3228            session.peer.send(connection_id, update.clone())?;
3229        }
3230    };
3231
3232    send_notifications(&connection_pool, &session.peer, notifications);
3233
3234    response.send(proto::Ack {})?;
3235
3236    Ok(())
3237}
3238
3239/// Join the channels' room
3240async fn join_channel(
3241    request: proto::JoinChannel,
3242    response: Response<proto::JoinChannel>,
3243    session: Session,
3244) -> Result<()> {
3245    let channel_id = ChannelId::from_proto(request.channel_id);
3246    join_channel_internal(channel_id, Box::new(response), session).await
3247}
3248
3249trait JoinChannelInternalResponse {
3250    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3251}
3252impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3253    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3254        Response::<proto::JoinChannel>::send(self, result)
3255    }
3256}
3257impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3258    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3259        Response::<proto::JoinRoom>::send(self, result)
3260    }
3261}
3262
3263async fn join_channel_internal(
3264    channel_id: ChannelId,
3265    response: Box<impl JoinChannelInternalResponse>,
3266    session: Session,
3267) -> Result<()> {
3268    let joined_room = {
3269        let mut db = session.db().await;
3270        // If zed quits without leaving the room, and the user re-opens zed before the
3271        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3272        // room they were in.
3273        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3274            tracing::info!(
3275                stale_connection_id = %connection,
3276                "cleaning up stale connection",
3277            );
3278            drop(db);
3279            leave_room_for_session(&session, connection).await?;
3280            db = session.db().await;
3281        }
3282
3283        let (joined_room, membership_updated, role) = db
3284            .join_channel(channel_id, session.user_id(), session.connection_id)
3285            .await?;
3286
3287        let live_kit_connection_info =
3288            session
3289                .app_state
3290                .livekit_client
3291                .as_ref()
3292                .and_then(|live_kit| {
3293                    let (can_publish, token) = if role == ChannelRole::Guest {
3294                        (
3295                            false,
3296                            live_kit
3297                                .guest_token(
3298                                    &joined_room.room.livekit_room,
3299                                    &session.user_id().to_string(),
3300                                )
3301                                .trace_err()?,
3302                        )
3303                    } else {
3304                        (
3305                            true,
3306                            live_kit
3307                                .room_token(
3308                                    &joined_room.room.livekit_room,
3309                                    &session.user_id().to_string(),
3310                                )
3311                                .trace_err()?,
3312                        )
3313                    };
3314
3315                    Some(LiveKitConnectionInfo {
3316                        server_url: live_kit.url().into(),
3317                        token,
3318                        can_publish,
3319                    })
3320                });
3321
3322        response.send(proto::JoinRoomResponse {
3323            room: Some(joined_room.room.clone()),
3324            channel_id: joined_room
3325                .channel
3326                .as_ref()
3327                .map(|channel| channel.id.to_proto()),
3328            live_kit_connection_info,
3329        })?;
3330
3331        let mut connection_pool = session.connection_pool().await;
3332        if let Some(membership_updated) = membership_updated {
3333            notify_membership_updated(
3334                &mut connection_pool,
3335                membership_updated,
3336                session.user_id(),
3337                &session.peer,
3338            );
3339        }
3340
3341        room_updated(&joined_room.room, &session.peer);
3342
3343        joined_room
3344    };
3345
3346    channel_updated(
3347        &joined_room
3348            .channel
3349            .ok_or_else(|| anyhow!("channel not returned"))?,
3350        &joined_room.room,
3351        &session.peer,
3352        &*session.connection_pool().await,
3353    );
3354
3355    update_user_contacts(session.user_id(), &session).await?;
3356    Ok(())
3357}
3358
3359/// Start editing the channel notes
3360async fn join_channel_buffer(
3361    request: proto::JoinChannelBuffer,
3362    response: Response<proto::JoinChannelBuffer>,
3363    session: Session,
3364) -> Result<()> {
3365    let db = session.db().await;
3366    let channel_id = ChannelId::from_proto(request.channel_id);
3367
3368    let open_response = db
3369        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3370        .await?;
3371
3372    let collaborators = open_response.collaborators.clone();
3373    response.send(open_response)?;
3374
3375    let update = UpdateChannelBufferCollaborators {
3376        channel_id: channel_id.to_proto(),
3377        collaborators: collaborators.clone(),
3378    };
3379    channel_buffer_updated(
3380        session.connection_id,
3381        collaborators
3382            .iter()
3383            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3384        &update,
3385        &session.peer,
3386    );
3387
3388    Ok(())
3389}
3390
3391/// Edit the channel notes
3392async fn update_channel_buffer(
3393    request: proto::UpdateChannelBuffer,
3394    session: Session,
3395) -> Result<()> {
3396    let db = session.db().await;
3397    let channel_id = ChannelId::from_proto(request.channel_id);
3398
3399    let (collaborators, epoch, version) = db
3400        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3401        .await?;
3402
3403    channel_buffer_updated(
3404        session.connection_id,
3405        collaborators.clone(),
3406        &proto::UpdateChannelBuffer {
3407            channel_id: channel_id.to_proto(),
3408            operations: request.operations,
3409        },
3410        &session.peer,
3411    );
3412
3413    let pool = &*session.connection_pool().await;
3414
3415    let non_collaborators =
3416        pool.channel_connection_ids(channel_id)
3417            .filter_map(|(connection_id, _)| {
3418                if collaborators.contains(&connection_id) {
3419                    None
3420                } else {
3421                    Some(connection_id)
3422                }
3423            });
3424
3425    broadcast(None, non_collaborators, |peer_id| {
3426        session.peer.send(
3427            peer_id,
3428            proto::UpdateChannels {
3429                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3430                    channel_id: channel_id.to_proto(),
3431                    epoch: epoch as u64,
3432                    version: version.clone(),
3433                }],
3434                ..Default::default()
3435            },
3436        )
3437    });
3438
3439    Ok(())
3440}
3441
3442/// Rejoin the channel notes after a connection blip
3443async fn rejoin_channel_buffers(
3444    request: proto::RejoinChannelBuffers,
3445    response: Response<proto::RejoinChannelBuffers>,
3446    session: Session,
3447) -> Result<()> {
3448    let db = session.db().await;
3449    let buffers = db
3450        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3451        .await?;
3452
3453    for rejoined_buffer in &buffers {
3454        let collaborators_to_notify = rejoined_buffer
3455            .buffer
3456            .collaborators
3457            .iter()
3458            .filter_map(|c| Some(c.peer_id?.into()));
3459        channel_buffer_updated(
3460            session.connection_id,
3461            collaborators_to_notify,
3462            &proto::UpdateChannelBufferCollaborators {
3463                channel_id: rejoined_buffer.buffer.channel_id,
3464                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3465            },
3466            &session.peer,
3467        );
3468    }
3469
3470    response.send(proto::RejoinChannelBuffersResponse {
3471        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3472    })?;
3473
3474    Ok(())
3475}
3476
3477/// Stop editing the channel notes
3478async fn leave_channel_buffer(
3479    request: proto::LeaveChannelBuffer,
3480    response: Response<proto::LeaveChannelBuffer>,
3481    session: Session,
3482) -> Result<()> {
3483    let db = session.db().await;
3484    let channel_id = ChannelId::from_proto(request.channel_id);
3485
3486    let left_buffer = db
3487        .leave_channel_buffer(channel_id, session.connection_id)
3488        .await?;
3489
3490    response.send(Ack {})?;
3491
3492    channel_buffer_updated(
3493        session.connection_id,
3494        left_buffer.connections,
3495        &proto::UpdateChannelBufferCollaborators {
3496            channel_id: channel_id.to_proto(),
3497            collaborators: left_buffer.collaborators,
3498        },
3499        &session.peer,
3500    );
3501
3502    Ok(())
3503}
3504
3505fn channel_buffer_updated<T: EnvelopedMessage>(
3506    sender_id: ConnectionId,
3507    collaborators: impl IntoIterator<Item = ConnectionId>,
3508    message: &T,
3509    peer: &Peer,
3510) {
3511    broadcast(Some(sender_id), collaborators, |peer_id| {
3512        peer.send(peer_id, message.clone())
3513    });
3514}
3515
3516fn send_notifications(
3517    connection_pool: &ConnectionPool,
3518    peer: &Peer,
3519    notifications: db::NotificationBatch,
3520) {
3521    for (user_id, notification) in notifications {
3522        for connection_id in connection_pool.user_connection_ids(user_id) {
3523            if let Err(error) = peer.send(
3524                connection_id,
3525                proto::AddNotification {
3526                    notification: Some(notification.clone()),
3527                },
3528            ) {
3529                tracing::error!(
3530                    "failed to send notification to {:?} {}",
3531                    connection_id,
3532                    error
3533                );
3534            }
3535        }
3536    }
3537}
3538
3539/// Send a message to the channel
3540async fn send_channel_message(
3541    request: proto::SendChannelMessage,
3542    response: Response<proto::SendChannelMessage>,
3543    session: Session,
3544) -> Result<()> {
3545    // Validate the message body.
3546    let body = request.body.trim().to_string();
3547    if body.len() > MAX_MESSAGE_LEN {
3548        return Err(anyhow!("message is too long"))?;
3549    }
3550    if body.is_empty() {
3551        return Err(anyhow!("message can't be blank"))?;
3552    }
3553
3554    // TODO: adjust mentions if body is trimmed
3555
3556    let timestamp = OffsetDateTime::now_utc();
3557    let nonce = request
3558        .nonce
3559        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3560
3561    let channel_id = ChannelId::from_proto(request.channel_id);
3562    let CreatedChannelMessage {
3563        message_id,
3564        participant_connection_ids,
3565        notifications,
3566    } = session
3567        .db()
3568        .await
3569        .create_channel_message(
3570            channel_id,
3571            session.user_id(),
3572            &body,
3573            &request.mentions,
3574            timestamp,
3575            nonce.clone().into(),
3576            request.reply_to_message_id.map(MessageId::from_proto),
3577        )
3578        .await?;
3579
3580    let message = proto::ChannelMessage {
3581        sender_id: session.user_id().to_proto(),
3582        id: message_id.to_proto(),
3583        body,
3584        mentions: request.mentions,
3585        timestamp: timestamp.unix_timestamp() as u64,
3586        nonce: Some(nonce),
3587        reply_to_message_id: request.reply_to_message_id,
3588        edited_at: None,
3589    };
3590    broadcast(
3591        Some(session.connection_id),
3592        participant_connection_ids.clone(),
3593        |connection| {
3594            session.peer.send(
3595                connection,
3596                proto::ChannelMessageSent {
3597                    channel_id: channel_id.to_proto(),
3598                    message: Some(message.clone()),
3599                },
3600            )
3601        },
3602    );
3603    response.send(proto::SendChannelMessageResponse {
3604        message: Some(message),
3605    })?;
3606
3607    let pool = &*session.connection_pool().await;
3608    let non_participants =
3609        pool.channel_connection_ids(channel_id)
3610            .filter_map(|(connection_id, _)| {
3611                if participant_connection_ids.contains(&connection_id) {
3612                    None
3613                } else {
3614                    Some(connection_id)
3615                }
3616            });
3617    broadcast(None, non_participants, |peer_id| {
3618        session.peer.send(
3619            peer_id,
3620            proto::UpdateChannels {
3621                latest_channel_message_ids: vec![proto::ChannelMessageId {
3622                    channel_id: channel_id.to_proto(),
3623                    message_id: message_id.to_proto(),
3624                }],
3625                ..Default::default()
3626            },
3627        )
3628    });
3629    send_notifications(pool, &session.peer, notifications);
3630
3631    Ok(())
3632}
3633
3634/// Delete a channel message
3635async fn remove_channel_message(
3636    request: proto::RemoveChannelMessage,
3637    response: Response<proto::RemoveChannelMessage>,
3638    session: Session,
3639) -> Result<()> {
3640    let channel_id = ChannelId::from_proto(request.channel_id);
3641    let message_id = MessageId::from_proto(request.message_id);
3642    let (connection_ids, existing_notification_ids) = session
3643        .db()
3644        .await
3645        .remove_channel_message(channel_id, message_id, session.user_id())
3646        .await?;
3647
3648    broadcast(
3649        Some(session.connection_id),
3650        connection_ids,
3651        move |connection| {
3652            session.peer.send(connection, request.clone())?;
3653
3654            for notification_id in &existing_notification_ids {
3655                session.peer.send(
3656                    connection,
3657                    proto::DeleteNotification {
3658                        notification_id: (*notification_id).to_proto(),
3659                    },
3660                )?;
3661            }
3662
3663            Ok(())
3664        },
3665    );
3666    response.send(proto::Ack {})?;
3667    Ok(())
3668}
3669
3670async fn update_channel_message(
3671    request: proto::UpdateChannelMessage,
3672    response: Response<proto::UpdateChannelMessage>,
3673    session: Session,
3674) -> Result<()> {
3675    let channel_id = ChannelId::from_proto(request.channel_id);
3676    let message_id = MessageId::from_proto(request.message_id);
3677    let updated_at = OffsetDateTime::now_utc();
3678    let UpdatedChannelMessage {
3679        message_id,
3680        participant_connection_ids,
3681        notifications,
3682        reply_to_message_id,
3683        timestamp,
3684        deleted_mention_notification_ids,
3685        updated_mention_notifications,
3686    } = session
3687        .db()
3688        .await
3689        .update_channel_message(
3690            channel_id,
3691            message_id,
3692            session.user_id(),
3693            request.body.as_str(),
3694            &request.mentions,
3695            updated_at,
3696        )
3697        .await?;
3698
3699    let nonce = request
3700        .nonce
3701        .clone()
3702        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3703
3704    let message = proto::ChannelMessage {
3705        sender_id: session.user_id().to_proto(),
3706        id: message_id.to_proto(),
3707        body: request.body.clone(),
3708        mentions: request.mentions.clone(),
3709        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3710        nonce: Some(nonce),
3711        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3712        edited_at: Some(updated_at.unix_timestamp() as u64),
3713    };
3714
3715    response.send(proto::Ack {})?;
3716
3717    let pool = &*session.connection_pool().await;
3718    broadcast(
3719        Some(session.connection_id),
3720        participant_connection_ids,
3721        |connection| {
3722            session.peer.send(
3723                connection,
3724                proto::ChannelMessageUpdate {
3725                    channel_id: channel_id.to_proto(),
3726                    message: Some(message.clone()),
3727                },
3728            )?;
3729
3730            for notification_id in &deleted_mention_notification_ids {
3731                session.peer.send(
3732                    connection,
3733                    proto::DeleteNotification {
3734                        notification_id: (*notification_id).to_proto(),
3735                    },
3736                )?;
3737            }
3738
3739            for notification in &updated_mention_notifications {
3740                session.peer.send(
3741                    connection,
3742                    proto::UpdateNotification {
3743                        notification: Some(notification.clone()),
3744                    },
3745                )?;
3746            }
3747
3748            Ok(())
3749        },
3750    );
3751
3752    send_notifications(pool, &session.peer, notifications);
3753
3754    Ok(())
3755}
3756
3757/// Mark a channel message as read
3758async fn acknowledge_channel_message(
3759    request: proto::AckChannelMessage,
3760    session: Session,
3761) -> Result<()> {
3762    let channel_id = ChannelId::from_proto(request.channel_id);
3763    let message_id = MessageId::from_proto(request.message_id);
3764    let notifications = session
3765        .db()
3766        .await
3767        .observe_channel_message(channel_id, session.user_id(), message_id)
3768        .await?;
3769    send_notifications(
3770        &*session.connection_pool().await,
3771        &session.peer,
3772        notifications,
3773    );
3774    Ok(())
3775}
3776
3777/// Mark a buffer version as synced
3778async fn acknowledge_buffer_version(
3779    request: proto::AckBufferOperation,
3780    session: Session,
3781) -> Result<()> {
3782    let buffer_id = BufferId::from_proto(request.buffer_id);
3783    session
3784        .db()
3785        .await
3786        .observe_buffer_version(
3787            buffer_id,
3788            session.user_id(),
3789            request.epoch as i32,
3790            &request.version,
3791        )
3792        .await?;
3793    Ok(())
3794}
3795
3796/// Get a Supermaven API key for the user
3797async fn get_supermaven_api_key(
3798    _request: proto::GetSupermavenApiKey,
3799    response: Response<proto::GetSupermavenApiKey>,
3800    session: Session,
3801) -> Result<()> {
3802    let user_id: String = session.user_id().to_string();
3803    if !session.is_staff() {
3804        return Err(anyhow!("supermaven not enabled for this account"))?;
3805    }
3806
3807    let email = session
3808        .email()
3809        .ok_or_else(|| anyhow!("user must have an email"))?;
3810
3811    let supermaven_admin_api = session
3812        .supermaven_client
3813        .as_ref()
3814        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3815
3816    let result = supermaven_admin_api
3817        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3818        .await?;
3819
3820    response.send(proto::GetSupermavenApiKeyResponse {
3821        api_key: result.api_key,
3822    })?;
3823
3824    Ok(())
3825}
3826
3827/// Start receiving chat updates for a channel
3828async fn join_channel_chat(
3829    request: proto::JoinChannelChat,
3830    response: Response<proto::JoinChannelChat>,
3831    session: Session,
3832) -> Result<()> {
3833    let channel_id = ChannelId::from_proto(request.channel_id);
3834
3835    let db = session.db().await;
3836    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3837        .await?;
3838    let messages = db
3839        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3840        .await?;
3841    response.send(proto::JoinChannelChatResponse {
3842        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3843        messages,
3844    })?;
3845    Ok(())
3846}
3847
3848/// Stop receiving chat updates for a channel
3849async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3850    let channel_id = ChannelId::from_proto(request.channel_id);
3851    session
3852        .db()
3853        .await
3854        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3855        .await?;
3856    Ok(())
3857}
3858
3859/// Retrieve the chat history for a channel
3860async fn get_channel_messages(
3861    request: proto::GetChannelMessages,
3862    response: Response<proto::GetChannelMessages>,
3863    session: Session,
3864) -> Result<()> {
3865    let channel_id = ChannelId::from_proto(request.channel_id);
3866    let messages = session
3867        .db()
3868        .await
3869        .get_channel_messages(
3870            channel_id,
3871            session.user_id(),
3872            MESSAGE_COUNT_PER_PAGE,
3873            Some(MessageId::from_proto(request.before_message_id)),
3874        )
3875        .await?;
3876    response.send(proto::GetChannelMessagesResponse {
3877        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3878        messages,
3879    })?;
3880    Ok(())
3881}
3882
3883/// Retrieve specific chat messages
3884async fn get_channel_messages_by_id(
3885    request: proto::GetChannelMessagesById,
3886    response: Response<proto::GetChannelMessagesById>,
3887    session: Session,
3888) -> Result<()> {
3889    let message_ids = request
3890        .message_ids
3891        .iter()
3892        .map(|id| MessageId::from_proto(*id))
3893        .collect::<Vec<_>>();
3894    let messages = session
3895        .db()
3896        .await
3897        .get_channel_messages_by_id(session.user_id(), &message_ids)
3898        .await?;
3899    response.send(proto::GetChannelMessagesResponse {
3900        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3901        messages,
3902    })?;
3903    Ok(())
3904}
3905
3906/// Retrieve the current users notifications
3907async fn get_notifications(
3908    request: proto::GetNotifications,
3909    response: Response<proto::GetNotifications>,
3910    session: Session,
3911) -> Result<()> {
3912    let notifications = session
3913        .db()
3914        .await
3915        .get_notifications(
3916            session.user_id(),
3917            NOTIFICATION_COUNT_PER_PAGE,
3918            request.before_id.map(db::NotificationId::from_proto),
3919        )
3920        .await?;
3921    response.send(proto::GetNotificationsResponse {
3922        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3923        notifications,
3924    })?;
3925    Ok(())
3926}
3927
3928/// Mark notifications as read
3929async fn mark_notification_as_read(
3930    request: proto::MarkNotificationRead,
3931    response: Response<proto::MarkNotificationRead>,
3932    session: Session,
3933) -> Result<()> {
3934    let database = &session.db().await;
3935    let notifications = database
3936        .mark_notification_as_read_by_id(
3937            session.user_id(),
3938            NotificationId::from_proto(request.notification_id),
3939        )
3940        .await?;
3941    send_notifications(
3942        &*session.connection_pool().await,
3943        &session.peer,
3944        notifications,
3945    );
3946    response.send(proto::Ack {})?;
3947    Ok(())
3948}
3949
3950/// Get the current users information
3951async fn get_private_user_info(
3952    _request: proto::GetPrivateUserInfo,
3953    response: Response<proto::GetPrivateUserInfo>,
3954    session: Session,
3955) -> Result<()> {
3956    let db = session.db().await;
3957
3958    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3959    let user = db
3960        .get_user_by_id(session.user_id())
3961        .await?
3962        .ok_or_else(|| anyhow!("user not found"))?;
3963    let flags = db.get_user_flags(session.user_id()).await?;
3964
3965    response.send(proto::GetPrivateUserInfoResponse {
3966        metrics_id,
3967        staff: user.admin,
3968        flags,
3969        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3970    })?;
3971    Ok(())
3972}
3973
3974/// Accept the terms of service (tos) on behalf of the current user
3975async fn accept_terms_of_service(
3976    _request: proto::AcceptTermsOfService,
3977    response: Response<proto::AcceptTermsOfService>,
3978    session: Session,
3979) -> Result<()> {
3980    let db = session.db().await;
3981
3982    let accepted_tos_at = Utc::now();
3983    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3984        .await?;
3985
3986    response.send(proto::AcceptTermsOfServiceResponse {
3987        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3988    })?;
3989    Ok(())
3990}
3991
3992/// The minimum account age an account must have in order to use the LLM service.
3993pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3994
3995async fn get_llm_api_token(
3996    _request: proto::GetLlmToken,
3997    response: Response<proto::GetLlmToken>,
3998    session: Session,
3999) -> Result<()> {
4000    let db = session.db().await;
4001
4002    let flags = db.get_user_flags(session.user_id()).await?;
4003    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4004
4005    if !session.is_staff() && !has_language_models_feature_flag {
4006        Err(anyhow!("permission denied"))?
4007    }
4008
4009    let user_id = session.user_id();
4010    let user = db
4011        .get_user_by_id(user_id)
4012        .await?
4013        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4014
4015    if user.accepted_tos_at.is_none() {
4016        Err(anyhow!("terms of service not accepted"))?
4017    }
4018
4019    let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4020    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4021    let billing_preferences = db.get_billing_preferences(user.id).await?;
4022
4023    let token = LlmTokenClaims::create(
4024        &user,
4025        session.is_staff(),
4026        billing_preferences,
4027        &flags,
4028        has_legacy_llm_subscription,
4029        billing_subscription,
4030        session.system_id.clone(),
4031        &session.app_state.config,
4032    )?;
4033    response.send(proto::GetLlmTokenResponse { token })?;
4034    Ok(())
4035}
4036
4037fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4038    let message = match message {
4039        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4040        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4041        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4042        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4043        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4044            code: frame.code.into(),
4045            reason: frame.reason.as_str().to_owned().into(),
4046        })),
4047        // We should never receive a frame while reading the message, according
4048        // to the `tungstenite` maintainers:
4049        //
4050        // > It cannot occur when you read messages from the WebSocket, but it
4051        // > can be used when you want to send the raw frames (e.g. you want to
4052        // > send the frames to the WebSocket without composing the full message first).
4053        // >
4054        // > — https://github.com/snapview/tungstenite-rs/issues/268
4055        TungsteniteMessage::Frame(_) => {
4056            bail!("received an unexpected frame while reading the message")
4057        }
4058    };
4059
4060    Ok(message)
4061}
4062
4063fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4064    match message {
4065        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4066        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4067        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4068        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4069        AxumMessage::Close(frame) => {
4070            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4071                code: frame.code.into(),
4072                reason: frame.reason.as_ref().into(),
4073            }))
4074        }
4075    }
4076}
4077
4078fn notify_membership_updated(
4079    connection_pool: &mut ConnectionPool,
4080    result: MembershipUpdated,
4081    user_id: UserId,
4082    peer: &Peer,
4083) {
4084    for membership in &result.new_channels.channel_memberships {
4085        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4086    }
4087    for channel_id in &result.removed_channels {
4088        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4089    }
4090
4091    let user_channels_update = proto::UpdateUserChannels {
4092        channel_memberships: result
4093            .new_channels
4094            .channel_memberships
4095            .iter()
4096            .map(|cm| proto::ChannelMembership {
4097                channel_id: cm.channel_id.to_proto(),
4098                role: cm.role.into(),
4099            })
4100            .collect(),
4101        ..Default::default()
4102    };
4103
4104    let mut update = build_channels_update(result.new_channels);
4105    update.delete_channels = result
4106        .removed_channels
4107        .into_iter()
4108        .map(|id| id.to_proto())
4109        .collect();
4110    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4111
4112    for connection_id in connection_pool.user_connection_ids(user_id) {
4113        peer.send(connection_id, user_channels_update.clone())
4114            .trace_err();
4115        peer.send(connection_id, update.clone()).trace_err();
4116    }
4117}
4118
4119fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4120    proto::UpdateUserChannels {
4121        channel_memberships: channels
4122            .channel_memberships
4123            .iter()
4124            .map(|m| proto::ChannelMembership {
4125                channel_id: m.channel_id.to_proto(),
4126                role: m.role.into(),
4127            })
4128            .collect(),
4129        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4130        observed_channel_message_id: channels.observed_channel_messages.clone(),
4131    }
4132}
4133
4134fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4135    let mut update = proto::UpdateChannels::default();
4136
4137    for channel in channels.channels {
4138        update.channels.push(channel.to_proto());
4139    }
4140
4141    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4142    update.latest_channel_message_ids = channels.latest_channel_messages;
4143
4144    for (channel_id, participants) in channels.channel_participants {
4145        update
4146            .channel_participants
4147            .push(proto::ChannelParticipants {
4148                channel_id: channel_id.to_proto(),
4149                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4150            });
4151    }
4152
4153    for channel in channels.invited_channels {
4154        update.channel_invitations.push(channel.to_proto());
4155    }
4156
4157    update
4158}
4159
4160fn build_initial_contacts_update(
4161    contacts: Vec<db::Contact>,
4162    pool: &ConnectionPool,
4163) -> proto::UpdateContacts {
4164    let mut update = proto::UpdateContacts::default();
4165
4166    for contact in contacts {
4167        match contact {
4168            db::Contact::Accepted { user_id, busy } => {
4169                update.contacts.push(contact_for_user(user_id, busy, pool));
4170            }
4171            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4172            db::Contact::Incoming { user_id } => {
4173                update
4174                    .incoming_requests
4175                    .push(proto::IncomingContactRequest {
4176                        requester_id: user_id.to_proto(),
4177                    })
4178            }
4179        }
4180    }
4181
4182    update
4183}
4184
4185fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4186    proto::Contact {
4187        user_id: user_id.to_proto(),
4188        online: pool.is_user_online(user_id),
4189        busy,
4190    }
4191}
4192
4193fn room_updated(room: &proto::Room, peer: &Peer) {
4194    broadcast(
4195        None,
4196        room.participants
4197            .iter()
4198            .filter_map(|participant| Some(participant.peer_id?.into())),
4199        |peer_id| {
4200            peer.send(
4201                peer_id,
4202                proto::RoomUpdated {
4203                    room: Some(room.clone()),
4204                },
4205            )
4206        },
4207    );
4208}
4209
4210fn channel_updated(
4211    channel: &db::channel::Model,
4212    room: &proto::Room,
4213    peer: &Peer,
4214    pool: &ConnectionPool,
4215) {
4216    let participants = room
4217        .participants
4218        .iter()
4219        .map(|p| p.user_id)
4220        .collect::<Vec<_>>();
4221
4222    broadcast(
4223        None,
4224        pool.channel_connection_ids(channel.root_id())
4225            .filter_map(|(channel_id, role)| {
4226                role.can_see_channel(channel.visibility)
4227                    .then_some(channel_id)
4228            }),
4229        |peer_id| {
4230            peer.send(
4231                peer_id,
4232                proto::UpdateChannels {
4233                    channel_participants: vec![proto::ChannelParticipants {
4234                        channel_id: channel.id.to_proto(),
4235                        participant_user_ids: participants.clone(),
4236                    }],
4237                    ..Default::default()
4238                },
4239            )
4240        },
4241    );
4242}
4243
4244async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4245    let db = session.db().await;
4246
4247    let contacts = db.get_contacts(user_id).await?;
4248    let busy = db.is_user_busy(user_id).await?;
4249
4250    let pool = session.connection_pool().await;
4251    let updated_contact = contact_for_user(user_id, busy, &pool);
4252    for contact in contacts {
4253        if let db::Contact::Accepted {
4254            user_id: contact_user_id,
4255            ..
4256        } = contact
4257        {
4258            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4259                session
4260                    .peer
4261                    .send(
4262                        contact_conn_id,
4263                        proto::UpdateContacts {
4264                            contacts: vec![updated_contact.clone()],
4265                            remove_contacts: Default::default(),
4266                            incoming_requests: Default::default(),
4267                            remove_incoming_requests: Default::default(),
4268                            outgoing_requests: Default::default(),
4269                            remove_outgoing_requests: Default::default(),
4270                        },
4271                    )
4272                    .trace_err();
4273            }
4274        }
4275    }
4276    Ok(())
4277}
4278
4279async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4280    let mut contacts_to_update = HashSet::default();
4281
4282    let room_id;
4283    let canceled_calls_to_user_ids;
4284    let livekit_room;
4285    let delete_livekit_room;
4286    let room;
4287    let channel;
4288
4289    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4290        contacts_to_update.insert(session.user_id());
4291
4292        for project in left_room.left_projects.values() {
4293            project_left(project, session);
4294        }
4295
4296        room_id = RoomId::from_proto(left_room.room.id);
4297        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4298        livekit_room = mem::take(&mut left_room.room.livekit_room);
4299        delete_livekit_room = left_room.deleted;
4300        room = mem::take(&mut left_room.room);
4301        channel = mem::take(&mut left_room.channel);
4302
4303        room_updated(&room, &session.peer);
4304    } else {
4305        return Ok(());
4306    }
4307
4308    if let Some(channel) = channel {
4309        channel_updated(
4310            &channel,
4311            &room,
4312            &session.peer,
4313            &*session.connection_pool().await,
4314        );
4315    }
4316
4317    {
4318        let pool = session.connection_pool().await;
4319        for canceled_user_id in canceled_calls_to_user_ids {
4320            for connection_id in pool.user_connection_ids(canceled_user_id) {
4321                session
4322                    .peer
4323                    .send(
4324                        connection_id,
4325                        proto::CallCanceled {
4326                            room_id: room_id.to_proto(),
4327                        },
4328                    )
4329                    .trace_err();
4330            }
4331            contacts_to_update.insert(canceled_user_id);
4332        }
4333    }
4334
4335    for contact_user_id in contacts_to_update {
4336        update_user_contacts(contact_user_id, session).await?;
4337    }
4338
4339    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4340        live_kit
4341            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4342            .await
4343            .trace_err();
4344
4345        if delete_livekit_room {
4346            live_kit.delete_room(livekit_room).await.trace_err();
4347        }
4348    }
4349
4350    Ok(())
4351}
4352
4353async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4354    let left_channel_buffers = session
4355        .db()
4356        .await
4357        .leave_channel_buffers(session.connection_id)
4358        .await?;
4359
4360    for left_buffer in left_channel_buffers {
4361        channel_buffer_updated(
4362            session.connection_id,
4363            left_buffer.connections,
4364            &proto::UpdateChannelBufferCollaborators {
4365                channel_id: left_buffer.channel_id.to_proto(),
4366                collaborators: left_buffer.collaborators,
4367            },
4368            &session.peer,
4369        );
4370    }
4371
4372    Ok(())
4373}
4374
4375fn project_left(project: &db::LeftProject, session: &Session) {
4376    for connection_id in &project.connection_ids {
4377        if project.should_unshare {
4378            session
4379                .peer
4380                .send(
4381                    *connection_id,
4382                    proto::UnshareProject {
4383                        project_id: project.id.to_proto(),
4384                    },
4385                )
4386                .trace_err();
4387        } else {
4388            session
4389                .peer
4390                .send(
4391                    *connection_id,
4392                    proto::RemoveProjectCollaborator {
4393                        project_id: project.id.to_proto(),
4394                        peer_id: Some(session.connection_id.into()),
4395                    },
4396                )
4397                .trace_err();
4398        }
4399    }
4400}
4401
4402pub trait ResultExt {
4403    type Ok;
4404
4405    fn trace_err(self) -> Option<Self::Ok>;
4406}
4407
4408impl<T, E> ResultExt for Result<T, E>
4409where
4410    E: std::fmt::Debug,
4411{
4412    type Ok = T;
4413
4414    #[track_caller]
4415    fn trace_err(self) -> Option<T> {
4416        match self {
4417            Ok(value) => Some(value),
4418            Err(error) => {
4419                tracing::error!("{:?}", error);
4420                None
4421            }
4422        }
4423    }
4424}