rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::db::billing_subscription::SubscriptionKind;
   5use crate::llm::db::LlmDatabase;
   6use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
   7use crate::{
   8    AppState, Error, Result, auth,
   9    db::{
  10        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  11        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  12        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  13        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  14    },
  15    executor::Executor,
  16};
  17use anyhow::{Context as _, anyhow, bail};
  18use async_tungstenite::tungstenite::{
  19    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  20};
  21use axum::{
  22    Extension, Router, TypedHeader,
  23    body::Body,
  24    extract::{
  25        ConnectInfo, WebSocketUpgrade,
  26        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  27    },
  28    headers::{Header, HeaderName},
  29    http::StatusCode,
  30    middleware,
  31    response::IntoResponse,
  32    routing::get,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use reqwest_client::ReqwestClient;
  39use rpc::proto::split_repository_update;
  40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  41
  42use futures::{
  43    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  44    stream::FuturesUnordered,
  45};
  46use prometheus::{IntGauge, register_int_gauge};
  47use rpc::{
  48    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        Arc, OnceLock,
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{Semaphore, watch};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    Instrument,
  75    field::{self},
  76    info_span, instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    /// The GeoIP country code for the user.
 137    #[allow(unused)]
 138    geoip_country_code: Option<String>,
 139    system_id: Option<String>,
 140    _executor: Executor,
 141}
 142
 143impl Session {
 144    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 145        #[cfg(test)]
 146        tokio::task::yield_now().await;
 147        let guard = self.db.lock().await;
 148        #[cfg(test)]
 149        tokio::task::yield_now().await;
 150        guard
 151    }
 152
 153    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.connection_pool.lock();
 157        ConnectionPoolGuard {
 158            guard,
 159            _not_send: PhantomData,
 160        }
 161    }
 162
 163    fn is_staff(&self) -> bool {
 164        match &self.principal {
 165            Principal::User(user) => user.admin,
 166            Principal::Impersonated { .. } => true,
 167        }
 168    }
 169
 170    fn user_id(&self) -> UserId {
 171        match &self.principal {
 172            Principal::User(user) => user.id,
 173            Principal::Impersonated { user, .. } => user.id,
 174        }
 175    }
 176
 177    pub fn email(&self) -> Option<String> {
 178        match &self.principal {
 179            Principal::User(user) => user.email_address.clone(),
 180            Principal::Impersonated { user, .. } => user.email_address.clone(),
 181        }
 182    }
 183}
 184
 185impl Debug for Session {
 186    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 187        let mut result = f.debug_struct("Session");
 188        match &self.principal {
 189            Principal::User(user) => {
 190                result.field("user", &user.github_login);
 191            }
 192            Principal::Impersonated { user, admin } => {
 193                result.field("user", &user.github_login);
 194                result.field("impersonator", &admin.github_login);
 195            }
 196        }
 197        result.field("connection_id", &self.connection_id).finish()
 198    }
 199}
 200
 201struct DbHandle(Arc<Database>);
 202
 203impl Deref for DbHandle {
 204    type Target = Database;
 205
 206    fn deref(&self) -> &Self::Target {
 207        self.0.as_ref()
 208    }
 209}
 210
 211pub struct Server {
 212    id: parking_lot::Mutex<ServerId>,
 213    peer: Arc<Peer>,
 214    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 215    app_state: Arc<AppState>,
 216    handlers: HashMap<TypeId, MessageHandler>,
 217    teardown: watch::Sender<bool>,
 218}
 219
 220pub(crate) struct ConnectionPoolGuard<'a> {
 221    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 222    _not_send: PhantomData<Rc<()>>,
 223}
 224
 225#[derive(Serialize)]
 226pub struct ServerSnapshot<'a> {
 227    peer: &'a Peer,
 228    #[serde(serialize_with = "serialize_deref")]
 229    connection_pool: ConnectionPoolGuard<'a>,
 230}
 231
 232pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 233where
 234    S: Serializer,
 235    T: Deref<Target = U>,
 236    U: Serialize,
 237{
 238    Serialize::serialize(value.deref(), serializer)
 239}
 240
 241impl Server {
 242    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 243        let mut server = Self {
 244            id: parking_lot::Mutex::new(id),
 245            peer: Peer::new(id.0 as u32),
 246            app_state: app_state.clone(),
 247            connection_pool: Default::default(),
 248            handlers: Default::default(),
 249            teardown: watch::channel(false).0,
 250        };
 251
 252        server
 253            .add_request_handler(ping)
 254            .add_request_handler(create_room)
 255            .add_request_handler(join_room)
 256            .add_request_handler(rejoin_room)
 257            .add_request_handler(leave_room)
 258            .add_request_handler(set_room_participant_role)
 259            .add_request_handler(call)
 260            .add_request_handler(cancel_call)
 261            .add_message_handler(decline_call)
 262            .add_request_handler(update_participant_location)
 263            .add_request_handler(share_project)
 264            .add_message_handler(unshare_project)
 265            .add_request_handler(join_project)
 266            .add_message_handler(leave_project)
 267            .add_request_handler(update_project)
 268            .add_request_handler(update_worktree)
 269            .add_request_handler(update_repository)
 270            .add_request_handler(remove_repository)
 271            .add_message_handler(start_language_server)
 272            .add_message_handler(update_language_server)
 273            .add_message_handler(update_diagnostic_summary)
 274            .add_message_handler(update_worktree_settings)
 275            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 276            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 277            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 278            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 279            .add_request_handler(forward_find_search_candidates_request)
 280            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 281            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 282            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 283            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 284            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 285            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 286            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 287            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 288            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 289            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 290            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 291            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 292            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 293            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 294            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 295            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 296            .add_request_handler(
 297                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 298            )
 299            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 300            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 301            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 302            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 303            .add_request_handler(
 304                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 305            )
 306            .add_request_handler(
 307                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 308            )
 309            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 310            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 311            .add_request_handler(
 312                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 313            )
 314            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 315            .add_request_handler(
 316                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 317            )
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 319            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 320            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 321            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 322            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 323            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 324            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 325            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 326            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 327            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 328            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 329            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 330            .add_request_handler(
 331                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 332            )
 333            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 334            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 335            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 336            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 337            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 338            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 339            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 340            .add_message_handler(create_buffer_for_peer)
 341            .add_request_handler(update_buffer)
 342            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 343            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 344            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 345            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 346            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 347            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 348            .add_request_handler(get_users)
 349            .add_request_handler(fuzzy_search_users)
 350            .add_request_handler(request_contact)
 351            .add_request_handler(remove_contact)
 352            .add_request_handler(respond_to_contact_request)
 353            .add_message_handler(subscribe_to_channels)
 354            .add_request_handler(create_channel)
 355            .add_request_handler(delete_channel)
 356            .add_request_handler(invite_channel_member)
 357            .add_request_handler(remove_channel_member)
 358            .add_request_handler(set_channel_member_role)
 359            .add_request_handler(set_channel_visibility)
 360            .add_request_handler(rename_channel)
 361            .add_request_handler(join_channel_buffer)
 362            .add_request_handler(leave_channel_buffer)
 363            .add_message_handler(update_channel_buffer)
 364            .add_request_handler(rejoin_channel_buffers)
 365            .add_request_handler(get_channel_members)
 366            .add_request_handler(respond_to_channel_invite)
 367            .add_request_handler(join_channel)
 368            .add_request_handler(join_channel_chat)
 369            .add_message_handler(leave_channel_chat)
 370            .add_request_handler(send_channel_message)
 371            .add_request_handler(remove_channel_message)
 372            .add_request_handler(update_channel_message)
 373            .add_request_handler(get_channel_messages)
 374            .add_request_handler(get_channel_messages_by_id)
 375            .add_request_handler(get_notifications)
 376            .add_request_handler(mark_notification_as_read)
 377            .add_request_handler(move_channel)
 378            .add_request_handler(follow)
 379            .add_message_handler(unfollow)
 380            .add_message_handler(update_followers)
 381            .add_request_handler(get_private_user_info)
 382            .add_request_handler(get_llm_api_token)
 383            .add_request_handler(accept_terms_of_service)
 384            .add_message_handler(acknowledge_channel_message)
 385            .add_message_handler(acknowledge_buffer_version)
 386            .add_request_handler(get_supermaven_api_key)
 387            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 388            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 389            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 390            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 391            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 392            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 393            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 394            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 395            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 396            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 397            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 398            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 399            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 400            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 402            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 403            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 404            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 405            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 406            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 407            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 408            .add_message_handler(update_context);
 409
 410        Arc::new(server)
 411    }
 412
 413    pub async fn start(&self) -> Result<()> {
 414        let server_id = *self.id.lock();
 415        let app_state = self.app_state.clone();
 416        let peer = self.peer.clone();
 417        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 418        let pool = self.connection_pool.clone();
 419        let livekit_client = self.app_state.livekit_client.clone();
 420
 421        let span = info_span!("start server");
 422        self.app_state.executor.spawn_detached(
 423            async move {
 424                tracing::info!("waiting for cleanup timeout");
 425                timeout.await;
 426                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 427                if let Some((room_ids, channel_ids)) = app_state
 428                    .db
 429                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 430                    .await
 431                    .trace_err()
 432                {
 433                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 434                    tracing::info!(
 435                        stale_channel_buffer_count = channel_ids.len(),
 436                        "retrieved stale channel buffers"
 437                    );
 438
 439                    for channel_id in channel_ids {
 440                        if let Some(refreshed_channel_buffer) = app_state
 441                            .db
 442                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 443                            .await
 444                            .trace_err()
 445                        {
 446                            for connection_id in refreshed_channel_buffer.connection_ids {
 447                                peer.send(
 448                                    connection_id,
 449                                    proto::UpdateChannelBufferCollaborators {
 450                                        channel_id: channel_id.to_proto(),
 451                                        collaborators: refreshed_channel_buffer
 452                                            .collaborators
 453                                            .clone(),
 454                                    },
 455                                )
 456                                .trace_err();
 457                            }
 458                        }
 459                    }
 460
 461                    for room_id in room_ids {
 462                        let mut contacts_to_update = HashSet::default();
 463                        let mut canceled_calls_to_user_ids = Vec::new();
 464                        let mut livekit_room = String::new();
 465                        let mut delete_livekit_room = false;
 466
 467                        if let Some(mut refreshed_room) = app_state
 468                            .db
 469                            .clear_stale_room_participants(room_id, server_id)
 470                            .await
 471                            .trace_err()
 472                        {
 473                            tracing::info!(
 474                                room_id = room_id.0,
 475                                new_participant_count = refreshed_room.room.participants.len(),
 476                                "refreshed room"
 477                            );
 478                            room_updated(&refreshed_room.room, &peer);
 479                            if let Some(channel) = refreshed_room.channel.as_ref() {
 480                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 481                            }
 482                            contacts_to_update
 483                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 484                            contacts_to_update
 485                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 486                            canceled_calls_to_user_ids =
 487                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 488                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 489                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 490                        }
 491
 492                        {
 493                            let pool = pool.lock();
 494                            for canceled_user_id in canceled_calls_to_user_ids {
 495                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 496                                    peer.send(
 497                                        connection_id,
 498                                        proto::CallCanceled {
 499                                            room_id: room_id.to_proto(),
 500                                        },
 501                                    )
 502                                    .trace_err();
 503                                }
 504                            }
 505                        }
 506
 507                        for user_id in contacts_to_update {
 508                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 509                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 510                            if let Some((busy, contacts)) = busy.zip(contacts) {
 511                                let pool = pool.lock();
 512                                let updated_contact = contact_for_user(user_id, busy, &pool);
 513                                for contact in contacts {
 514                                    if let db::Contact::Accepted {
 515                                        user_id: contact_user_id,
 516                                        ..
 517                                    } = contact
 518                                    {
 519                                        for contact_conn_id in
 520                                            pool.user_connection_ids(contact_user_id)
 521                                        {
 522                                            peer.send(
 523                                                contact_conn_id,
 524                                                proto::UpdateContacts {
 525                                                    contacts: vec![updated_contact.clone()],
 526                                                    remove_contacts: Default::default(),
 527                                                    incoming_requests: Default::default(),
 528                                                    remove_incoming_requests: Default::default(),
 529                                                    outgoing_requests: Default::default(),
 530                                                    remove_outgoing_requests: Default::default(),
 531                                                },
 532                                            )
 533                                            .trace_err();
 534                                        }
 535                                    }
 536                                }
 537                            }
 538                        }
 539
 540                        if let Some(live_kit) = livekit_client.as_ref() {
 541                            if delete_livekit_room {
 542                                live_kit.delete_room(livekit_room).await.trace_err();
 543                            }
 544                        }
 545                    }
 546                }
 547
 548                app_state
 549                    .db
 550                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 551                    .await
 552                    .trace_err();
 553            }
 554            .instrument(span),
 555        );
 556        Ok(())
 557    }
 558
 559    pub fn teardown(&self) {
 560        self.peer.teardown();
 561        self.connection_pool.lock().reset();
 562        let _ = self.teardown.send(true);
 563    }
 564
 565    #[cfg(test)]
 566    pub fn reset(&self, id: ServerId) {
 567        self.teardown();
 568        *self.id.lock() = id;
 569        self.peer.reset(id.0 as u32);
 570        let _ = self.teardown.send(false);
 571    }
 572
 573    #[cfg(test)]
 574    pub fn id(&self) -> ServerId {
 575        *self.id.lock()
 576    }
 577
 578    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 579    where
 580        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 581        Fut: 'static + Send + Future<Output = Result<()>>,
 582        M: EnvelopedMessage,
 583    {
 584        let prev_handler = self.handlers.insert(
 585            TypeId::of::<M>(),
 586            Box::new(move |envelope, session| {
 587                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 588                let received_at = envelope.received_at;
 589                tracing::info!("message received");
 590                let start_time = Instant::now();
 591                let future = (handler)(*envelope, session);
 592                async move {
 593                    let result = future.await;
 594                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 595                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 596                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 597                    let payload_type = M::NAME;
 598
 599                    match result {
 600                        Err(error) => {
 601                            tracing::error!(
 602                                ?error,
 603                                total_duration_ms,
 604                                processing_duration_ms,
 605                                queue_duration_ms,
 606                                payload_type,
 607                                "error handling message"
 608                            )
 609                        }
 610                        Ok(()) => tracing::info!(
 611                            total_duration_ms,
 612                            processing_duration_ms,
 613                            queue_duration_ms,
 614                            "finished handling message"
 615                        ),
 616                    }
 617                }
 618                .boxed()
 619            }),
 620        );
 621        if prev_handler.is_some() {
 622            panic!("registered a handler for the same message twice");
 623        }
 624        self
 625    }
 626
 627    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 628    where
 629        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 630        Fut: 'static + Send + Future<Output = Result<()>>,
 631        M: EnvelopedMessage,
 632    {
 633        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 634        self
 635    }
 636
 637    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 638    where
 639        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 640        Fut: Send + Future<Output = Result<()>>,
 641        M: RequestMessage,
 642    {
 643        let handler = Arc::new(handler);
 644        self.add_handler(move |envelope, session| {
 645            let receipt = envelope.receipt();
 646            let handler = handler.clone();
 647            async move {
 648                let peer = session.peer.clone();
 649                let responded = Arc::new(AtomicBool::default());
 650                let response = Response {
 651                    peer: peer.clone(),
 652                    responded: responded.clone(),
 653                    receipt,
 654                };
 655                match (handler)(envelope.payload, response, session).await {
 656                    Ok(()) => {
 657                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 658                            Ok(())
 659                        } else {
 660                            Err(anyhow!("handler did not send a response"))?
 661                        }
 662                    }
 663                    Err(error) => {
 664                        let proto_err = match &error {
 665                            Error::Internal(err) => err.to_proto(),
 666                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 667                        };
 668                        peer.respond_with_error(receipt, proto_err)?;
 669                        Err(error)
 670                    }
 671                }
 672            }
 673        })
 674    }
 675
 676    pub fn handle_connection(
 677        self: &Arc<Self>,
 678        connection: Connection,
 679        address: String,
 680        principal: Principal,
 681        zed_version: ZedVersion,
 682        geoip_country_code: Option<String>,
 683        system_id: Option<String>,
 684        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 685        executor: Executor,
 686    ) -> impl Future<Output = ()> + use<> {
 687        let this = self.clone();
 688        let span = info_span!("handle connection", %address,
 689            connection_id=field::Empty,
 690            user_id=field::Empty,
 691            login=field::Empty,
 692            impersonator=field::Empty,
 693            geoip_country_code=field::Empty
 694        );
 695        principal.update_span(&span);
 696        if let Some(country_code) = geoip_country_code.as_ref() {
 697            span.record("geoip_country_code", country_code);
 698        }
 699
 700        let mut teardown = self.teardown.subscribe();
 701        async move {
 702            if *teardown.borrow() {
 703                tracing::error!("server is tearing down");
 704                return
 705            }
 706            let (connection_id, handle_io, mut incoming_rx) = this
 707                .peer
 708                .add_connection(connection, {
 709                    let executor = executor.clone();
 710                    move |duration| executor.sleep(duration)
 711                });
 712            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 713
 714            tracing::info!("connection opened");
 715
 716            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 717            let http_client = match ReqwestClient::user_agent(&user_agent) {
 718                Ok(http_client) => Arc::new(http_client),
 719                Err(error) => {
 720                    tracing::error!(?error, "failed to create HTTP client");
 721                    return;
 722                }
 723            };
 724
 725            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 726                    supermaven_admin_api_key.to_string(),
 727                    http_client.clone(),
 728                )));
 729
 730            let session = Session {
 731                principal: principal.clone(),
 732                connection_id,
 733                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 734                peer: this.peer.clone(),
 735                connection_pool: this.connection_pool.clone(),
 736                app_state: this.app_state.clone(),
 737                geoip_country_code,
 738                system_id,
 739                _executor: executor.clone(),
 740                supermaven_client,
 741            };
 742
 743            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 744                tracing::error!(?error, "failed to send initial client update");
 745                return;
 746            }
 747
 748            let handle_io = handle_io.fuse();
 749            futures::pin_mut!(handle_io);
 750
 751            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 752            // This prevents deadlocks when e.g., client A performs a request to client B and
 753            // client B performs a request to client A. If both clients stop processing further
 754            // messages until their respective request completes, they won't have a chance to
 755            // respond to the other client's request and cause a deadlock.
 756            //
 757            // This arrangement ensures we will attempt to process earlier messages first, but fall
 758            // back to processing messages arrived later in the spirit of making progress.
 759            let mut foreground_message_handlers = FuturesUnordered::new();
 760            let concurrent_handlers = Arc::new(Semaphore::new(256));
 761            loop {
 762                let next_message = async {
 763                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 764                    let message = incoming_rx.next().await;
 765                    (permit, message)
 766                }.fuse();
 767                futures::pin_mut!(next_message);
 768                futures::select_biased! {
 769                    _ = teardown.changed().fuse() => return,
 770                    result = handle_io => {
 771                        if let Err(error) = result {
 772                            tracing::error!(?error, "error handling I/O");
 773                        }
 774                        break;
 775                    }
 776                    _ = foreground_message_handlers.next() => {}
 777                    next_message = next_message => {
 778                        let (permit, message) = next_message;
 779                        if let Some(message) = message {
 780                            let type_name = message.payload_type_name();
 781                            // note: we copy all the fields from the parent span so we can query them in the logs.
 782                            // (https://github.com/tokio-rs/tracing/issues/2670).
 783                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 784                                user_id=field::Empty,
 785                                login=field::Empty,
 786                                impersonator=field::Empty,
 787                            );
 788                            principal.update_span(&span);
 789                            let span_enter = span.enter();
 790                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 791                                let is_background = message.is_background();
 792                                let handle_message = (handler)(message, session.clone());
 793                                drop(span_enter);
 794
 795                                let handle_message = async move {
 796                                    handle_message.await;
 797                                    drop(permit);
 798                                }.instrument(span);
 799                                if is_background {
 800                                    executor.spawn_detached(handle_message);
 801                                } else {
 802                                    foreground_message_handlers.push(handle_message);
 803                                }
 804                            } else {
 805                                tracing::error!("no message handler");
 806                            }
 807                        } else {
 808                            tracing::info!("connection closed");
 809                            break;
 810                        }
 811                    }
 812                }
 813            }
 814
 815            drop(foreground_message_handlers);
 816            tracing::info!("signing out");
 817            if let Err(error) = connection_lost(session, teardown, executor).await {
 818                tracing::error!(?error, "error signing out");
 819            }
 820
 821        }.instrument(span)
 822    }
 823
 824    async fn send_initial_client_update(
 825        &self,
 826        connection_id: ConnectionId,
 827        principal: &Principal,
 828        zed_version: ZedVersion,
 829        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 830        session: &Session,
 831    ) -> Result<()> {
 832        self.peer.send(
 833            connection_id,
 834            proto::Hello {
 835                peer_id: Some(connection_id.into()),
 836            },
 837        )?;
 838        tracing::info!("sent hello message");
 839        if let Some(send_connection_id) = send_connection_id.take() {
 840            let _ = send_connection_id.send(connection_id);
 841        }
 842
 843        match principal {
 844            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 845                if !user.connected_once {
 846                    self.peer.send(connection_id, proto::ShowContacts {})?;
 847                    self.app_state
 848                        .db
 849                        .set_user_connected_once(user.id, true)
 850                        .await?;
 851                }
 852
 853                update_user_plan(user.id, session).await?;
 854
 855                let contacts = self.app_state.db.get_contacts(user.id).await?;
 856
 857                {
 858                    let mut pool = self.connection_pool.lock();
 859                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 860                    self.peer.send(
 861                        connection_id,
 862                        build_initial_contacts_update(contacts, &pool),
 863                    )?;
 864                }
 865
 866                if should_auto_subscribe_to_channels(zed_version) {
 867                    subscribe_user_to_channels(user.id, session).await?;
 868                }
 869
 870                if let Some(incoming_call) =
 871                    self.app_state.db.incoming_call_for_user(user.id).await?
 872                {
 873                    self.peer.send(connection_id, incoming_call)?;
 874                }
 875
 876                update_user_contacts(user.id, session).await?;
 877            }
 878        }
 879
 880        Ok(())
 881    }
 882
 883    pub async fn invite_code_redeemed(
 884        self: &Arc<Self>,
 885        inviter_id: UserId,
 886        invitee_id: UserId,
 887    ) -> Result<()> {
 888        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 889            if let Some(code) = &user.invite_code {
 890                let pool = self.connection_pool.lock();
 891                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 892                for connection_id in pool.user_connection_ids(inviter_id) {
 893                    self.peer.send(
 894                        connection_id,
 895                        proto::UpdateContacts {
 896                            contacts: vec![invitee_contact.clone()],
 897                            ..Default::default()
 898                        },
 899                    )?;
 900                    self.peer.send(
 901                        connection_id,
 902                        proto::UpdateInviteInfo {
 903                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 904                            count: user.invite_count as u32,
 905                        },
 906                    )?;
 907                }
 908            }
 909        }
 910        Ok(())
 911    }
 912
 913    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 914        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 915            if let Some(invite_code) = &user.invite_code {
 916                let pool = self.connection_pool.lock();
 917                for connection_id in pool.user_connection_ids(user_id) {
 918                    self.peer.send(
 919                        connection_id,
 920                        proto::UpdateInviteInfo {
 921                            url: format!(
 922                                "{}{}",
 923                                self.app_state.config.invite_link_prefix, invite_code
 924                            ),
 925                            count: user.invite_count as u32,
 926                        },
 927                    )?;
 928                }
 929            }
 930        }
 931        Ok(())
 932    }
 933
 934    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 935        let user = self
 936            .app_state
 937            .db
 938            .get_user_by_id(user_id)
 939            .await?
 940            .ok_or_else(|| anyhow!("user not found"))?;
 941
 942        let update_user_plan = make_update_user_plan_message(
 943            &self.app_state.db,
 944            self.app_state.llm_db.clone(),
 945            user_id,
 946            user.admin,
 947        )
 948        .await?;
 949
 950        let pool = self.connection_pool.lock();
 951        for connection_id in pool.user_connection_ids(user_id) {
 952            self.peer
 953                .send(connection_id, update_user_plan.clone())
 954                .trace_err();
 955        }
 956
 957        Ok(())
 958    }
 959
 960    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 961        let pool = self.connection_pool.lock();
 962        for connection_id in pool.user_connection_ids(user_id) {
 963            self.peer
 964                .send(connection_id, proto::RefreshLlmToken {})
 965                .trace_err();
 966        }
 967    }
 968
 969    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 970        ServerSnapshot {
 971            connection_pool: ConnectionPoolGuard {
 972                guard: self.connection_pool.lock(),
 973                _not_send: PhantomData,
 974            },
 975            peer: &self.peer,
 976        }
 977    }
 978}
 979
 980impl Deref for ConnectionPoolGuard<'_> {
 981    type Target = ConnectionPool;
 982
 983    fn deref(&self) -> &Self::Target {
 984        &self.guard
 985    }
 986}
 987
 988impl DerefMut for ConnectionPoolGuard<'_> {
 989    fn deref_mut(&mut self) -> &mut Self::Target {
 990        &mut self.guard
 991    }
 992}
 993
 994impl Drop for ConnectionPoolGuard<'_> {
 995    fn drop(&mut self) {
 996        #[cfg(test)]
 997        self.check_invariants();
 998    }
 999}
1000
1001fn broadcast<F>(
1002    sender_id: Option<ConnectionId>,
1003    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1004    mut f: F,
1005) where
1006    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1007{
1008    for receiver_id in receiver_ids {
1009        if Some(receiver_id) != sender_id {
1010            if let Err(error) = f(receiver_id) {
1011                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1012            }
1013        }
1014    }
1015}
1016
1017pub struct ProtocolVersion(u32);
1018
1019impl Header for ProtocolVersion {
1020    fn name() -> &'static HeaderName {
1021        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1022        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1023    }
1024
1025    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1026    where
1027        Self: Sized,
1028        I: Iterator<Item = &'i axum::http::HeaderValue>,
1029    {
1030        let version = values
1031            .next()
1032            .ok_or_else(axum::headers::Error::invalid)?
1033            .to_str()
1034            .map_err(|_| axum::headers::Error::invalid())?
1035            .parse()
1036            .map_err(|_| axum::headers::Error::invalid())?;
1037        Ok(Self(version))
1038    }
1039
1040    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1041        values.extend([self.0.to_string().parse().unwrap()]);
1042    }
1043}
1044
1045pub struct AppVersionHeader(SemanticVersion);
1046impl Header for AppVersionHeader {
1047    fn name() -> &'static HeaderName {
1048        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1049        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1050    }
1051
1052    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1053    where
1054        Self: Sized,
1055        I: Iterator<Item = &'i axum::http::HeaderValue>,
1056    {
1057        let version = values
1058            .next()
1059            .ok_or_else(axum::headers::Error::invalid)?
1060            .to_str()
1061            .map_err(|_| axum::headers::Error::invalid())?
1062            .parse()
1063            .map_err(|_| axum::headers::Error::invalid())?;
1064        Ok(Self(version))
1065    }
1066
1067    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1068        values.extend([self.0.to_string().parse().unwrap()]);
1069    }
1070}
1071
1072pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1073    Router::new()
1074        .route("/rpc", get(handle_websocket_request))
1075        .layer(
1076            ServiceBuilder::new()
1077                .layer(Extension(server.app_state.clone()))
1078                .layer(middleware::from_fn(auth::validate_header)),
1079        )
1080        .route("/metrics", get(handle_metrics))
1081        .layer(Extension(server))
1082}
1083
1084pub async fn handle_websocket_request(
1085    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1086    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1087    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1088    Extension(server): Extension<Arc<Server>>,
1089    Extension(principal): Extension<Principal>,
1090    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1091    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1092    ws: WebSocketUpgrade,
1093) -> axum::response::Response {
1094    if protocol_version != rpc::PROTOCOL_VERSION {
1095        return (
1096            StatusCode::UPGRADE_REQUIRED,
1097            "client must be upgraded".to_string(),
1098        )
1099            .into_response();
1100    }
1101
1102    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1103        return (
1104            StatusCode::UPGRADE_REQUIRED,
1105            "no version header found".to_string(),
1106        )
1107            .into_response();
1108    };
1109
1110    if !version.can_collaborate() {
1111        return (
1112            StatusCode::UPGRADE_REQUIRED,
1113            "client must be upgraded".to_string(),
1114        )
1115            .into_response();
1116    }
1117
1118    let socket_address = socket_address.to_string();
1119    ws.on_upgrade(move |socket| {
1120        let socket = socket
1121            .map_ok(to_tungstenite_message)
1122            .err_into()
1123            .with(|message| async move { to_axum_message(message) });
1124        let connection = Connection::new(Box::pin(socket));
1125        async move {
1126            server
1127                .handle_connection(
1128                    connection,
1129                    socket_address,
1130                    principal,
1131                    version,
1132                    country_code_header.map(|header| header.to_string()),
1133                    system_id_header.map(|header| header.to_string()),
1134                    None,
1135                    Executor::Production,
1136                )
1137                .await;
1138        }
1139    })
1140}
1141
1142pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1143    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1144    let connections_metric = CONNECTIONS_METRIC
1145        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1146
1147    let connections = server
1148        .connection_pool
1149        .lock()
1150        .connections()
1151        .filter(|connection| !connection.admin)
1152        .count();
1153    connections_metric.set(connections as _);
1154
1155    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1156    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1157        register_int_gauge!(
1158            "shared_projects",
1159            "number of open projects with one or more guests"
1160        )
1161        .unwrap()
1162    });
1163
1164    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1165    shared_projects_metric.set(shared_projects as _);
1166
1167    let encoder = prometheus::TextEncoder::new();
1168    let metric_families = prometheus::gather();
1169    let encoded_metrics = encoder
1170        .encode_to_string(&metric_families)
1171        .map_err(|err| anyhow!("{}", err))?;
1172    Ok(encoded_metrics)
1173}
1174
1175#[instrument(err, skip(executor))]
1176async fn connection_lost(
1177    session: Session,
1178    mut teardown: watch::Receiver<bool>,
1179    executor: Executor,
1180) -> Result<()> {
1181    session.peer.disconnect(session.connection_id);
1182    session
1183        .connection_pool()
1184        .await
1185        .remove_connection(session.connection_id)?;
1186
1187    session
1188        .db()
1189        .await
1190        .connection_lost(session.connection_id)
1191        .await
1192        .trace_err();
1193
1194    futures::select_biased! {
1195        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1196
1197            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1198            leave_room_for_session(&session, session.connection_id).await.trace_err();
1199            leave_channel_buffers_for_session(&session)
1200                .await
1201                .trace_err();
1202
1203            if !session
1204                .connection_pool()
1205                .await
1206                .is_user_online(session.user_id())
1207            {
1208                let db = session.db().await;
1209                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1210                    room_updated(&room, &session.peer);
1211                }
1212            }
1213
1214            update_user_contacts(session.user_id(), &session).await?;
1215        },
1216        _ = teardown.changed().fuse() => {}
1217    }
1218
1219    Ok(())
1220}
1221
1222/// Acknowledges a ping from a client, used to keep the connection alive.
1223async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1224    response.send(proto::Ack {})?;
1225    Ok(())
1226}
1227
1228/// Creates a new room for calling (outside of channels)
1229async fn create_room(
1230    _request: proto::CreateRoom,
1231    response: Response<proto::CreateRoom>,
1232    session: Session,
1233) -> Result<()> {
1234    let livekit_room = nanoid::nanoid!(30);
1235
1236    let live_kit_connection_info = util::maybe!(async {
1237        let live_kit = session.app_state.livekit_client.as_ref();
1238        let live_kit = live_kit?;
1239        let user_id = session.user_id().to_string();
1240
1241        let token = live_kit
1242            .room_token(&livekit_room, &user_id.to_string())
1243            .trace_err()?;
1244
1245        Some(proto::LiveKitConnectionInfo {
1246            server_url: live_kit.url().into(),
1247            token,
1248            can_publish: true,
1249        })
1250    })
1251    .await;
1252
1253    let room = session
1254        .db()
1255        .await
1256        .create_room(session.user_id(), session.connection_id, &livekit_room)
1257        .await?;
1258
1259    response.send(proto::CreateRoomResponse {
1260        room: Some(room.clone()),
1261        live_kit_connection_info,
1262    })?;
1263
1264    update_user_contacts(session.user_id(), &session).await?;
1265    Ok(())
1266}
1267
1268/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1269async fn join_room(
1270    request: proto::JoinRoom,
1271    response: Response<proto::JoinRoom>,
1272    session: Session,
1273) -> Result<()> {
1274    let room_id = RoomId::from_proto(request.id);
1275
1276    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1277
1278    if let Some(channel_id) = channel_id {
1279        return join_channel_internal(channel_id, Box::new(response), session).await;
1280    }
1281
1282    let joined_room = {
1283        let room = session
1284            .db()
1285            .await
1286            .join_room(room_id, session.user_id(), session.connection_id)
1287            .await?;
1288        room_updated(&room.room, &session.peer);
1289        room.into_inner()
1290    };
1291
1292    for connection_id in session
1293        .connection_pool()
1294        .await
1295        .user_connection_ids(session.user_id())
1296    {
1297        session
1298            .peer
1299            .send(
1300                connection_id,
1301                proto::CallCanceled {
1302                    room_id: room_id.to_proto(),
1303                },
1304            )
1305            .trace_err();
1306    }
1307
1308    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1309    {
1310        live_kit
1311            .room_token(
1312                &joined_room.room.livekit_room,
1313                &session.user_id().to_string(),
1314            )
1315            .trace_err()
1316            .map(|token| proto::LiveKitConnectionInfo {
1317                server_url: live_kit.url().into(),
1318                token,
1319                can_publish: true,
1320            })
1321    } else {
1322        None
1323    };
1324
1325    response.send(proto::JoinRoomResponse {
1326        room: Some(joined_room.room),
1327        channel_id: None,
1328        live_kit_connection_info,
1329    })?;
1330
1331    update_user_contacts(session.user_id(), &session).await?;
1332    Ok(())
1333}
1334
1335/// Rejoin room is used to reconnect to a room after connection errors.
1336async fn rejoin_room(
1337    request: proto::RejoinRoom,
1338    response: Response<proto::RejoinRoom>,
1339    session: Session,
1340) -> Result<()> {
1341    let room;
1342    let channel;
1343    {
1344        let mut rejoined_room = session
1345            .db()
1346            .await
1347            .rejoin_room(request, session.user_id(), session.connection_id)
1348            .await?;
1349
1350        response.send(proto::RejoinRoomResponse {
1351            room: Some(rejoined_room.room.clone()),
1352            reshared_projects: rejoined_room
1353                .reshared_projects
1354                .iter()
1355                .map(|project| proto::ResharedProject {
1356                    id: project.id.to_proto(),
1357                    collaborators: project
1358                        .collaborators
1359                        .iter()
1360                        .map(|collaborator| collaborator.to_proto())
1361                        .collect(),
1362                })
1363                .collect(),
1364            rejoined_projects: rejoined_room
1365                .rejoined_projects
1366                .iter()
1367                .map(|rejoined_project| rejoined_project.to_proto())
1368                .collect(),
1369        })?;
1370        room_updated(&rejoined_room.room, &session.peer);
1371
1372        for project in &rejoined_room.reshared_projects {
1373            for collaborator in &project.collaborators {
1374                session
1375                    .peer
1376                    .send(
1377                        collaborator.connection_id,
1378                        proto::UpdateProjectCollaborator {
1379                            project_id: project.id.to_proto(),
1380                            old_peer_id: Some(project.old_connection_id.into()),
1381                            new_peer_id: Some(session.connection_id.into()),
1382                        },
1383                    )
1384                    .trace_err();
1385            }
1386
1387            broadcast(
1388                Some(session.connection_id),
1389                project
1390                    .collaborators
1391                    .iter()
1392                    .map(|collaborator| collaborator.connection_id),
1393                |connection_id| {
1394                    session.peer.forward_send(
1395                        session.connection_id,
1396                        connection_id,
1397                        proto::UpdateProject {
1398                            project_id: project.id.to_proto(),
1399                            worktrees: project.worktrees.clone(),
1400                        },
1401                    )
1402                },
1403            );
1404        }
1405
1406        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1407
1408        let rejoined_room = rejoined_room.into_inner();
1409
1410        room = rejoined_room.room;
1411        channel = rejoined_room.channel;
1412    }
1413
1414    if let Some(channel) = channel {
1415        channel_updated(
1416            &channel,
1417            &room,
1418            &session.peer,
1419            &*session.connection_pool().await,
1420        );
1421    }
1422
1423    update_user_contacts(session.user_id(), &session).await?;
1424    Ok(())
1425}
1426
1427fn notify_rejoined_projects(
1428    rejoined_projects: &mut Vec<RejoinedProject>,
1429    session: &Session,
1430) -> Result<()> {
1431    for project in rejoined_projects.iter() {
1432        for collaborator in &project.collaborators {
1433            session
1434                .peer
1435                .send(
1436                    collaborator.connection_id,
1437                    proto::UpdateProjectCollaborator {
1438                        project_id: project.id.to_proto(),
1439                        old_peer_id: Some(project.old_connection_id.into()),
1440                        new_peer_id: Some(session.connection_id.into()),
1441                    },
1442                )
1443                .trace_err();
1444        }
1445    }
1446
1447    for project in rejoined_projects {
1448        for worktree in mem::take(&mut project.worktrees) {
1449            // Stream this worktree's entries.
1450            let message = proto::UpdateWorktree {
1451                project_id: project.id.to_proto(),
1452                worktree_id: worktree.id,
1453                abs_path: worktree.abs_path.clone(),
1454                root_name: worktree.root_name,
1455                updated_entries: worktree.updated_entries,
1456                removed_entries: worktree.removed_entries,
1457                scan_id: worktree.scan_id,
1458                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1459                updated_repositories: worktree.updated_repositories,
1460                removed_repositories: worktree.removed_repositories,
1461            };
1462            for update in proto::split_worktree_update(message) {
1463                session.peer.send(session.connection_id, update)?;
1464            }
1465
1466            // Stream this worktree's diagnostics.
1467            for summary in worktree.diagnostic_summaries {
1468                session.peer.send(
1469                    session.connection_id,
1470                    proto::UpdateDiagnosticSummary {
1471                        project_id: project.id.to_proto(),
1472                        worktree_id: worktree.id,
1473                        summary: Some(summary),
1474                    },
1475                )?;
1476            }
1477
1478            for settings_file in worktree.settings_files {
1479                session.peer.send(
1480                    session.connection_id,
1481                    proto::UpdateWorktreeSettings {
1482                        project_id: project.id.to_proto(),
1483                        worktree_id: worktree.id,
1484                        path: settings_file.path,
1485                        content: Some(settings_file.content),
1486                        kind: Some(settings_file.kind.to_proto().into()),
1487                    },
1488                )?;
1489            }
1490        }
1491
1492        for repository in mem::take(&mut project.updated_repositories) {
1493            for update in split_repository_update(repository) {
1494                session.peer.send(session.connection_id, update)?;
1495            }
1496        }
1497
1498        for id in mem::take(&mut project.removed_repositories) {
1499            session.peer.send(
1500                session.connection_id,
1501                proto::RemoveRepository {
1502                    project_id: project.id.to_proto(),
1503                    id,
1504                },
1505            )?;
1506        }
1507    }
1508
1509    Ok(())
1510}
1511
1512/// leave room disconnects from the room.
1513async fn leave_room(
1514    _: proto::LeaveRoom,
1515    response: Response<proto::LeaveRoom>,
1516    session: Session,
1517) -> Result<()> {
1518    leave_room_for_session(&session, session.connection_id).await?;
1519    response.send(proto::Ack {})?;
1520    Ok(())
1521}
1522
1523/// Updates the permissions of someone else in the room.
1524async fn set_room_participant_role(
1525    request: proto::SetRoomParticipantRole,
1526    response: Response<proto::SetRoomParticipantRole>,
1527    session: Session,
1528) -> Result<()> {
1529    let user_id = UserId::from_proto(request.user_id);
1530    let role = ChannelRole::from(request.role());
1531
1532    let (livekit_room, can_publish) = {
1533        let room = session
1534            .db()
1535            .await
1536            .set_room_participant_role(
1537                session.user_id(),
1538                RoomId::from_proto(request.room_id),
1539                user_id,
1540                role,
1541            )
1542            .await?;
1543
1544        let livekit_room = room.livekit_room.clone();
1545        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1546        room_updated(&room, &session.peer);
1547        (livekit_room, can_publish)
1548    };
1549
1550    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1551        live_kit
1552            .update_participant(
1553                livekit_room.clone(),
1554                request.user_id.to_string(),
1555                livekit_api::proto::ParticipantPermission {
1556                    can_subscribe: true,
1557                    can_publish,
1558                    can_publish_data: can_publish,
1559                    hidden: false,
1560                    recorder: false,
1561                },
1562            )
1563            .await
1564            .trace_err();
1565    }
1566
1567    response.send(proto::Ack {})?;
1568    Ok(())
1569}
1570
1571/// Call someone else into the current room
1572async fn call(
1573    request: proto::Call,
1574    response: Response<proto::Call>,
1575    session: Session,
1576) -> Result<()> {
1577    let room_id = RoomId::from_proto(request.room_id);
1578    let calling_user_id = session.user_id();
1579    let calling_connection_id = session.connection_id;
1580    let called_user_id = UserId::from_proto(request.called_user_id);
1581    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1582    if !session
1583        .db()
1584        .await
1585        .has_contact(calling_user_id, called_user_id)
1586        .await?
1587    {
1588        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1589    }
1590
1591    let incoming_call = {
1592        let (room, incoming_call) = &mut *session
1593            .db()
1594            .await
1595            .call(
1596                room_id,
1597                calling_user_id,
1598                calling_connection_id,
1599                called_user_id,
1600                initial_project_id,
1601            )
1602            .await?;
1603        room_updated(room, &session.peer);
1604        mem::take(incoming_call)
1605    };
1606    update_user_contacts(called_user_id, &session).await?;
1607
1608    let mut calls = session
1609        .connection_pool()
1610        .await
1611        .user_connection_ids(called_user_id)
1612        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1613        .collect::<FuturesUnordered<_>>();
1614
1615    while let Some(call_response) = calls.next().await {
1616        match call_response.as_ref() {
1617            Ok(_) => {
1618                response.send(proto::Ack {})?;
1619                return Ok(());
1620            }
1621            Err(_) => {
1622                call_response.trace_err();
1623            }
1624        }
1625    }
1626
1627    {
1628        let room = session
1629            .db()
1630            .await
1631            .call_failed(room_id, called_user_id)
1632            .await?;
1633        room_updated(&room, &session.peer);
1634    }
1635    update_user_contacts(called_user_id, &session).await?;
1636
1637    Err(anyhow!("failed to ring user"))?
1638}
1639
1640/// Cancel an outgoing call.
1641async fn cancel_call(
1642    request: proto::CancelCall,
1643    response: Response<proto::CancelCall>,
1644    session: Session,
1645) -> Result<()> {
1646    let called_user_id = UserId::from_proto(request.called_user_id);
1647    let room_id = RoomId::from_proto(request.room_id);
1648    {
1649        let room = session
1650            .db()
1651            .await
1652            .cancel_call(room_id, session.connection_id, called_user_id)
1653            .await?;
1654        room_updated(&room, &session.peer);
1655    }
1656
1657    for connection_id in session
1658        .connection_pool()
1659        .await
1660        .user_connection_ids(called_user_id)
1661    {
1662        session
1663            .peer
1664            .send(
1665                connection_id,
1666                proto::CallCanceled {
1667                    room_id: room_id.to_proto(),
1668                },
1669            )
1670            .trace_err();
1671    }
1672    response.send(proto::Ack {})?;
1673
1674    update_user_contacts(called_user_id, &session).await?;
1675    Ok(())
1676}
1677
1678/// Decline an incoming call.
1679async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1680    let room_id = RoomId::from_proto(message.room_id);
1681    {
1682        let room = session
1683            .db()
1684            .await
1685            .decline_call(Some(room_id), session.user_id())
1686            .await?
1687            .ok_or_else(|| anyhow!("failed to decline call"))?;
1688        room_updated(&room, &session.peer);
1689    }
1690
1691    for connection_id in session
1692        .connection_pool()
1693        .await
1694        .user_connection_ids(session.user_id())
1695    {
1696        session
1697            .peer
1698            .send(
1699                connection_id,
1700                proto::CallCanceled {
1701                    room_id: room_id.to_proto(),
1702                },
1703            )
1704            .trace_err();
1705    }
1706    update_user_contacts(session.user_id(), &session).await?;
1707    Ok(())
1708}
1709
1710/// Updates other participants in the room with your current location.
1711async fn update_participant_location(
1712    request: proto::UpdateParticipantLocation,
1713    response: Response<proto::UpdateParticipantLocation>,
1714    session: Session,
1715) -> Result<()> {
1716    let room_id = RoomId::from_proto(request.room_id);
1717    let location = request
1718        .location
1719        .ok_or_else(|| anyhow!("invalid location"))?;
1720
1721    let db = session.db().await;
1722    let room = db
1723        .update_room_participant_location(room_id, session.connection_id, location)
1724        .await?;
1725
1726    room_updated(&room, &session.peer);
1727    response.send(proto::Ack {})?;
1728    Ok(())
1729}
1730
1731/// Share a project into the room.
1732async fn share_project(
1733    request: proto::ShareProject,
1734    response: Response<proto::ShareProject>,
1735    session: Session,
1736) -> Result<()> {
1737    let (project_id, room) = &*session
1738        .db()
1739        .await
1740        .share_project(
1741            RoomId::from_proto(request.room_id),
1742            session.connection_id,
1743            &request.worktrees,
1744            request.is_ssh_project,
1745        )
1746        .await?;
1747    response.send(proto::ShareProjectResponse {
1748        project_id: project_id.to_proto(),
1749    })?;
1750    room_updated(room, &session.peer);
1751
1752    Ok(())
1753}
1754
1755/// Unshare a project from the room.
1756async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1757    let project_id = ProjectId::from_proto(message.project_id);
1758    unshare_project_internal(project_id, session.connection_id, &session).await
1759}
1760
1761async fn unshare_project_internal(
1762    project_id: ProjectId,
1763    connection_id: ConnectionId,
1764    session: &Session,
1765) -> Result<()> {
1766    let delete = {
1767        let room_guard = session
1768            .db()
1769            .await
1770            .unshare_project(project_id, connection_id)
1771            .await?;
1772
1773        let (delete, room, guest_connection_ids) = &*room_guard;
1774
1775        let message = proto::UnshareProject {
1776            project_id: project_id.to_proto(),
1777        };
1778
1779        broadcast(
1780            Some(connection_id),
1781            guest_connection_ids.iter().copied(),
1782            |conn_id| session.peer.send(conn_id, message.clone()),
1783        );
1784        if let Some(room) = room {
1785            room_updated(room, &session.peer);
1786        }
1787
1788        *delete
1789    };
1790
1791    if delete {
1792        let db = session.db().await;
1793        db.delete_project(project_id).await?;
1794    }
1795
1796    Ok(())
1797}
1798
1799/// Join someone elses shared project.
1800async fn join_project(
1801    request: proto::JoinProject,
1802    response: Response<proto::JoinProject>,
1803    session: Session,
1804) -> Result<()> {
1805    let project_id = ProjectId::from_proto(request.project_id);
1806
1807    tracing::info!(%project_id, "join project");
1808
1809    let db = session.db().await;
1810    let (project, replica_id) = &mut *db
1811        .join_project(project_id, session.connection_id, session.user_id())
1812        .await?;
1813    drop(db);
1814    tracing::info!(%project_id, "join remote project");
1815    join_project_internal(response, session, project, replica_id)
1816}
1817
1818trait JoinProjectInternalResponse {
1819    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1820}
1821impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1822    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1823        Response::<proto::JoinProject>::send(self, result)
1824    }
1825}
1826
1827fn join_project_internal(
1828    response: impl JoinProjectInternalResponse,
1829    session: Session,
1830    project: &mut Project,
1831    replica_id: &ReplicaId,
1832) -> Result<()> {
1833    let collaborators = project
1834        .collaborators
1835        .iter()
1836        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1837        .map(|collaborator| collaborator.to_proto())
1838        .collect::<Vec<_>>();
1839    let project_id = project.id;
1840    let guest_user_id = session.user_id();
1841
1842    let worktrees = project
1843        .worktrees
1844        .iter()
1845        .map(|(id, worktree)| proto::WorktreeMetadata {
1846            id: *id,
1847            root_name: worktree.root_name.clone(),
1848            visible: worktree.visible,
1849            abs_path: worktree.abs_path.clone(),
1850        })
1851        .collect::<Vec<_>>();
1852
1853    let add_project_collaborator = proto::AddProjectCollaborator {
1854        project_id: project_id.to_proto(),
1855        collaborator: Some(proto::Collaborator {
1856            peer_id: Some(session.connection_id.into()),
1857            replica_id: replica_id.0 as u32,
1858            user_id: guest_user_id.to_proto(),
1859            is_host: false,
1860        }),
1861    };
1862
1863    for collaborator in &collaborators {
1864        session
1865            .peer
1866            .send(
1867                collaborator.peer_id.unwrap().into(),
1868                add_project_collaborator.clone(),
1869            )
1870            .trace_err();
1871    }
1872
1873    // First, we send the metadata associated with each worktree.
1874    response.send(proto::JoinProjectResponse {
1875        project_id: project.id.0 as u64,
1876        worktrees: worktrees.clone(),
1877        replica_id: replica_id.0 as u32,
1878        collaborators: collaborators.clone(),
1879        language_servers: project.language_servers.clone(),
1880        role: project.role.into(),
1881    })?;
1882
1883    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1884        // Stream this worktree's entries.
1885        let message = proto::UpdateWorktree {
1886            project_id: project_id.to_proto(),
1887            worktree_id,
1888            abs_path: worktree.abs_path.clone(),
1889            root_name: worktree.root_name,
1890            updated_entries: worktree.entries,
1891            removed_entries: Default::default(),
1892            scan_id: worktree.scan_id,
1893            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1894            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1895            removed_repositories: Default::default(),
1896        };
1897        for update in proto::split_worktree_update(message) {
1898            session.peer.send(session.connection_id, update.clone())?;
1899        }
1900
1901        // Stream this worktree's diagnostics.
1902        for summary in worktree.diagnostic_summaries {
1903            session.peer.send(
1904                session.connection_id,
1905                proto::UpdateDiagnosticSummary {
1906                    project_id: project_id.to_proto(),
1907                    worktree_id: worktree.id,
1908                    summary: Some(summary),
1909                },
1910            )?;
1911        }
1912
1913        for settings_file in worktree.settings_files {
1914            session.peer.send(
1915                session.connection_id,
1916                proto::UpdateWorktreeSettings {
1917                    project_id: project_id.to_proto(),
1918                    worktree_id: worktree.id,
1919                    path: settings_file.path,
1920                    content: Some(settings_file.content),
1921                    kind: Some(settings_file.kind.to_proto() as i32),
1922                },
1923            )?;
1924        }
1925    }
1926
1927    for repository in mem::take(&mut project.repositories) {
1928        for update in split_repository_update(repository) {
1929            session.peer.send(session.connection_id, update)?;
1930        }
1931    }
1932
1933    for language_server in &project.language_servers {
1934        session.peer.send(
1935            session.connection_id,
1936            proto::UpdateLanguageServer {
1937                project_id: project_id.to_proto(),
1938                language_server_id: language_server.id,
1939                variant: Some(
1940                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1941                        proto::LspDiskBasedDiagnosticsUpdated {},
1942                    ),
1943                ),
1944            },
1945        )?;
1946    }
1947
1948    Ok(())
1949}
1950
1951/// Leave someone elses shared project.
1952async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1953    let sender_id = session.connection_id;
1954    let project_id = ProjectId::from_proto(request.project_id);
1955    let db = session.db().await;
1956
1957    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1958    tracing::info!(
1959        %project_id,
1960        "leave project"
1961    );
1962
1963    project_left(project, &session);
1964    if let Some(room) = room {
1965        room_updated(room, &session.peer);
1966    }
1967
1968    Ok(())
1969}
1970
1971/// Updates other participants with changes to the project
1972async fn update_project(
1973    request: proto::UpdateProject,
1974    response: Response<proto::UpdateProject>,
1975    session: Session,
1976) -> Result<()> {
1977    let project_id = ProjectId::from_proto(request.project_id);
1978    let (room, guest_connection_ids) = &*session
1979        .db()
1980        .await
1981        .update_project(project_id, session.connection_id, &request.worktrees)
1982        .await?;
1983    broadcast(
1984        Some(session.connection_id),
1985        guest_connection_ids.iter().copied(),
1986        |connection_id| {
1987            session
1988                .peer
1989                .forward_send(session.connection_id, connection_id, request.clone())
1990        },
1991    );
1992    if let Some(room) = room {
1993        room_updated(room, &session.peer);
1994    }
1995    response.send(proto::Ack {})?;
1996
1997    Ok(())
1998}
1999
2000/// Updates other participants with changes to the worktree
2001async fn update_worktree(
2002    request: proto::UpdateWorktree,
2003    response: Response<proto::UpdateWorktree>,
2004    session: Session,
2005) -> Result<()> {
2006    let guest_connection_ids = session
2007        .db()
2008        .await
2009        .update_worktree(&request, session.connection_id)
2010        .await?;
2011
2012    broadcast(
2013        Some(session.connection_id),
2014        guest_connection_ids.iter().copied(),
2015        |connection_id| {
2016            session
2017                .peer
2018                .forward_send(session.connection_id, connection_id, request.clone())
2019        },
2020    );
2021    response.send(proto::Ack {})?;
2022    Ok(())
2023}
2024
2025async fn update_repository(
2026    request: proto::UpdateRepository,
2027    response: Response<proto::UpdateRepository>,
2028    session: Session,
2029) -> Result<()> {
2030    let guest_connection_ids = session
2031        .db()
2032        .await
2033        .update_repository(&request, session.connection_id)
2034        .await?;
2035
2036    broadcast(
2037        Some(session.connection_id),
2038        guest_connection_ids.iter().copied(),
2039        |connection_id| {
2040            session
2041                .peer
2042                .forward_send(session.connection_id, connection_id, request.clone())
2043        },
2044    );
2045    response.send(proto::Ack {})?;
2046    Ok(())
2047}
2048
2049async fn remove_repository(
2050    request: proto::RemoveRepository,
2051    response: Response<proto::RemoveRepository>,
2052    session: Session,
2053) -> Result<()> {
2054    let guest_connection_ids = session
2055        .db()
2056        .await
2057        .remove_repository(&request, session.connection_id)
2058        .await?;
2059
2060    broadcast(
2061        Some(session.connection_id),
2062        guest_connection_ids.iter().copied(),
2063        |connection_id| {
2064            session
2065                .peer
2066                .forward_send(session.connection_id, connection_id, request.clone())
2067        },
2068    );
2069    response.send(proto::Ack {})?;
2070    Ok(())
2071}
2072
2073/// Updates other participants with changes to the diagnostics
2074async fn update_diagnostic_summary(
2075    message: proto::UpdateDiagnosticSummary,
2076    session: Session,
2077) -> Result<()> {
2078    let guest_connection_ids = session
2079        .db()
2080        .await
2081        .update_diagnostic_summary(&message, session.connection_id)
2082        .await?;
2083
2084    broadcast(
2085        Some(session.connection_id),
2086        guest_connection_ids.iter().copied(),
2087        |connection_id| {
2088            session
2089                .peer
2090                .forward_send(session.connection_id, connection_id, message.clone())
2091        },
2092    );
2093
2094    Ok(())
2095}
2096
2097/// Updates other participants with changes to the worktree settings
2098async fn update_worktree_settings(
2099    message: proto::UpdateWorktreeSettings,
2100    session: Session,
2101) -> Result<()> {
2102    let guest_connection_ids = session
2103        .db()
2104        .await
2105        .update_worktree_settings(&message, session.connection_id)
2106        .await?;
2107
2108    broadcast(
2109        Some(session.connection_id),
2110        guest_connection_ids.iter().copied(),
2111        |connection_id| {
2112            session
2113                .peer
2114                .forward_send(session.connection_id, connection_id, message.clone())
2115        },
2116    );
2117
2118    Ok(())
2119}
2120
2121/// Notify other participants that a language server has started.
2122async fn start_language_server(
2123    request: proto::StartLanguageServer,
2124    session: Session,
2125) -> Result<()> {
2126    let guest_connection_ids = session
2127        .db()
2128        .await
2129        .start_language_server(&request, session.connection_id)
2130        .await?;
2131
2132    broadcast(
2133        Some(session.connection_id),
2134        guest_connection_ids.iter().copied(),
2135        |connection_id| {
2136            session
2137                .peer
2138                .forward_send(session.connection_id, connection_id, request.clone())
2139        },
2140    );
2141    Ok(())
2142}
2143
2144/// Notify other participants that a language server has changed.
2145async fn update_language_server(
2146    request: proto::UpdateLanguageServer,
2147    session: Session,
2148) -> Result<()> {
2149    let project_id = ProjectId::from_proto(request.project_id);
2150    let project_connection_ids = session
2151        .db()
2152        .await
2153        .project_connection_ids(project_id, session.connection_id, true)
2154        .await?;
2155    broadcast(
2156        Some(session.connection_id),
2157        project_connection_ids.iter().copied(),
2158        |connection_id| {
2159            session
2160                .peer
2161                .forward_send(session.connection_id, connection_id, request.clone())
2162        },
2163    );
2164    Ok(())
2165}
2166
2167/// forward a project request to the host. These requests should be read only
2168/// as guests are allowed to send them.
2169async fn forward_read_only_project_request<T>(
2170    request: T,
2171    response: Response<T>,
2172    session: Session,
2173) -> Result<()>
2174where
2175    T: EntityMessage + RequestMessage,
2176{
2177    let project_id = ProjectId::from_proto(request.remote_entity_id());
2178    let host_connection_id = session
2179        .db()
2180        .await
2181        .host_for_read_only_project_request(project_id, session.connection_id)
2182        .await?;
2183    let payload = session
2184        .peer
2185        .forward_request(session.connection_id, host_connection_id, request)
2186        .await?;
2187    response.send(payload)?;
2188    Ok(())
2189}
2190
2191async fn forward_find_search_candidates_request(
2192    request: proto::FindSearchCandidates,
2193    response: Response<proto::FindSearchCandidates>,
2194    session: Session,
2195) -> Result<()> {
2196    let project_id = ProjectId::from_proto(request.remote_entity_id());
2197    let host_connection_id = session
2198        .db()
2199        .await
2200        .host_for_read_only_project_request(project_id, session.connection_id)
2201        .await?;
2202    let payload = session
2203        .peer
2204        .forward_request(session.connection_id, host_connection_id, request)
2205        .await?;
2206    response.send(payload)?;
2207    Ok(())
2208}
2209
2210/// forward a project request to the host. These requests are disallowed
2211/// for guests.
2212async fn forward_mutating_project_request<T>(
2213    request: T,
2214    response: Response<T>,
2215    session: Session,
2216) -> Result<()>
2217where
2218    T: EntityMessage + RequestMessage,
2219{
2220    let project_id = ProjectId::from_proto(request.remote_entity_id());
2221
2222    let host_connection_id = session
2223        .db()
2224        .await
2225        .host_for_mutating_project_request(project_id, session.connection_id)
2226        .await?;
2227    let payload = session
2228        .peer
2229        .forward_request(session.connection_id, host_connection_id, request)
2230        .await?;
2231    response.send(payload)?;
2232    Ok(())
2233}
2234
2235/// Notify other participants that a new buffer has been created
2236async fn create_buffer_for_peer(
2237    request: proto::CreateBufferForPeer,
2238    session: Session,
2239) -> Result<()> {
2240    session
2241        .db()
2242        .await
2243        .check_user_is_project_host(
2244            ProjectId::from_proto(request.project_id),
2245            session.connection_id,
2246        )
2247        .await?;
2248    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2249    session
2250        .peer
2251        .forward_send(session.connection_id, peer_id.into(), request)?;
2252    Ok(())
2253}
2254
2255/// Notify other participants that a buffer has been updated. This is
2256/// allowed for guests as long as the update is limited to selections.
2257async fn update_buffer(
2258    request: proto::UpdateBuffer,
2259    response: Response<proto::UpdateBuffer>,
2260    session: Session,
2261) -> Result<()> {
2262    let project_id = ProjectId::from_proto(request.project_id);
2263    let mut capability = Capability::ReadOnly;
2264
2265    for op in request.operations.iter() {
2266        match op.variant {
2267            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2268            Some(_) => capability = Capability::ReadWrite,
2269        }
2270    }
2271
2272    let host = {
2273        let guard = session
2274            .db()
2275            .await
2276            .connections_for_buffer_update(project_id, session.connection_id, capability)
2277            .await?;
2278
2279        let (host, guests) = &*guard;
2280
2281        broadcast(
2282            Some(session.connection_id),
2283            guests.clone(),
2284            |connection_id| {
2285                session
2286                    .peer
2287                    .forward_send(session.connection_id, connection_id, request.clone())
2288            },
2289        );
2290
2291        *host
2292    };
2293
2294    if host != session.connection_id {
2295        session
2296            .peer
2297            .forward_request(session.connection_id, host, request.clone())
2298            .await?;
2299    }
2300
2301    response.send(proto::Ack {})?;
2302    Ok(())
2303}
2304
2305async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2306    let project_id = ProjectId::from_proto(message.project_id);
2307
2308    let operation = message.operation.as_ref().context("invalid operation")?;
2309    let capability = match operation.variant.as_ref() {
2310        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2311            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2312                match buffer_op.variant {
2313                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2314                        Capability::ReadOnly
2315                    }
2316                    _ => Capability::ReadWrite,
2317                }
2318            } else {
2319                Capability::ReadWrite
2320            }
2321        }
2322        Some(_) => Capability::ReadWrite,
2323        None => Capability::ReadOnly,
2324    };
2325
2326    let guard = session
2327        .db()
2328        .await
2329        .connections_for_buffer_update(project_id, session.connection_id, capability)
2330        .await?;
2331
2332    let (host, guests) = &*guard;
2333
2334    broadcast(
2335        Some(session.connection_id),
2336        guests.iter().chain([host]).copied(),
2337        |connection_id| {
2338            session
2339                .peer
2340                .forward_send(session.connection_id, connection_id, message.clone())
2341        },
2342    );
2343
2344    Ok(())
2345}
2346
2347/// Notify other participants that a project has been updated.
2348async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2349    request: T,
2350    session: Session,
2351) -> Result<()> {
2352    let project_id = ProjectId::from_proto(request.remote_entity_id());
2353    let project_connection_ids = session
2354        .db()
2355        .await
2356        .project_connection_ids(project_id, session.connection_id, false)
2357        .await?;
2358
2359    broadcast(
2360        Some(session.connection_id),
2361        project_connection_ids.iter().copied(),
2362        |connection_id| {
2363            session
2364                .peer
2365                .forward_send(session.connection_id, connection_id, request.clone())
2366        },
2367    );
2368    Ok(())
2369}
2370
2371/// Start following another user in a call.
2372async fn follow(
2373    request: proto::Follow,
2374    response: Response<proto::Follow>,
2375    session: Session,
2376) -> Result<()> {
2377    let room_id = RoomId::from_proto(request.room_id);
2378    let project_id = request.project_id.map(ProjectId::from_proto);
2379    let leader_id = request
2380        .leader_id
2381        .ok_or_else(|| anyhow!("invalid leader id"))?
2382        .into();
2383    let follower_id = session.connection_id;
2384
2385    session
2386        .db()
2387        .await
2388        .check_room_participants(room_id, leader_id, session.connection_id)
2389        .await?;
2390
2391    let response_payload = session
2392        .peer
2393        .forward_request(session.connection_id, leader_id, request)
2394        .await?;
2395    response.send(response_payload)?;
2396
2397    if let Some(project_id) = project_id {
2398        let room = session
2399            .db()
2400            .await
2401            .follow(room_id, project_id, leader_id, follower_id)
2402            .await?;
2403        room_updated(&room, &session.peer);
2404    }
2405
2406    Ok(())
2407}
2408
2409/// Stop following another user in a call.
2410async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2411    let room_id = RoomId::from_proto(request.room_id);
2412    let project_id = request.project_id.map(ProjectId::from_proto);
2413    let leader_id = request
2414        .leader_id
2415        .ok_or_else(|| anyhow!("invalid leader id"))?
2416        .into();
2417    let follower_id = session.connection_id;
2418
2419    session
2420        .db()
2421        .await
2422        .check_room_participants(room_id, leader_id, session.connection_id)
2423        .await?;
2424
2425    session
2426        .peer
2427        .forward_send(session.connection_id, leader_id, request)?;
2428
2429    if let Some(project_id) = project_id {
2430        let room = session
2431            .db()
2432            .await
2433            .unfollow(room_id, project_id, leader_id, follower_id)
2434            .await?;
2435        room_updated(&room, &session.peer);
2436    }
2437
2438    Ok(())
2439}
2440
2441/// Notify everyone following you of your current location.
2442async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2443    let room_id = RoomId::from_proto(request.room_id);
2444    let database = session.db.lock().await;
2445
2446    let connection_ids = if let Some(project_id) = request.project_id {
2447        let project_id = ProjectId::from_proto(project_id);
2448        database
2449            .project_connection_ids(project_id, session.connection_id, true)
2450            .await?
2451    } else {
2452        database
2453            .room_connection_ids(room_id, session.connection_id)
2454            .await?
2455    };
2456
2457    // For now, don't send view update messages back to that view's current leader.
2458    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2459        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2460        _ => None,
2461    });
2462
2463    for connection_id in connection_ids.iter().cloned() {
2464        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2465            session
2466                .peer
2467                .forward_send(session.connection_id, connection_id, request.clone())?;
2468        }
2469    }
2470    Ok(())
2471}
2472
2473/// Get public data about users.
2474async fn get_users(
2475    request: proto::GetUsers,
2476    response: Response<proto::GetUsers>,
2477    session: Session,
2478) -> Result<()> {
2479    let user_ids = request
2480        .user_ids
2481        .into_iter()
2482        .map(UserId::from_proto)
2483        .collect();
2484    let users = session
2485        .db()
2486        .await
2487        .get_users_by_ids(user_ids)
2488        .await?
2489        .into_iter()
2490        .map(|user| proto::User {
2491            id: user.id.to_proto(),
2492            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2493            github_login: user.github_login,
2494            email: user.email_address,
2495            name: user.name,
2496        })
2497        .collect();
2498    response.send(proto::UsersResponse { users })?;
2499    Ok(())
2500}
2501
2502/// Search for users (to invite) buy Github login
2503async fn fuzzy_search_users(
2504    request: proto::FuzzySearchUsers,
2505    response: Response<proto::FuzzySearchUsers>,
2506    session: Session,
2507) -> Result<()> {
2508    let query = request.query;
2509    let users = match query.len() {
2510        0 => vec![],
2511        1 | 2 => session
2512            .db()
2513            .await
2514            .get_user_by_github_login(&query)
2515            .await?
2516            .into_iter()
2517            .collect(),
2518        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2519    };
2520    let users = users
2521        .into_iter()
2522        .filter(|user| user.id != session.user_id())
2523        .map(|user| proto::User {
2524            id: user.id.to_proto(),
2525            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2526            github_login: user.github_login,
2527            name: user.name,
2528            email: user.email_address,
2529        })
2530        .collect();
2531    response.send(proto::UsersResponse { users })?;
2532    Ok(())
2533}
2534
2535/// Send a contact request to another user.
2536async fn request_contact(
2537    request: proto::RequestContact,
2538    response: Response<proto::RequestContact>,
2539    session: Session,
2540) -> Result<()> {
2541    let requester_id = session.user_id();
2542    let responder_id = UserId::from_proto(request.responder_id);
2543    if requester_id == responder_id {
2544        return Err(anyhow!("cannot add yourself as a contact"))?;
2545    }
2546
2547    let notifications = session
2548        .db()
2549        .await
2550        .send_contact_request(requester_id, responder_id)
2551        .await?;
2552
2553    // Update outgoing contact requests of requester
2554    let mut update = proto::UpdateContacts::default();
2555    update.outgoing_requests.push(responder_id.to_proto());
2556    for connection_id in session
2557        .connection_pool()
2558        .await
2559        .user_connection_ids(requester_id)
2560    {
2561        session.peer.send(connection_id, update.clone())?;
2562    }
2563
2564    // Update incoming contact requests of responder
2565    let mut update = proto::UpdateContacts::default();
2566    update
2567        .incoming_requests
2568        .push(proto::IncomingContactRequest {
2569            requester_id: requester_id.to_proto(),
2570        });
2571    let connection_pool = session.connection_pool().await;
2572    for connection_id in connection_pool.user_connection_ids(responder_id) {
2573        session.peer.send(connection_id, update.clone())?;
2574    }
2575
2576    send_notifications(&connection_pool, &session.peer, notifications);
2577
2578    response.send(proto::Ack {})?;
2579    Ok(())
2580}
2581
2582/// Accept or decline a contact request
2583async fn respond_to_contact_request(
2584    request: proto::RespondToContactRequest,
2585    response: Response<proto::RespondToContactRequest>,
2586    session: Session,
2587) -> Result<()> {
2588    let responder_id = session.user_id();
2589    let requester_id = UserId::from_proto(request.requester_id);
2590    let db = session.db().await;
2591    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2592        db.dismiss_contact_notification(responder_id, requester_id)
2593            .await?;
2594    } else {
2595        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2596
2597        let notifications = db
2598            .respond_to_contact_request(responder_id, requester_id, accept)
2599            .await?;
2600        let requester_busy = db.is_user_busy(requester_id).await?;
2601        let responder_busy = db.is_user_busy(responder_id).await?;
2602
2603        let pool = session.connection_pool().await;
2604        // Update responder with new contact
2605        let mut update = proto::UpdateContacts::default();
2606        if accept {
2607            update
2608                .contacts
2609                .push(contact_for_user(requester_id, requester_busy, &pool));
2610        }
2611        update
2612            .remove_incoming_requests
2613            .push(requester_id.to_proto());
2614        for connection_id in pool.user_connection_ids(responder_id) {
2615            session.peer.send(connection_id, update.clone())?;
2616        }
2617
2618        // Update requester with new contact
2619        let mut update = proto::UpdateContacts::default();
2620        if accept {
2621            update
2622                .contacts
2623                .push(contact_for_user(responder_id, responder_busy, &pool));
2624        }
2625        update
2626            .remove_outgoing_requests
2627            .push(responder_id.to_proto());
2628
2629        for connection_id in pool.user_connection_ids(requester_id) {
2630            session.peer.send(connection_id, update.clone())?;
2631        }
2632
2633        send_notifications(&pool, &session.peer, notifications);
2634    }
2635
2636    response.send(proto::Ack {})?;
2637    Ok(())
2638}
2639
2640/// Remove a contact.
2641async fn remove_contact(
2642    request: proto::RemoveContact,
2643    response: Response<proto::RemoveContact>,
2644    session: Session,
2645) -> Result<()> {
2646    let requester_id = session.user_id();
2647    let responder_id = UserId::from_proto(request.user_id);
2648    let db = session.db().await;
2649    let (contact_accepted, deleted_notification_id) =
2650        db.remove_contact(requester_id, responder_id).await?;
2651
2652    let pool = session.connection_pool().await;
2653    // Update outgoing contact requests of requester
2654    let mut update = proto::UpdateContacts::default();
2655    if contact_accepted {
2656        update.remove_contacts.push(responder_id.to_proto());
2657    } else {
2658        update
2659            .remove_outgoing_requests
2660            .push(responder_id.to_proto());
2661    }
2662    for connection_id in pool.user_connection_ids(requester_id) {
2663        session.peer.send(connection_id, update.clone())?;
2664    }
2665
2666    // Update incoming contact requests of responder
2667    let mut update = proto::UpdateContacts::default();
2668    if contact_accepted {
2669        update.remove_contacts.push(requester_id.to_proto());
2670    } else {
2671        update
2672            .remove_incoming_requests
2673            .push(requester_id.to_proto());
2674    }
2675    for connection_id in pool.user_connection_ids(responder_id) {
2676        session.peer.send(connection_id, update.clone())?;
2677        if let Some(notification_id) = deleted_notification_id {
2678            session.peer.send(
2679                connection_id,
2680                proto::DeleteNotification {
2681                    notification_id: notification_id.to_proto(),
2682                },
2683            )?;
2684        }
2685    }
2686
2687    response.send(proto::Ack {})?;
2688    Ok(())
2689}
2690
2691fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2692    version.0.minor() < 139
2693}
2694
2695async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2696    if is_staff {
2697        return Ok(proto::Plan::ZedPro);
2698    }
2699
2700    let subscription = db.get_active_billing_subscription(user_id).await?;
2701    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2702
2703    let plan = if let Some(subscription_kind) = subscription_kind {
2704        match subscription_kind {
2705            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2706            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2707            SubscriptionKind::ZedFree => proto::Plan::Free,
2708        }
2709    } else {
2710        proto::Plan::Free
2711    };
2712
2713    Ok(plan)
2714}
2715
2716async fn make_update_user_plan_message(
2717    db: &Arc<Database>,
2718    llm_db: Option<Arc<LlmDatabase>>,
2719    user_id: UserId,
2720    is_staff: bool,
2721) -> Result<proto::UpdateUserPlan> {
2722    let feature_flags = db.get_user_flags(user_id).await?;
2723    let plan = current_plan(db, user_id, is_staff).await?;
2724    let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2725    let billing_preferences = db.get_billing_preferences(user_id).await?;
2726
2727    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2728        let subscription = db.get_active_billing_subscription(user_id).await?;
2729
2730        let subscription_period =
2731            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2732
2733        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2734            llm_db
2735                .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2736                .await?
2737        } else {
2738            None
2739        };
2740
2741        (subscription_period, usage)
2742    } else {
2743        (None, None)
2744    };
2745
2746    Ok(proto::UpdateUserPlan {
2747        plan: plan.into(),
2748        trial_started_at: billing_customer
2749            .and_then(|billing_customer| billing_customer.trial_started_at)
2750            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2751        is_usage_based_billing_enabled: if is_staff {
2752            Some(true)
2753        } else {
2754            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2755        },
2756        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2757            proto::SubscriptionPeriod {
2758                started_at: started_at.timestamp() as u64,
2759                ended_at: ended_at.timestamp() as u64,
2760            }
2761        }),
2762        usage: usage.map(|usage| {
2763            let plan = match plan {
2764                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2765                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2766                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2767            };
2768
2769            let model_requests_limit = match plan.model_requests_limit() {
2770                zed_llm_client::UsageLimit::Limited(limit) => {
2771                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2772                        && feature_flags
2773                            .iter()
2774                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2775                    {
2776                        1_000
2777                    } else {
2778                        limit
2779                    };
2780
2781                    zed_llm_client::UsageLimit::Limited(limit)
2782                }
2783                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2784            };
2785
2786            proto::SubscriptionUsage {
2787                model_requests_usage_amount: usage.model_requests as u32,
2788                model_requests_usage_limit: Some(proto::UsageLimit {
2789                    variant: Some(match model_requests_limit {
2790                        zed_llm_client::UsageLimit::Limited(limit) => {
2791                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2792                                limit: limit as u32,
2793                            })
2794                        }
2795                        zed_llm_client::UsageLimit::Unlimited => {
2796                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2797                        }
2798                    }),
2799                }),
2800                edit_predictions_usage_amount: usage.edit_predictions as u32,
2801                edit_predictions_usage_limit: Some(proto::UsageLimit {
2802                    variant: Some(match plan.edit_predictions_limit() {
2803                        zed_llm_client::UsageLimit::Limited(limit) => {
2804                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2805                                limit: limit as u32,
2806                            })
2807                        }
2808                        zed_llm_client::UsageLimit::Unlimited => {
2809                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2810                        }
2811                    }),
2812                }),
2813            }
2814        }),
2815    })
2816}
2817
2818async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2819    let db = session.db().await;
2820
2821    let update_user_plan = make_update_user_plan_message(
2822        &db.0,
2823        session.app_state.llm_db.clone(),
2824        user_id,
2825        session.is_staff(),
2826    )
2827    .await?;
2828
2829    session
2830        .peer
2831        .send(session.connection_id, update_user_plan)
2832        .trace_err();
2833
2834    Ok(())
2835}
2836
2837async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2838    subscribe_user_to_channels(session.user_id(), &session).await?;
2839    Ok(())
2840}
2841
2842async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2843    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2844    let mut pool = session.connection_pool().await;
2845    for membership in &channels_for_user.channel_memberships {
2846        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2847    }
2848    session.peer.send(
2849        session.connection_id,
2850        build_update_user_channels(&channels_for_user),
2851    )?;
2852    session.peer.send(
2853        session.connection_id,
2854        build_channels_update(channels_for_user),
2855    )?;
2856    Ok(())
2857}
2858
2859/// Creates a new channel.
2860async fn create_channel(
2861    request: proto::CreateChannel,
2862    response: Response<proto::CreateChannel>,
2863    session: Session,
2864) -> Result<()> {
2865    let db = session.db().await;
2866
2867    let parent_id = request.parent_id.map(ChannelId::from_proto);
2868    let (channel, membership) = db
2869        .create_channel(&request.name, parent_id, session.user_id())
2870        .await?;
2871
2872    let root_id = channel.root_id();
2873    let channel = Channel::from_model(channel);
2874
2875    response.send(proto::CreateChannelResponse {
2876        channel: Some(channel.to_proto()),
2877        parent_id: request.parent_id,
2878    })?;
2879
2880    let mut connection_pool = session.connection_pool().await;
2881    if let Some(membership) = membership {
2882        connection_pool.subscribe_to_channel(
2883            membership.user_id,
2884            membership.channel_id,
2885            membership.role,
2886        );
2887        let update = proto::UpdateUserChannels {
2888            channel_memberships: vec![proto::ChannelMembership {
2889                channel_id: membership.channel_id.to_proto(),
2890                role: membership.role.into(),
2891            }],
2892            ..Default::default()
2893        };
2894        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2895            session.peer.send(connection_id, update.clone())?;
2896        }
2897    }
2898
2899    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2900        if !role.can_see_channel(channel.visibility) {
2901            continue;
2902        }
2903
2904        let update = proto::UpdateChannels {
2905            channels: vec![channel.to_proto()],
2906            ..Default::default()
2907        };
2908        session.peer.send(connection_id, update.clone())?;
2909    }
2910
2911    Ok(())
2912}
2913
2914/// Delete a channel
2915async fn delete_channel(
2916    request: proto::DeleteChannel,
2917    response: Response<proto::DeleteChannel>,
2918    session: Session,
2919) -> Result<()> {
2920    let db = session.db().await;
2921
2922    let channel_id = request.channel_id;
2923    let (root_channel, removed_channels) = db
2924        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2925        .await?;
2926    response.send(proto::Ack {})?;
2927
2928    // Notify members of removed channels
2929    let mut update = proto::UpdateChannels::default();
2930    update
2931        .delete_channels
2932        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2933
2934    let connection_pool = session.connection_pool().await;
2935    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2936        session.peer.send(connection_id, update.clone())?;
2937    }
2938
2939    Ok(())
2940}
2941
2942/// Invite someone to join a channel.
2943async fn invite_channel_member(
2944    request: proto::InviteChannelMember,
2945    response: Response<proto::InviteChannelMember>,
2946    session: Session,
2947) -> Result<()> {
2948    let db = session.db().await;
2949    let channel_id = ChannelId::from_proto(request.channel_id);
2950    let invitee_id = UserId::from_proto(request.user_id);
2951    let InviteMemberResult {
2952        channel,
2953        notifications,
2954    } = db
2955        .invite_channel_member(
2956            channel_id,
2957            invitee_id,
2958            session.user_id(),
2959            request.role().into(),
2960        )
2961        .await?;
2962
2963    let update = proto::UpdateChannels {
2964        channel_invitations: vec![channel.to_proto()],
2965        ..Default::default()
2966    };
2967
2968    let connection_pool = session.connection_pool().await;
2969    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2970        session.peer.send(connection_id, update.clone())?;
2971    }
2972
2973    send_notifications(&connection_pool, &session.peer, notifications);
2974
2975    response.send(proto::Ack {})?;
2976    Ok(())
2977}
2978
2979/// remove someone from a channel
2980async fn remove_channel_member(
2981    request: proto::RemoveChannelMember,
2982    response: Response<proto::RemoveChannelMember>,
2983    session: Session,
2984) -> Result<()> {
2985    let db = session.db().await;
2986    let channel_id = ChannelId::from_proto(request.channel_id);
2987    let member_id = UserId::from_proto(request.user_id);
2988
2989    let RemoveChannelMemberResult {
2990        membership_update,
2991        notification_id,
2992    } = db
2993        .remove_channel_member(channel_id, member_id, session.user_id())
2994        .await?;
2995
2996    let mut connection_pool = session.connection_pool().await;
2997    notify_membership_updated(
2998        &mut connection_pool,
2999        membership_update,
3000        member_id,
3001        &session.peer,
3002    );
3003    for connection_id in connection_pool.user_connection_ids(member_id) {
3004        if let Some(notification_id) = notification_id {
3005            session
3006                .peer
3007                .send(
3008                    connection_id,
3009                    proto::DeleteNotification {
3010                        notification_id: notification_id.to_proto(),
3011                    },
3012                )
3013                .trace_err();
3014        }
3015    }
3016
3017    response.send(proto::Ack {})?;
3018    Ok(())
3019}
3020
3021/// Toggle the channel between public and private.
3022/// Care is taken to maintain the invariant that public channels only descend from public channels,
3023/// (though members-only channels can appear at any point in the hierarchy).
3024async fn set_channel_visibility(
3025    request: proto::SetChannelVisibility,
3026    response: Response<proto::SetChannelVisibility>,
3027    session: Session,
3028) -> Result<()> {
3029    let db = session.db().await;
3030    let channel_id = ChannelId::from_proto(request.channel_id);
3031    let visibility = request.visibility().into();
3032
3033    let channel_model = db
3034        .set_channel_visibility(channel_id, visibility, session.user_id())
3035        .await?;
3036    let root_id = channel_model.root_id();
3037    let channel = Channel::from_model(channel_model);
3038
3039    let mut connection_pool = session.connection_pool().await;
3040    for (user_id, role) in connection_pool
3041        .channel_user_ids(root_id)
3042        .collect::<Vec<_>>()
3043        .into_iter()
3044    {
3045        let update = if role.can_see_channel(channel.visibility) {
3046            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3047            proto::UpdateChannels {
3048                channels: vec![channel.to_proto()],
3049                ..Default::default()
3050            }
3051        } else {
3052            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3053            proto::UpdateChannels {
3054                delete_channels: vec![channel.id.to_proto()],
3055                ..Default::default()
3056            }
3057        };
3058
3059        for connection_id in connection_pool.user_connection_ids(user_id) {
3060            session.peer.send(connection_id, update.clone())?;
3061        }
3062    }
3063
3064    response.send(proto::Ack {})?;
3065    Ok(())
3066}
3067
3068/// Alter the role for a user in the channel.
3069async fn set_channel_member_role(
3070    request: proto::SetChannelMemberRole,
3071    response: Response<proto::SetChannelMemberRole>,
3072    session: Session,
3073) -> Result<()> {
3074    let db = session.db().await;
3075    let channel_id = ChannelId::from_proto(request.channel_id);
3076    let member_id = UserId::from_proto(request.user_id);
3077    let result = db
3078        .set_channel_member_role(
3079            channel_id,
3080            session.user_id(),
3081            member_id,
3082            request.role().into(),
3083        )
3084        .await?;
3085
3086    match result {
3087        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3088            let mut connection_pool = session.connection_pool().await;
3089            notify_membership_updated(
3090                &mut connection_pool,
3091                membership_update,
3092                member_id,
3093                &session.peer,
3094            )
3095        }
3096        db::SetMemberRoleResult::InviteUpdated(channel) => {
3097            let update = proto::UpdateChannels {
3098                channel_invitations: vec![channel.to_proto()],
3099                ..Default::default()
3100            };
3101
3102            for connection_id in session
3103                .connection_pool()
3104                .await
3105                .user_connection_ids(member_id)
3106            {
3107                session.peer.send(connection_id, update.clone())?;
3108            }
3109        }
3110    }
3111
3112    response.send(proto::Ack {})?;
3113    Ok(())
3114}
3115
3116/// Change the name of a channel
3117async fn rename_channel(
3118    request: proto::RenameChannel,
3119    response: Response<proto::RenameChannel>,
3120    session: Session,
3121) -> Result<()> {
3122    let db = session.db().await;
3123    let channel_id = ChannelId::from_proto(request.channel_id);
3124    let channel_model = db
3125        .rename_channel(channel_id, session.user_id(), &request.name)
3126        .await?;
3127    let root_id = channel_model.root_id();
3128    let channel = Channel::from_model(channel_model);
3129
3130    response.send(proto::RenameChannelResponse {
3131        channel: Some(channel.to_proto()),
3132    })?;
3133
3134    let connection_pool = session.connection_pool().await;
3135    let update = proto::UpdateChannels {
3136        channels: vec![channel.to_proto()],
3137        ..Default::default()
3138    };
3139    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3140        if role.can_see_channel(channel.visibility) {
3141            session.peer.send(connection_id, update.clone())?;
3142        }
3143    }
3144
3145    Ok(())
3146}
3147
3148/// Move a channel to a new parent.
3149async fn move_channel(
3150    request: proto::MoveChannel,
3151    response: Response<proto::MoveChannel>,
3152    session: Session,
3153) -> Result<()> {
3154    let channel_id = ChannelId::from_proto(request.channel_id);
3155    let to = ChannelId::from_proto(request.to);
3156
3157    let (root_id, channels) = session
3158        .db()
3159        .await
3160        .move_channel(channel_id, to, session.user_id())
3161        .await?;
3162
3163    let connection_pool = session.connection_pool().await;
3164    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3165        let channels = channels
3166            .iter()
3167            .filter_map(|channel| {
3168                if role.can_see_channel(channel.visibility) {
3169                    Some(channel.to_proto())
3170                } else {
3171                    None
3172                }
3173            })
3174            .collect::<Vec<_>>();
3175        if channels.is_empty() {
3176            continue;
3177        }
3178
3179        let update = proto::UpdateChannels {
3180            channels,
3181            ..Default::default()
3182        };
3183
3184        session.peer.send(connection_id, update.clone())?;
3185    }
3186
3187    response.send(Ack {})?;
3188    Ok(())
3189}
3190
3191/// Get the list of channel members
3192async fn get_channel_members(
3193    request: proto::GetChannelMembers,
3194    response: Response<proto::GetChannelMembers>,
3195    session: Session,
3196) -> Result<()> {
3197    let db = session.db().await;
3198    let channel_id = ChannelId::from_proto(request.channel_id);
3199    let limit = if request.limit == 0 {
3200        u16::MAX as u64
3201    } else {
3202        request.limit
3203    };
3204    let (members, users) = db
3205        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3206        .await?;
3207    response.send(proto::GetChannelMembersResponse { members, users })?;
3208    Ok(())
3209}
3210
3211/// Accept or decline a channel invitation.
3212async fn respond_to_channel_invite(
3213    request: proto::RespondToChannelInvite,
3214    response: Response<proto::RespondToChannelInvite>,
3215    session: Session,
3216) -> Result<()> {
3217    let db = session.db().await;
3218    let channel_id = ChannelId::from_proto(request.channel_id);
3219    let RespondToChannelInvite {
3220        membership_update,
3221        notifications,
3222    } = db
3223        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3224        .await?;
3225
3226    let mut connection_pool = session.connection_pool().await;
3227    if let Some(membership_update) = membership_update {
3228        notify_membership_updated(
3229            &mut connection_pool,
3230            membership_update,
3231            session.user_id(),
3232            &session.peer,
3233        );
3234    } else {
3235        let update = proto::UpdateChannels {
3236            remove_channel_invitations: vec![channel_id.to_proto()],
3237            ..Default::default()
3238        };
3239
3240        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3241            session.peer.send(connection_id, update.clone())?;
3242        }
3243    };
3244
3245    send_notifications(&connection_pool, &session.peer, notifications);
3246
3247    response.send(proto::Ack {})?;
3248
3249    Ok(())
3250}
3251
3252/// Join the channels' room
3253async fn join_channel(
3254    request: proto::JoinChannel,
3255    response: Response<proto::JoinChannel>,
3256    session: Session,
3257) -> Result<()> {
3258    let channel_id = ChannelId::from_proto(request.channel_id);
3259    join_channel_internal(channel_id, Box::new(response), session).await
3260}
3261
3262trait JoinChannelInternalResponse {
3263    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3264}
3265impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3266    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3267        Response::<proto::JoinChannel>::send(self, result)
3268    }
3269}
3270impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3271    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3272        Response::<proto::JoinRoom>::send(self, result)
3273    }
3274}
3275
3276async fn join_channel_internal(
3277    channel_id: ChannelId,
3278    response: Box<impl JoinChannelInternalResponse>,
3279    session: Session,
3280) -> Result<()> {
3281    let joined_room = {
3282        let mut db = session.db().await;
3283        // If zed quits without leaving the room, and the user re-opens zed before the
3284        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3285        // room they were in.
3286        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3287            tracing::info!(
3288                stale_connection_id = %connection,
3289                "cleaning up stale connection",
3290            );
3291            drop(db);
3292            leave_room_for_session(&session, connection).await?;
3293            db = session.db().await;
3294        }
3295
3296        let (joined_room, membership_updated, role) = db
3297            .join_channel(channel_id, session.user_id(), session.connection_id)
3298            .await?;
3299
3300        let live_kit_connection_info =
3301            session
3302                .app_state
3303                .livekit_client
3304                .as_ref()
3305                .and_then(|live_kit| {
3306                    let (can_publish, token) = if role == ChannelRole::Guest {
3307                        (
3308                            false,
3309                            live_kit
3310                                .guest_token(
3311                                    &joined_room.room.livekit_room,
3312                                    &session.user_id().to_string(),
3313                                )
3314                                .trace_err()?,
3315                        )
3316                    } else {
3317                        (
3318                            true,
3319                            live_kit
3320                                .room_token(
3321                                    &joined_room.room.livekit_room,
3322                                    &session.user_id().to_string(),
3323                                )
3324                                .trace_err()?,
3325                        )
3326                    };
3327
3328                    Some(LiveKitConnectionInfo {
3329                        server_url: live_kit.url().into(),
3330                        token,
3331                        can_publish,
3332                    })
3333                });
3334
3335        response.send(proto::JoinRoomResponse {
3336            room: Some(joined_room.room.clone()),
3337            channel_id: joined_room
3338                .channel
3339                .as_ref()
3340                .map(|channel| channel.id.to_proto()),
3341            live_kit_connection_info,
3342        })?;
3343
3344        let mut connection_pool = session.connection_pool().await;
3345        if let Some(membership_updated) = membership_updated {
3346            notify_membership_updated(
3347                &mut connection_pool,
3348                membership_updated,
3349                session.user_id(),
3350                &session.peer,
3351            );
3352        }
3353
3354        room_updated(&joined_room.room, &session.peer);
3355
3356        joined_room
3357    };
3358
3359    channel_updated(
3360        &joined_room
3361            .channel
3362            .ok_or_else(|| anyhow!("channel not returned"))?,
3363        &joined_room.room,
3364        &session.peer,
3365        &*session.connection_pool().await,
3366    );
3367
3368    update_user_contacts(session.user_id(), &session).await?;
3369    Ok(())
3370}
3371
3372/// Start editing the channel notes
3373async fn join_channel_buffer(
3374    request: proto::JoinChannelBuffer,
3375    response: Response<proto::JoinChannelBuffer>,
3376    session: Session,
3377) -> Result<()> {
3378    let db = session.db().await;
3379    let channel_id = ChannelId::from_proto(request.channel_id);
3380
3381    let open_response = db
3382        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3383        .await?;
3384
3385    let collaborators = open_response.collaborators.clone();
3386    response.send(open_response)?;
3387
3388    let update = UpdateChannelBufferCollaborators {
3389        channel_id: channel_id.to_proto(),
3390        collaborators: collaborators.clone(),
3391    };
3392    channel_buffer_updated(
3393        session.connection_id,
3394        collaborators
3395            .iter()
3396            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3397        &update,
3398        &session.peer,
3399    );
3400
3401    Ok(())
3402}
3403
3404/// Edit the channel notes
3405async fn update_channel_buffer(
3406    request: proto::UpdateChannelBuffer,
3407    session: Session,
3408) -> Result<()> {
3409    let db = session.db().await;
3410    let channel_id = ChannelId::from_proto(request.channel_id);
3411
3412    let (collaborators, epoch, version) = db
3413        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3414        .await?;
3415
3416    channel_buffer_updated(
3417        session.connection_id,
3418        collaborators.clone(),
3419        &proto::UpdateChannelBuffer {
3420            channel_id: channel_id.to_proto(),
3421            operations: request.operations,
3422        },
3423        &session.peer,
3424    );
3425
3426    let pool = &*session.connection_pool().await;
3427
3428    let non_collaborators =
3429        pool.channel_connection_ids(channel_id)
3430            .filter_map(|(connection_id, _)| {
3431                if collaborators.contains(&connection_id) {
3432                    None
3433                } else {
3434                    Some(connection_id)
3435                }
3436            });
3437
3438    broadcast(None, non_collaborators, |peer_id| {
3439        session.peer.send(
3440            peer_id,
3441            proto::UpdateChannels {
3442                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3443                    channel_id: channel_id.to_proto(),
3444                    epoch: epoch as u64,
3445                    version: version.clone(),
3446                }],
3447                ..Default::default()
3448            },
3449        )
3450    });
3451
3452    Ok(())
3453}
3454
3455/// Rejoin the channel notes after a connection blip
3456async fn rejoin_channel_buffers(
3457    request: proto::RejoinChannelBuffers,
3458    response: Response<proto::RejoinChannelBuffers>,
3459    session: Session,
3460) -> Result<()> {
3461    let db = session.db().await;
3462    let buffers = db
3463        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3464        .await?;
3465
3466    for rejoined_buffer in &buffers {
3467        let collaborators_to_notify = rejoined_buffer
3468            .buffer
3469            .collaborators
3470            .iter()
3471            .filter_map(|c| Some(c.peer_id?.into()));
3472        channel_buffer_updated(
3473            session.connection_id,
3474            collaborators_to_notify,
3475            &proto::UpdateChannelBufferCollaborators {
3476                channel_id: rejoined_buffer.buffer.channel_id,
3477                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3478            },
3479            &session.peer,
3480        );
3481    }
3482
3483    response.send(proto::RejoinChannelBuffersResponse {
3484        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3485    })?;
3486
3487    Ok(())
3488}
3489
3490/// Stop editing the channel notes
3491async fn leave_channel_buffer(
3492    request: proto::LeaveChannelBuffer,
3493    response: Response<proto::LeaveChannelBuffer>,
3494    session: Session,
3495) -> Result<()> {
3496    let db = session.db().await;
3497    let channel_id = ChannelId::from_proto(request.channel_id);
3498
3499    let left_buffer = db
3500        .leave_channel_buffer(channel_id, session.connection_id)
3501        .await?;
3502
3503    response.send(Ack {})?;
3504
3505    channel_buffer_updated(
3506        session.connection_id,
3507        left_buffer.connections,
3508        &proto::UpdateChannelBufferCollaborators {
3509            channel_id: channel_id.to_proto(),
3510            collaborators: left_buffer.collaborators,
3511        },
3512        &session.peer,
3513    );
3514
3515    Ok(())
3516}
3517
3518fn channel_buffer_updated<T: EnvelopedMessage>(
3519    sender_id: ConnectionId,
3520    collaborators: impl IntoIterator<Item = ConnectionId>,
3521    message: &T,
3522    peer: &Peer,
3523) {
3524    broadcast(Some(sender_id), collaborators, |peer_id| {
3525        peer.send(peer_id, message.clone())
3526    });
3527}
3528
3529fn send_notifications(
3530    connection_pool: &ConnectionPool,
3531    peer: &Peer,
3532    notifications: db::NotificationBatch,
3533) {
3534    for (user_id, notification) in notifications {
3535        for connection_id in connection_pool.user_connection_ids(user_id) {
3536            if let Err(error) = peer.send(
3537                connection_id,
3538                proto::AddNotification {
3539                    notification: Some(notification.clone()),
3540                },
3541            ) {
3542                tracing::error!(
3543                    "failed to send notification to {:?} {}",
3544                    connection_id,
3545                    error
3546                );
3547            }
3548        }
3549    }
3550}
3551
3552/// Send a message to the channel
3553async fn send_channel_message(
3554    request: proto::SendChannelMessage,
3555    response: Response<proto::SendChannelMessage>,
3556    session: Session,
3557) -> Result<()> {
3558    // Validate the message body.
3559    let body = request.body.trim().to_string();
3560    if body.len() > MAX_MESSAGE_LEN {
3561        return Err(anyhow!("message is too long"))?;
3562    }
3563    if body.is_empty() {
3564        return Err(anyhow!("message can't be blank"))?;
3565    }
3566
3567    // TODO: adjust mentions if body is trimmed
3568
3569    let timestamp = OffsetDateTime::now_utc();
3570    let nonce = request
3571        .nonce
3572        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3573
3574    let channel_id = ChannelId::from_proto(request.channel_id);
3575    let CreatedChannelMessage {
3576        message_id,
3577        participant_connection_ids,
3578        notifications,
3579    } = session
3580        .db()
3581        .await
3582        .create_channel_message(
3583            channel_id,
3584            session.user_id(),
3585            &body,
3586            &request.mentions,
3587            timestamp,
3588            nonce.clone().into(),
3589            request.reply_to_message_id.map(MessageId::from_proto),
3590        )
3591        .await?;
3592
3593    let message = proto::ChannelMessage {
3594        sender_id: session.user_id().to_proto(),
3595        id: message_id.to_proto(),
3596        body,
3597        mentions: request.mentions,
3598        timestamp: timestamp.unix_timestamp() as u64,
3599        nonce: Some(nonce),
3600        reply_to_message_id: request.reply_to_message_id,
3601        edited_at: None,
3602    };
3603    broadcast(
3604        Some(session.connection_id),
3605        participant_connection_ids.clone(),
3606        |connection| {
3607            session.peer.send(
3608                connection,
3609                proto::ChannelMessageSent {
3610                    channel_id: channel_id.to_proto(),
3611                    message: Some(message.clone()),
3612                },
3613            )
3614        },
3615    );
3616    response.send(proto::SendChannelMessageResponse {
3617        message: Some(message),
3618    })?;
3619
3620    let pool = &*session.connection_pool().await;
3621    let non_participants =
3622        pool.channel_connection_ids(channel_id)
3623            .filter_map(|(connection_id, _)| {
3624                if participant_connection_ids.contains(&connection_id) {
3625                    None
3626                } else {
3627                    Some(connection_id)
3628                }
3629            });
3630    broadcast(None, non_participants, |peer_id| {
3631        session.peer.send(
3632            peer_id,
3633            proto::UpdateChannels {
3634                latest_channel_message_ids: vec![proto::ChannelMessageId {
3635                    channel_id: channel_id.to_proto(),
3636                    message_id: message_id.to_proto(),
3637                }],
3638                ..Default::default()
3639            },
3640        )
3641    });
3642    send_notifications(pool, &session.peer, notifications);
3643
3644    Ok(())
3645}
3646
3647/// Delete a channel message
3648async fn remove_channel_message(
3649    request: proto::RemoveChannelMessage,
3650    response: Response<proto::RemoveChannelMessage>,
3651    session: Session,
3652) -> Result<()> {
3653    let channel_id = ChannelId::from_proto(request.channel_id);
3654    let message_id = MessageId::from_proto(request.message_id);
3655    let (connection_ids, existing_notification_ids) = session
3656        .db()
3657        .await
3658        .remove_channel_message(channel_id, message_id, session.user_id())
3659        .await?;
3660
3661    broadcast(
3662        Some(session.connection_id),
3663        connection_ids,
3664        move |connection| {
3665            session.peer.send(connection, request.clone())?;
3666
3667            for notification_id in &existing_notification_ids {
3668                session.peer.send(
3669                    connection,
3670                    proto::DeleteNotification {
3671                        notification_id: (*notification_id).to_proto(),
3672                    },
3673                )?;
3674            }
3675
3676            Ok(())
3677        },
3678    );
3679    response.send(proto::Ack {})?;
3680    Ok(())
3681}
3682
3683async fn update_channel_message(
3684    request: proto::UpdateChannelMessage,
3685    response: Response<proto::UpdateChannelMessage>,
3686    session: Session,
3687) -> Result<()> {
3688    let channel_id = ChannelId::from_proto(request.channel_id);
3689    let message_id = MessageId::from_proto(request.message_id);
3690    let updated_at = OffsetDateTime::now_utc();
3691    let UpdatedChannelMessage {
3692        message_id,
3693        participant_connection_ids,
3694        notifications,
3695        reply_to_message_id,
3696        timestamp,
3697        deleted_mention_notification_ids,
3698        updated_mention_notifications,
3699    } = session
3700        .db()
3701        .await
3702        .update_channel_message(
3703            channel_id,
3704            message_id,
3705            session.user_id(),
3706            request.body.as_str(),
3707            &request.mentions,
3708            updated_at,
3709        )
3710        .await?;
3711
3712    let nonce = request
3713        .nonce
3714        .clone()
3715        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3716
3717    let message = proto::ChannelMessage {
3718        sender_id: session.user_id().to_proto(),
3719        id: message_id.to_proto(),
3720        body: request.body.clone(),
3721        mentions: request.mentions.clone(),
3722        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3723        nonce: Some(nonce),
3724        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3725        edited_at: Some(updated_at.unix_timestamp() as u64),
3726    };
3727
3728    response.send(proto::Ack {})?;
3729
3730    let pool = &*session.connection_pool().await;
3731    broadcast(
3732        Some(session.connection_id),
3733        participant_connection_ids,
3734        |connection| {
3735            session.peer.send(
3736                connection,
3737                proto::ChannelMessageUpdate {
3738                    channel_id: channel_id.to_proto(),
3739                    message: Some(message.clone()),
3740                },
3741            )?;
3742
3743            for notification_id in &deleted_mention_notification_ids {
3744                session.peer.send(
3745                    connection,
3746                    proto::DeleteNotification {
3747                        notification_id: (*notification_id).to_proto(),
3748                    },
3749                )?;
3750            }
3751
3752            for notification in &updated_mention_notifications {
3753                session.peer.send(
3754                    connection,
3755                    proto::UpdateNotification {
3756                        notification: Some(notification.clone()),
3757                    },
3758                )?;
3759            }
3760
3761            Ok(())
3762        },
3763    );
3764
3765    send_notifications(pool, &session.peer, notifications);
3766
3767    Ok(())
3768}
3769
3770/// Mark a channel message as read
3771async fn acknowledge_channel_message(
3772    request: proto::AckChannelMessage,
3773    session: Session,
3774) -> Result<()> {
3775    let channel_id = ChannelId::from_proto(request.channel_id);
3776    let message_id = MessageId::from_proto(request.message_id);
3777    let notifications = session
3778        .db()
3779        .await
3780        .observe_channel_message(channel_id, session.user_id(), message_id)
3781        .await?;
3782    send_notifications(
3783        &*session.connection_pool().await,
3784        &session.peer,
3785        notifications,
3786    );
3787    Ok(())
3788}
3789
3790/// Mark a buffer version as synced
3791async fn acknowledge_buffer_version(
3792    request: proto::AckBufferOperation,
3793    session: Session,
3794) -> Result<()> {
3795    let buffer_id = BufferId::from_proto(request.buffer_id);
3796    session
3797        .db()
3798        .await
3799        .observe_buffer_version(
3800            buffer_id,
3801            session.user_id(),
3802            request.epoch as i32,
3803            &request.version,
3804        )
3805        .await?;
3806    Ok(())
3807}
3808
3809/// Get a Supermaven API key for the user
3810async fn get_supermaven_api_key(
3811    _request: proto::GetSupermavenApiKey,
3812    response: Response<proto::GetSupermavenApiKey>,
3813    session: Session,
3814) -> Result<()> {
3815    let user_id: String = session.user_id().to_string();
3816    if !session.is_staff() {
3817        return Err(anyhow!("supermaven not enabled for this account"))?;
3818    }
3819
3820    let email = session
3821        .email()
3822        .ok_or_else(|| anyhow!("user must have an email"))?;
3823
3824    let supermaven_admin_api = session
3825        .supermaven_client
3826        .as_ref()
3827        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3828
3829    let result = supermaven_admin_api
3830        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3831        .await?;
3832
3833    response.send(proto::GetSupermavenApiKeyResponse {
3834        api_key: result.api_key,
3835    })?;
3836
3837    Ok(())
3838}
3839
3840/// Start receiving chat updates for a channel
3841async fn join_channel_chat(
3842    request: proto::JoinChannelChat,
3843    response: Response<proto::JoinChannelChat>,
3844    session: Session,
3845) -> Result<()> {
3846    let channel_id = ChannelId::from_proto(request.channel_id);
3847
3848    let db = session.db().await;
3849    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3850        .await?;
3851    let messages = db
3852        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3853        .await?;
3854    response.send(proto::JoinChannelChatResponse {
3855        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3856        messages,
3857    })?;
3858    Ok(())
3859}
3860
3861/// Stop receiving chat updates for a channel
3862async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3863    let channel_id = ChannelId::from_proto(request.channel_id);
3864    session
3865        .db()
3866        .await
3867        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3868        .await?;
3869    Ok(())
3870}
3871
3872/// Retrieve the chat history for a channel
3873async fn get_channel_messages(
3874    request: proto::GetChannelMessages,
3875    response: Response<proto::GetChannelMessages>,
3876    session: Session,
3877) -> Result<()> {
3878    let channel_id = ChannelId::from_proto(request.channel_id);
3879    let messages = session
3880        .db()
3881        .await
3882        .get_channel_messages(
3883            channel_id,
3884            session.user_id(),
3885            MESSAGE_COUNT_PER_PAGE,
3886            Some(MessageId::from_proto(request.before_message_id)),
3887        )
3888        .await?;
3889    response.send(proto::GetChannelMessagesResponse {
3890        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3891        messages,
3892    })?;
3893    Ok(())
3894}
3895
3896/// Retrieve specific chat messages
3897async fn get_channel_messages_by_id(
3898    request: proto::GetChannelMessagesById,
3899    response: Response<proto::GetChannelMessagesById>,
3900    session: Session,
3901) -> Result<()> {
3902    let message_ids = request
3903        .message_ids
3904        .iter()
3905        .map(|id| MessageId::from_proto(*id))
3906        .collect::<Vec<_>>();
3907    let messages = session
3908        .db()
3909        .await
3910        .get_channel_messages_by_id(session.user_id(), &message_ids)
3911        .await?;
3912    response.send(proto::GetChannelMessagesResponse {
3913        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914        messages,
3915    })?;
3916    Ok(())
3917}
3918
3919/// Retrieve the current users notifications
3920async fn get_notifications(
3921    request: proto::GetNotifications,
3922    response: Response<proto::GetNotifications>,
3923    session: Session,
3924) -> Result<()> {
3925    let notifications = session
3926        .db()
3927        .await
3928        .get_notifications(
3929            session.user_id(),
3930            NOTIFICATION_COUNT_PER_PAGE,
3931            request.before_id.map(db::NotificationId::from_proto),
3932        )
3933        .await?;
3934    response.send(proto::GetNotificationsResponse {
3935        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3936        notifications,
3937    })?;
3938    Ok(())
3939}
3940
3941/// Mark notifications as read
3942async fn mark_notification_as_read(
3943    request: proto::MarkNotificationRead,
3944    response: Response<proto::MarkNotificationRead>,
3945    session: Session,
3946) -> Result<()> {
3947    let database = &session.db().await;
3948    let notifications = database
3949        .mark_notification_as_read_by_id(
3950            session.user_id(),
3951            NotificationId::from_proto(request.notification_id),
3952        )
3953        .await?;
3954    send_notifications(
3955        &*session.connection_pool().await,
3956        &session.peer,
3957        notifications,
3958    );
3959    response.send(proto::Ack {})?;
3960    Ok(())
3961}
3962
3963/// Get the current users information
3964async fn get_private_user_info(
3965    _request: proto::GetPrivateUserInfo,
3966    response: Response<proto::GetPrivateUserInfo>,
3967    session: Session,
3968) -> Result<()> {
3969    let db = session.db().await;
3970
3971    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3972    let user = db
3973        .get_user_by_id(session.user_id())
3974        .await?
3975        .ok_or_else(|| anyhow!("user not found"))?;
3976    let flags = db.get_user_flags(session.user_id()).await?;
3977
3978    response.send(proto::GetPrivateUserInfoResponse {
3979        metrics_id,
3980        staff: user.admin,
3981        flags,
3982        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3983    })?;
3984    Ok(())
3985}
3986
3987/// Accept the terms of service (tos) on behalf of the current user
3988async fn accept_terms_of_service(
3989    _request: proto::AcceptTermsOfService,
3990    response: Response<proto::AcceptTermsOfService>,
3991    session: Session,
3992) -> Result<()> {
3993    let db = session.db().await;
3994
3995    let accepted_tos_at = Utc::now();
3996    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3997        .await?;
3998
3999    response.send(proto::AcceptTermsOfServiceResponse {
4000        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4001    })?;
4002    Ok(())
4003}
4004
4005/// The minimum account age an account must have in order to use the LLM service.
4006pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4007
4008async fn get_llm_api_token(
4009    _request: proto::GetLlmToken,
4010    response: Response<proto::GetLlmToken>,
4011    session: Session,
4012) -> Result<()> {
4013    let db = session.db().await;
4014
4015    let flags = db.get_user_flags(session.user_id()).await?;
4016
4017    let user_id = session.user_id();
4018    let user = db
4019        .get_user_by_id(user_id)
4020        .await?
4021        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4022
4023    if user.accepted_tos_at.is_none() {
4024        Err(anyhow!("terms of service not accepted"))?
4025    }
4026
4027    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4028    let billing_preferences = db.get_billing_preferences(user.id).await?;
4029
4030    let token = LlmTokenClaims::create(
4031        &user,
4032        session.is_staff(),
4033        billing_preferences,
4034        &flags,
4035        billing_subscription,
4036        session.system_id.clone(),
4037        &session.app_state.config,
4038    )?;
4039    response.send(proto::GetLlmTokenResponse { token })?;
4040    Ok(())
4041}
4042
4043fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4044    let message = match message {
4045        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4046        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4047        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4048        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4049        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4050            code: frame.code.into(),
4051            reason: frame.reason.as_str().to_owned().into(),
4052        })),
4053        // We should never receive a frame while reading the message, according
4054        // to the `tungstenite` maintainers:
4055        //
4056        // > It cannot occur when you read messages from the WebSocket, but it
4057        // > can be used when you want to send the raw frames (e.g. you want to
4058        // > send the frames to the WebSocket without composing the full message first).
4059        // >
4060        // > — https://github.com/snapview/tungstenite-rs/issues/268
4061        TungsteniteMessage::Frame(_) => {
4062            bail!("received an unexpected frame while reading the message")
4063        }
4064    };
4065
4066    Ok(message)
4067}
4068
4069fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4070    match message {
4071        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4072        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4073        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4074        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4075        AxumMessage::Close(frame) => {
4076            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4077                code: frame.code.into(),
4078                reason: frame.reason.as_ref().into(),
4079            }))
4080        }
4081    }
4082}
4083
4084fn notify_membership_updated(
4085    connection_pool: &mut ConnectionPool,
4086    result: MembershipUpdated,
4087    user_id: UserId,
4088    peer: &Peer,
4089) {
4090    for membership in &result.new_channels.channel_memberships {
4091        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4092    }
4093    for channel_id in &result.removed_channels {
4094        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4095    }
4096
4097    let user_channels_update = proto::UpdateUserChannels {
4098        channel_memberships: result
4099            .new_channels
4100            .channel_memberships
4101            .iter()
4102            .map(|cm| proto::ChannelMembership {
4103                channel_id: cm.channel_id.to_proto(),
4104                role: cm.role.into(),
4105            })
4106            .collect(),
4107        ..Default::default()
4108    };
4109
4110    let mut update = build_channels_update(result.new_channels);
4111    update.delete_channels = result
4112        .removed_channels
4113        .into_iter()
4114        .map(|id| id.to_proto())
4115        .collect();
4116    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4117
4118    for connection_id in connection_pool.user_connection_ids(user_id) {
4119        peer.send(connection_id, user_channels_update.clone())
4120            .trace_err();
4121        peer.send(connection_id, update.clone()).trace_err();
4122    }
4123}
4124
4125fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4126    proto::UpdateUserChannels {
4127        channel_memberships: channels
4128            .channel_memberships
4129            .iter()
4130            .map(|m| proto::ChannelMembership {
4131                channel_id: m.channel_id.to_proto(),
4132                role: m.role.into(),
4133            })
4134            .collect(),
4135        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4136        observed_channel_message_id: channels.observed_channel_messages.clone(),
4137    }
4138}
4139
4140fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4141    let mut update = proto::UpdateChannels::default();
4142
4143    for channel in channels.channels {
4144        update.channels.push(channel.to_proto());
4145    }
4146
4147    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4148    update.latest_channel_message_ids = channels.latest_channel_messages;
4149
4150    for (channel_id, participants) in channels.channel_participants {
4151        update
4152            .channel_participants
4153            .push(proto::ChannelParticipants {
4154                channel_id: channel_id.to_proto(),
4155                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4156            });
4157    }
4158
4159    for channel in channels.invited_channels {
4160        update.channel_invitations.push(channel.to_proto());
4161    }
4162
4163    update
4164}
4165
4166fn build_initial_contacts_update(
4167    contacts: Vec<db::Contact>,
4168    pool: &ConnectionPool,
4169) -> proto::UpdateContacts {
4170    let mut update = proto::UpdateContacts::default();
4171
4172    for contact in contacts {
4173        match contact {
4174            db::Contact::Accepted { user_id, busy } => {
4175                update.contacts.push(contact_for_user(user_id, busy, pool));
4176            }
4177            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4178            db::Contact::Incoming { user_id } => {
4179                update
4180                    .incoming_requests
4181                    .push(proto::IncomingContactRequest {
4182                        requester_id: user_id.to_proto(),
4183                    })
4184            }
4185        }
4186    }
4187
4188    update
4189}
4190
4191fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4192    proto::Contact {
4193        user_id: user_id.to_proto(),
4194        online: pool.is_user_online(user_id),
4195        busy,
4196    }
4197}
4198
4199fn room_updated(room: &proto::Room, peer: &Peer) {
4200    broadcast(
4201        None,
4202        room.participants
4203            .iter()
4204            .filter_map(|participant| Some(participant.peer_id?.into())),
4205        |peer_id| {
4206            peer.send(
4207                peer_id,
4208                proto::RoomUpdated {
4209                    room: Some(room.clone()),
4210                },
4211            )
4212        },
4213    );
4214}
4215
4216fn channel_updated(
4217    channel: &db::channel::Model,
4218    room: &proto::Room,
4219    peer: &Peer,
4220    pool: &ConnectionPool,
4221) {
4222    let participants = room
4223        .participants
4224        .iter()
4225        .map(|p| p.user_id)
4226        .collect::<Vec<_>>();
4227
4228    broadcast(
4229        None,
4230        pool.channel_connection_ids(channel.root_id())
4231            .filter_map(|(channel_id, role)| {
4232                role.can_see_channel(channel.visibility)
4233                    .then_some(channel_id)
4234            }),
4235        |peer_id| {
4236            peer.send(
4237                peer_id,
4238                proto::UpdateChannels {
4239                    channel_participants: vec![proto::ChannelParticipants {
4240                        channel_id: channel.id.to_proto(),
4241                        participant_user_ids: participants.clone(),
4242                    }],
4243                    ..Default::default()
4244                },
4245            )
4246        },
4247    );
4248}
4249
4250async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4251    let db = session.db().await;
4252
4253    let contacts = db.get_contacts(user_id).await?;
4254    let busy = db.is_user_busy(user_id).await?;
4255
4256    let pool = session.connection_pool().await;
4257    let updated_contact = contact_for_user(user_id, busy, &pool);
4258    for contact in contacts {
4259        if let db::Contact::Accepted {
4260            user_id: contact_user_id,
4261            ..
4262        } = contact
4263        {
4264            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4265                session
4266                    .peer
4267                    .send(
4268                        contact_conn_id,
4269                        proto::UpdateContacts {
4270                            contacts: vec![updated_contact.clone()],
4271                            remove_contacts: Default::default(),
4272                            incoming_requests: Default::default(),
4273                            remove_incoming_requests: Default::default(),
4274                            outgoing_requests: Default::default(),
4275                            remove_outgoing_requests: Default::default(),
4276                        },
4277                    )
4278                    .trace_err();
4279            }
4280        }
4281    }
4282    Ok(())
4283}
4284
4285async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4286    let mut contacts_to_update = HashSet::default();
4287
4288    let room_id;
4289    let canceled_calls_to_user_ids;
4290    let livekit_room;
4291    let delete_livekit_room;
4292    let room;
4293    let channel;
4294
4295    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4296        contacts_to_update.insert(session.user_id());
4297
4298        for project in left_room.left_projects.values() {
4299            project_left(project, session);
4300        }
4301
4302        room_id = RoomId::from_proto(left_room.room.id);
4303        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4304        livekit_room = mem::take(&mut left_room.room.livekit_room);
4305        delete_livekit_room = left_room.deleted;
4306        room = mem::take(&mut left_room.room);
4307        channel = mem::take(&mut left_room.channel);
4308
4309        room_updated(&room, &session.peer);
4310    } else {
4311        return Ok(());
4312    }
4313
4314    if let Some(channel) = channel {
4315        channel_updated(
4316            &channel,
4317            &room,
4318            &session.peer,
4319            &*session.connection_pool().await,
4320        );
4321    }
4322
4323    {
4324        let pool = session.connection_pool().await;
4325        for canceled_user_id in canceled_calls_to_user_ids {
4326            for connection_id in pool.user_connection_ids(canceled_user_id) {
4327                session
4328                    .peer
4329                    .send(
4330                        connection_id,
4331                        proto::CallCanceled {
4332                            room_id: room_id.to_proto(),
4333                        },
4334                    )
4335                    .trace_err();
4336            }
4337            contacts_to_update.insert(canceled_user_id);
4338        }
4339    }
4340
4341    for contact_user_id in contacts_to_update {
4342        update_user_contacts(contact_user_id, session).await?;
4343    }
4344
4345    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4346        live_kit
4347            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4348            .await
4349            .trace_err();
4350
4351        if delete_livekit_room {
4352            live_kit.delete_room(livekit_room).await.trace_err();
4353        }
4354    }
4355
4356    Ok(())
4357}
4358
4359async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4360    let left_channel_buffers = session
4361        .db()
4362        .await
4363        .leave_channel_buffers(session.connection_id)
4364        .await?;
4365
4366    for left_buffer in left_channel_buffers {
4367        channel_buffer_updated(
4368            session.connection_id,
4369            left_buffer.connections,
4370            &proto::UpdateChannelBufferCollaborators {
4371                channel_id: left_buffer.channel_id.to_proto(),
4372                collaborators: left_buffer.collaborators,
4373            },
4374            &session.peer,
4375        );
4376    }
4377
4378    Ok(())
4379}
4380
4381fn project_left(project: &db::LeftProject, session: &Session) {
4382    for connection_id in &project.connection_ids {
4383        if project.should_unshare {
4384            session
4385                .peer
4386                .send(
4387                    *connection_id,
4388                    proto::UnshareProject {
4389                        project_id: project.id.to_proto(),
4390                    },
4391                )
4392                .trace_err();
4393        } else {
4394            session
4395                .peer
4396                .send(
4397                    *connection_id,
4398                    proto::RemoveProjectCollaborator {
4399                        project_id: project.id.to_proto(),
4400                        peer_id: Some(session.connection_id.into()),
4401                    },
4402                )
4403                .trace_err();
4404        }
4405    }
4406}
4407
4408pub trait ResultExt {
4409    type Ok;
4410
4411    fn trace_err(self) -> Option<Self::Ok>;
4412}
4413
4414impl<T, E> ResultExt for Result<T, E>
4415where
4416    E: std::fmt::Debug,
4417{
4418    type Ok = T;
4419
4420    #[track_caller]
4421    fn trace_err(self) -> Option<T> {
4422        match self {
4423            Ok(value) => Some(value),
4424            Err(error) => {
4425                tracing::error!("{:?}", error);
4426                None
4427            }
4428        }
4429    }
4430}