rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
 313            .add_request_handler(
 314                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 315            )
 316            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 317            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 318            .add_request_handler(
 319                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 320            )
 321            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 322            .add_request_handler(
 323                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 324            )
 325            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 326            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 327            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 330            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 332            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 336            .add_request_handler(
 337                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 338            )
 339            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 340            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 341            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 342            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 343            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 345            .add_message_handler(create_buffer_for_peer)
 346            .add_request_handler(update_buffer)
 347            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 352            .add_request_handler(get_users)
 353            .add_request_handler(fuzzy_search_users)
 354            .add_request_handler(request_contact)
 355            .add_request_handler(remove_contact)
 356            .add_request_handler(respond_to_contact_request)
 357            .add_message_handler(subscribe_to_channels)
 358            .add_request_handler(create_channel)
 359            .add_request_handler(delete_channel)
 360            .add_request_handler(invite_channel_member)
 361            .add_request_handler(remove_channel_member)
 362            .add_request_handler(set_channel_member_role)
 363            .add_request_handler(set_channel_visibility)
 364            .add_request_handler(rename_channel)
 365            .add_request_handler(join_channel_buffer)
 366            .add_request_handler(leave_channel_buffer)
 367            .add_message_handler(update_channel_buffer)
 368            .add_request_handler(rejoin_channel_buffers)
 369            .add_request_handler(get_channel_members)
 370            .add_request_handler(respond_to_channel_invite)
 371            .add_request_handler(join_channel)
 372            .add_request_handler(join_channel_chat)
 373            .add_message_handler(leave_channel_chat)
 374            .add_request_handler(send_channel_message)
 375            .add_request_handler(remove_channel_message)
 376            .add_request_handler(update_channel_message)
 377            .add_request_handler(get_channel_messages)
 378            .add_request_handler(get_channel_messages_by_id)
 379            .add_request_handler(get_notifications)
 380            .add_request_handler(mark_notification_as_read)
 381            .add_request_handler(move_channel)
 382            .add_request_handler(follow)
 383            .add_message_handler(unfollow)
 384            .add_message_handler(update_followers)
 385            .add_request_handler(get_private_user_info)
 386            .add_request_handler(get_llm_api_token)
 387            .add_request_handler(accept_terms_of_service)
 388            .add_message_handler(acknowledge_channel_message)
 389            .add_message_handler(acknowledge_buffer_version)
 390            .add_request_handler(get_supermaven_api_key)
 391            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 392            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 393            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 395            .add_message_handler(update_context)
 396            .add_request_handler({
 397                let app_state = app_state.clone();
 398                move |request, response, session| {
 399                    let app_state = app_state.clone();
 400                    async move {
 401                        count_language_model_tokens(request, response, session, &app_state.config)
 402                            .await
 403                    }
 404                }
 405            })
 406            .add_request_handler(get_cached_embeddings)
 407            .add_request_handler({
 408                let app_state = app_state.clone();
 409                move |request, response, session| {
 410                    compute_embeddings(
 411                        request,
 412                        response,
 413                        session,
 414                        app_state.config.openai_api_key.clone(),
 415                    )
 416                }
 417            });
 418
 419        Arc::new(server)
 420    }
 421
 422    pub async fn start(&self) -> Result<()> {
 423        let server_id = *self.id.lock();
 424        let app_state = self.app_state.clone();
 425        let peer = self.peer.clone();
 426        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 427        let pool = self.connection_pool.clone();
 428        let livekit_client = self.app_state.livekit_client.clone();
 429
 430        let span = info_span!("start server");
 431        self.app_state.executor.spawn_detached(
 432            async move {
 433                tracing::info!("waiting for cleanup timeout");
 434                timeout.await;
 435                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 436                if let Some((room_ids, channel_ids)) = app_state
 437                    .db
 438                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 439                    .await
 440                    .trace_err()
 441                {
 442                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 443                    tracing::info!(
 444                        stale_channel_buffer_count = channel_ids.len(),
 445                        "retrieved stale channel buffers"
 446                    );
 447
 448                    for channel_id in channel_ids {
 449                        if let Some(refreshed_channel_buffer) = app_state
 450                            .db
 451                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 452                            .await
 453                            .trace_err()
 454                        {
 455                            for connection_id in refreshed_channel_buffer.connection_ids {
 456                                peer.send(
 457                                    connection_id,
 458                                    proto::UpdateChannelBufferCollaborators {
 459                                        channel_id: channel_id.to_proto(),
 460                                        collaborators: refreshed_channel_buffer
 461                                            .collaborators
 462                                            .clone(),
 463                                    },
 464                                )
 465                                .trace_err();
 466                            }
 467                        }
 468                    }
 469
 470                    for room_id in room_ids {
 471                        let mut contacts_to_update = HashSet::default();
 472                        let mut canceled_calls_to_user_ids = Vec::new();
 473                        let mut livekit_room = String::new();
 474                        let mut delete_livekit_room = false;
 475
 476                        if let Some(mut refreshed_room) = app_state
 477                            .db
 478                            .clear_stale_room_participants(room_id, server_id)
 479                            .await
 480                            .trace_err()
 481                        {
 482                            tracing::info!(
 483                                room_id = room_id.0,
 484                                new_participant_count = refreshed_room.room.participants.len(),
 485                                "refreshed room"
 486                            );
 487                            room_updated(&refreshed_room.room, &peer);
 488                            if let Some(channel) = refreshed_room.channel.as_ref() {
 489                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 490                            }
 491                            contacts_to_update
 492                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 493                            contacts_to_update
 494                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 495                            canceled_calls_to_user_ids =
 496                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 497                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 498                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 499                        }
 500
 501                        {
 502                            let pool = pool.lock();
 503                            for canceled_user_id in canceled_calls_to_user_ids {
 504                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 505                                    peer.send(
 506                                        connection_id,
 507                                        proto::CallCanceled {
 508                                            room_id: room_id.to_proto(),
 509                                        },
 510                                    )
 511                                    .trace_err();
 512                                }
 513                            }
 514                        }
 515
 516                        for user_id in contacts_to_update {
 517                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 518                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 519                            if let Some((busy, contacts)) = busy.zip(contacts) {
 520                                let pool = pool.lock();
 521                                let updated_contact = contact_for_user(user_id, busy, &pool);
 522                                for contact in contacts {
 523                                    if let db::Contact::Accepted {
 524                                        user_id: contact_user_id,
 525                                        ..
 526                                    } = contact
 527                                    {
 528                                        for contact_conn_id in
 529                                            pool.user_connection_ids(contact_user_id)
 530                                        {
 531                                            peer.send(
 532                                                contact_conn_id,
 533                                                proto::UpdateContacts {
 534                                                    contacts: vec![updated_contact.clone()],
 535                                                    remove_contacts: Default::default(),
 536                                                    incoming_requests: Default::default(),
 537                                                    remove_incoming_requests: Default::default(),
 538                                                    outgoing_requests: Default::default(),
 539                                                    remove_outgoing_requests: Default::default(),
 540                                                },
 541                                            )
 542                                            .trace_err();
 543                                        }
 544                                    }
 545                                }
 546                            }
 547                        }
 548
 549                        if let Some(live_kit) = livekit_client.as_ref() {
 550                            if delete_livekit_room {
 551                                live_kit.delete_room(livekit_room).await.trace_err();
 552                            }
 553                        }
 554                    }
 555                }
 556
 557                app_state
 558                    .db
 559                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 560                    .await
 561                    .trace_err();
 562            }
 563            .instrument(span),
 564        );
 565        Ok(())
 566    }
 567
 568    pub fn teardown(&self) {
 569        self.peer.teardown();
 570        self.connection_pool.lock().reset();
 571        let _ = self.teardown.send(true);
 572    }
 573
 574    #[cfg(test)]
 575    pub fn reset(&self, id: ServerId) {
 576        self.teardown();
 577        *self.id.lock() = id;
 578        self.peer.reset(id.0 as u32);
 579        let _ = self.teardown.send(false);
 580    }
 581
 582    #[cfg(test)]
 583    pub fn id(&self) -> ServerId {
 584        *self.id.lock()
 585    }
 586
 587    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 588    where
 589        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 590        Fut: 'static + Send + Future<Output = Result<()>>,
 591        M: EnvelopedMessage,
 592    {
 593        let prev_handler = self.handlers.insert(
 594            TypeId::of::<M>(),
 595            Box::new(move |envelope, session| {
 596                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 597                let received_at = envelope.received_at;
 598                tracing::info!("message received");
 599                let start_time = Instant::now();
 600                let future = (handler)(*envelope, session);
 601                async move {
 602                    let result = future.await;
 603                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 604                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 605                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 606                    let payload_type = M::NAME;
 607
 608                    match result {
 609                        Err(error) => {
 610                            tracing::error!(
 611                                ?error,
 612                                total_duration_ms,
 613                                processing_duration_ms,
 614                                queue_duration_ms,
 615                                payload_type,
 616                                "error handling message"
 617                            )
 618                        }
 619                        Ok(()) => tracing::info!(
 620                            total_duration_ms,
 621                            processing_duration_ms,
 622                            queue_duration_ms,
 623                            "finished handling message"
 624                        ),
 625                    }
 626                }
 627                .boxed()
 628            }),
 629        );
 630        if prev_handler.is_some() {
 631            panic!("registered a handler for the same message twice");
 632        }
 633        self
 634    }
 635
 636    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 637    where
 638        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 639        Fut: 'static + Send + Future<Output = Result<()>>,
 640        M: EnvelopedMessage,
 641    {
 642        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 643        self
 644    }
 645
 646    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 647    where
 648        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 649        Fut: Send + Future<Output = Result<()>>,
 650        M: RequestMessage,
 651    {
 652        let handler = Arc::new(handler);
 653        self.add_handler(move |envelope, session| {
 654            let receipt = envelope.receipt();
 655            let handler = handler.clone();
 656            async move {
 657                let peer = session.peer.clone();
 658                let responded = Arc::new(AtomicBool::default());
 659                let response = Response {
 660                    peer: peer.clone(),
 661                    responded: responded.clone(),
 662                    receipt,
 663                };
 664                match (handler)(envelope.payload, response, session).await {
 665                    Ok(()) => {
 666                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 667                            Ok(())
 668                        } else {
 669                            Err(anyhow!("handler did not send a response"))?
 670                        }
 671                    }
 672                    Err(error) => {
 673                        let proto_err = match &error {
 674                            Error::Internal(err) => err.to_proto(),
 675                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 676                        };
 677                        peer.respond_with_error(receipt, proto_err)?;
 678                        Err(error)
 679                    }
 680                }
 681            }
 682        })
 683    }
 684
 685    #[allow(clippy::too_many_arguments)]
 686    pub fn handle_connection(
 687        self: &Arc<Self>,
 688        connection: Connection,
 689        address: String,
 690        principal: Principal,
 691        zed_version: ZedVersion,
 692        geoip_country_code: Option<String>,
 693        system_id: Option<String>,
 694        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 695        executor: Executor,
 696    ) -> impl Future<Output = ()> {
 697        let this = self.clone();
 698        let span = info_span!("handle connection", %address,
 699            connection_id=field::Empty,
 700            user_id=field::Empty,
 701            login=field::Empty,
 702            impersonator=field::Empty,
 703            geoip_country_code=field::Empty
 704        );
 705        principal.update_span(&span);
 706        if let Some(country_code) = geoip_country_code.as_ref() {
 707            span.record("geoip_country_code", country_code);
 708        }
 709
 710        let mut teardown = self.teardown.subscribe();
 711        async move {
 712            if *teardown.borrow() {
 713                tracing::error!("server is tearing down");
 714                return
 715            }
 716            let (connection_id, handle_io, mut incoming_rx) = this
 717                .peer
 718                .add_connection(connection, {
 719                    let executor = executor.clone();
 720                    move |duration| executor.sleep(duration)
 721                });
 722            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 723
 724            tracing::info!("connection opened");
 725
 726            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 727            let http_client = match ReqwestClient::user_agent(&user_agent) {
 728                Ok(http_client) => Arc::new(http_client),
 729                Err(error) => {
 730                    tracing::error!(?error, "failed to create HTTP client");
 731                    return;
 732                }
 733            };
 734
 735            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 736                    supermaven_admin_api_key.to_string(),
 737                    http_client.clone(),
 738                )));
 739
 740            let session = Session {
 741                principal: principal.clone(),
 742                connection_id,
 743                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 744                peer: this.peer.clone(),
 745                connection_pool: this.connection_pool.clone(),
 746                app_state: this.app_state.clone(),
 747                http_client,
 748                geoip_country_code,
 749                system_id,
 750                _executor: executor.clone(),
 751                supermaven_client,
 752            };
 753
 754            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 755                tracing::error!(?error, "failed to send initial client update");
 756                return;
 757            }
 758
 759            let handle_io = handle_io.fuse();
 760            futures::pin_mut!(handle_io);
 761
 762            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 763            // This prevents deadlocks when e.g., client A performs a request to client B and
 764            // client B performs a request to client A. If both clients stop processing further
 765            // messages until their respective request completes, they won't have a chance to
 766            // respond to the other client's request and cause a deadlock.
 767            //
 768            // This arrangement ensures we will attempt to process earlier messages first, but fall
 769            // back to processing messages arrived later in the spirit of making progress.
 770            let mut foreground_message_handlers = FuturesUnordered::new();
 771            let concurrent_handlers = Arc::new(Semaphore::new(256));
 772            loop {
 773                let next_message = async {
 774                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 775                    let message = incoming_rx.next().await;
 776                    (permit, message)
 777                }.fuse();
 778                futures::pin_mut!(next_message);
 779                futures::select_biased! {
 780                    _ = teardown.changed().fuse() => return,
 781                    result = handle_io => {
 782                        if let Err(error) = result {
 783                            tracing::error!(?error, "error handling I/O");
 784                        }
 785                        break;
 786                    }
 787                    _ = foreground_message_handlers.next() => {}
 788                    next_message = next_message => {
 789                        let (permit, message) = next_message;
 790                        if let Some(message) = message {
 791                            let type_name = message.payload_type_name();
 792                            // note: we copy all the fields from the parent span so we can query them in the logs.
 793                            // (https://github.com/tokio-rs/tracing/issues/2670).
 794                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 795                                user_id=field::Empty,
 796                                login=field::Empty,
 797                                impersonator=field::Empty,
 798                            );
 799                            principal.update_span(&span);
 800                            let span_enter = span.enter();
 801                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 802                                let is_background = message.is_background();
 803                                let handle_message = (handler)(message, session.clone());
 804                                drop(span_enter);
 805
 806                                let handle_message = async move {
 807                                    handle_message.await;
 808                                    drop(permit);
 809                                }.instrument(span);
 810                                if is_background {
 811                                    executor.spawn_detached(handle_message);
 812                                } else {
 813                                    foreground_message_handlers.push(handle_message);
 814                                }
 815                            } else {
 816                                tracing::error!("no message handler");
 817                            }
 818                        } else {
 819                            tracing::info!("connection closed");
 820                            break;
 821                        }
 822                    }
 823                }
 824            }
 825
 826            drop(foreground_message_handlers);
 827            tracing::info!("signing out");
 828            if let Err(error) = connection_lost(session, teardown, executor).await {
 829                tracing::error!(?error, "error signing out");
 830            }
 831
 832        }.instrument(span)
 833    }
 834
 835    async fn send_initial_client_update(
 836        &self,
 837        connection_id: ConnectionId,
 838        principal: &Principal,
 839        zed_version: ZedVersion,
 840        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 841        session: &Session,
 842    ) -> Result<()> {
 843        self.peer.send(
 844            connection_id,
 845            proto::Hello {
 846                peer_id: Some(connection_id.into()),
 847            },
 848        )?;
 849        tracing::info!("sent hello message");
 850        if let Some(send_connection_id) = send_connection_id.take() {
 851            let _ = send_connection_id.send(connection_id);
 852        }
 853
 854        match principal {
 855            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 856                if !user.connected_once {
 857                    self.peer.send(connection_id, proto::ShowContacts {})?;
 858                    self.app_state
 859                        .db
 860                        .set_user_connected_once(user.id, true)
 861                        .await?;
 862                }
 863
 864                update_user_plan(user.id, session).await?;
 865
 866                let contacts = self.app_state.db.get_contacts(user.id).await?;
 867
 868                {
 869                    let mut pool = self.connection_pool.lock();
 870                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 871                    self.peer.send(
 872                        connection_id,
 873                        build_initial_contacts_update(contacts, &pool),
 874                    )?;
 875                }
 876
 877                if should_auto_subscribe_to_channels(zed_version) {
 878                    subscribe_user_to_channels(user.id, session).await?;
 879                }
 880
 881                if let Some(incoming_call) =
 882                    self.app_state.db.incoming_call_for_user(user.id).await?
 883                {
 884                    self.peer.send(connection_id, incoming_call)?;
 885                }
 886
 887                update_user_contacts(user.id, session).await?;
 888            }
 889        }
 890
 891        Ok(())
 892    }
 893
 894    pub async fn invite_code_redeemed(
 895        self: &Arc<Self>,
 896        inviter_id: UserId,
 897        invitee_id: UserId,
 898    ) -> Result<()> {
 899        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 900            if let Some(code) = &user.invite_code {
 901                let pool = self.connection_pool.lock();
 902                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 903                for connection_id in pool.user_connection_ids(inviter_id) {
 904                    self.peer.send(
 905                        connection_id,
 906                        proto::UpdateContacts {
 907                            contacts: vec![invitee_contact.clone()],
 908                            ..Default::default()
 909                        },
 910                    )?;
 911                    self.peer.send(
 912                        connection_id,
 913                        proto::UpdateInviteInfo {
 914                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 915                            count: user.invite_count as u32,
 916                        },
 917                    )?;
 918                }
 919            }
 920        }
 921        Ok(())
 922    }
 923
 924    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 925        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 926            if let Some(invite_code) = &user.invite_code {
 927                let pool = self.connection_pool.lock();
 928                for connection_id in pool.user_connection_ids(user_id) {
 929                    self.peer.send(
 930                        connection_id,
 931                        proto::UpdateInviteInfo {
 932                            url: format!(
 933                                "{}{}",
 934                                self.app_state.config.invite_link_prefix, invite_code
 935                            ),
 936                            count: user.invite_count as u32,
 937                        },
 938                    )?;
 939                }
 940            }
 941        }
 942        Ok(())
 943    }
 944
 945    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 946        let pool = self.connection_pool.lock();
 947        for connection_id in pool.user_connection_ids(user_id) {
 948            self.peer
 949                .send(connection_id, proto::RefreshLlmToken {})
 950                .trace_err();
 951        }
 952    }
 953
 954    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 955        ServerSnapshot {
 956            connection_pool: ConnectionPoolGuard {
 957                guard: self.connection_pool.lock(),
 958                _not_send: PhantomData,
 959            },
 960            peer: &self.peer,
 961        }
 962    }
 963}
 964
 965impl<'a> Deref for ConnectionPoolGuard<'a> {
 966    type Target = ConnectionPool;
 967
 968    fn deref(&self) -> &Self::Target {
 969        &self.guard
 970    }
 971}
 972
 973impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 974    fn deref_mut(&mut self) -> &mut Self::Target {
 975        &mut self.guard
 976    }
 977}
 978
 979impl<'a> Drop for ConnectionPoolGuard<'a> {
 980    fn drop(&mut self) {
 981        #[cfg(test)]
 982        self.check_invariants();
 983    }
 984}
 985
 986fn broadcast<F>(
 987    sender_id: Option<ConnectionId>,
 988    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 989    mut f: F,
 990) where
 991    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 992{
 993    for receiver_id in receiver_ids {
 994        if Some(receiver_id) != sender_id {
 995            if let Err(error) = f(receiver_id) {
 996                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 997            }
 998        }
 999    }
1000}
1001
1002pub struct ProtocolVersion(u32);
1003
1004impl Header for ProtocolVersion {
1005    fn name() -> &'static HeaderName {
1006        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1007        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1008    }
1009
1010    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1011    where
1012        Self: Sized,
1013        I: Iterator<Item = &'i axum::http::HeaderValue>,
1014    {
1015        let version = values
1016            .next()
1017            .ok_or_else(axum::headers::Error::invalid)?
1018            .to_str()
1019            .map_err(|_| axum::headers::Error::invalid())?
1020            .parse()
1021            .map_err(|_| axum::headers::Error::invalid())?;
1022        Ok(Self(version))
1023    }
1024
1025    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1026        values.extend([self.0.to_string().parse().unwrap()]);
1027    }
1028}
1029
1030pub struct AppVersionHeader(SemanticVersion);
1031impl Header for AppVersionHeader {
1032    fn name() -> &'static HeaderName {
1033        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1034        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1035    }
1036
1037    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1038    where
1039        Self: Sized,
1040        I: Iterator<Item = &'i axum::http::HeaderValue>,
1041    {
1042        let version = values
1043            .next()
1044            .ok_or_else(axum::headers::Error::invalid)?
1045            .to_str()
1046            .map_err(|_| axum::headers::Error::invalid())?
1047            .parse()
1048            .map_err(|_| axum::headers::Error::invalid())?;
1049        Ok(Self(version))
1050    }
1051
1052    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1053        values.extend([self.0.to_string().parse().unwrap()]);
1054    }
1055}
1056
1057pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1058    Router::new()
1059        .route("/rpc", get(handle_websocket_request))
1060        .layer(
1061            ServiceBuilder::new()
1062                .layer(Extension(server.app_state.clone()))
1063                .layer(middleware::from_fn(auth::validate_header)),
1064        )
1065        .route("/metrics", get(handle_metrics))
1066        .layer(Extension(server))
1067}
1068
1069#[allow(clippy::too_many_arguments)]
1070pub async fn handle_websocket_request(
1071    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1072    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1073    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1074    Extension(server): Extension<Arc<Server>>,
1075    Extension(principal): Extension<Principal>,
1076    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1077    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1078    ws: WebSocketUpgrade,
1079) -> axum::response::Response {
1080    if protocol_version != rpc::PROTOCOL_VERSION {
1081        return (
1082            StatusCode::UPGRADE_REQUIRED,
1083            "client must be upgraded".to_string(),
1084        )
1085            .into_response();
1086    }
1087
1088    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1089        return (
1090            StatusCode::UPGRADE_REQUIRED,
1091            "no version header found".to_string(),
1092        )
1093            .into_response();
1094    };
1095
1096    if !version.can_collaborate() {
1097        return (
1098            StatusCode::UPGRADE_REQUIRED,
1099            "client must be upgraded".to_string(),
1100        )
1101            .into_response();
1102    }
1103
1104    let socket_address = socket_address.to_string();
1105    ws.on_upgrade(move |socket| {
1106        let socket = socket
1107            .map_ok(to_tungstenite_message)
1108            .err_into()
1109            .with(|message| async move { to_axum_message(message) });
1110        let connection = Connection::new(Box::pin(socket));
1111        async move {
1112            server
1113                .handle_connection(
1114                    connection,
1115                    socket_address,
1116                    principal,
1117                    version,
1118                    country_code_header.map(|header| header.to_string()),
1119                    system_id_header.map(|header| header.to_string()),
1120                    None,
1121                    Executor::Production,
1122                )
1123                .await;
1124        }
1125    })
1126}
1127
1128pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1129    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1130    let connections_metric = CONNECTIONS_METRIC
1131        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1132
1133    let connections = server
1134        .connection_pool
1135        .lock()
1136        .connections()
1137        .filter(|connection| !connection.admin)
1138        .count();
1139    connections_metric.set(connections as _);
1140
1141    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1142    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1143        register_int_gauge!(
1144            "shared_projects",
1145            "number of open projects with one or more guests"
1146        )
1147        .unwrap()
1148    });
1149
1150    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1151    shared_projects_metric.set(shared_projects as _);
1152
1153    let encoder = prometheus::TextEncoder::new();
1154    let metric_families = prometheus::gather();
1155    let encoded_metrics = encoder
1156        .encode_to_string(&metric_families)
1157        .map_err(|err| anyhow!("{}", err))?;
1158    Ok(encoded_metrics)
1159}
1160
1161#[instrument(err, skip(executor))]
1162async fn connection_lost(
1163    session: Session,
1164    mut teardown: watch::Receiver<bool>,
1165    executor: Executor,
1166) -> Result<()> {
1167    session.peer.disconnect(session.connection_id);
1168    session
1169        .connection_pool()
1170        .await
1171        .remove_connection(session.connection_id)?;
1172
1173    session
1174        .db()
1175        .await
1176        .connection_lost(session.connection_id)
1177        .await
1178        .trace_err();
1179
1180    futures::select_biased! {
1181        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1182
1183            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1184            leave_room_for_session(&session, session.connection_id).await.trace_err();
1185            leave_channel_buffers_for_session(&session)
1186                .await
1187                .trace_err();
1188
1189            if !session
1190                .connection_pool()
1191                .await
1192                .is_user_online(session.user_id())
1193            {
1194                let db = session.db().await;
1195                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1196                    room_updated(&room, &session.peer);
1197                }
1198            }
1199
1200            update_user_contacts(session.user_id(), &session).await?;
1201        },
1202        _ = teardown.changed().fuse() => {}
1203    }
1204
1205    Ok(())
1206}
1207
1208/// Acknowledges a ping from a client, used to keep the connection alive.
1209async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1210    response.send(proto::Ack {})?;
1211    Ok(())
1212}
1213
1214/// Creates a new room for calling (outside of channels)
1215async fn create_room(
1216    _request: proto::CreateRoom,
1217    response: Response<proto::CreateRoom>,
1218    session: Session,
1219) -> Result<()> {
1220    let livekit_room = nanoid::nanoid!(30);
1221
1222    let live_kit_connection_info = util::maybe!(async {
1223        let live_kit = session.app_state.livekit_client.as_ref();
1224        let live_kit = live_kit?;
1225        let user_id = session.user_id().to_string();
1226
1227        let token = live_kit
1228            .room_token(&livekit_room, &user_id.to_string())
1229            .trace_err()?;
1230
1231        Some(proto::LiveKitConnectionInfo {
1232            server_url: live_kit.url().into(),
1233            token,
1234            can_publish: true,
1235        })
1236    })
1237    .await;
1238
1239    let room = session
1240        .db()
1241        .await
1242        .create_room(session.user_id(), session.connection_id, &livekit_room)
1243        .await?;
1244
1245    response.send(proto::CreateRoomResponse {
1246        room: Some(room.clone()),
1247        live_kit_connection_info,
1248    })?;
1249
1250    update_user_contacts(session.user_id(), &session).await?;
1251    Ok(())
1252}
1253
1254/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1255async fn join_room(
1256    request: proto::JoinRoom,
1257    response: Response<proto::JoinRoom>,
1258    session: Session,
1259) -> Result<()> {
1260    let room_id = RoomId::from_proto(request.id);
1261
1262    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1263
1264    if let Some(channel_id) = channel_id {
1265        return join_channel_internal(channel_id, Box::new(response), session).await;
1266    }
1267
1268    let joined_room = {
1269        let room = session
1270            .db()
1271            .await
1272            .join_room(room_id, session.user_id(), session.connection_id)
1273            .await?;
1274        room_updated(&room.room, &session.peer);
1275        room.into_inner()
1276    };
1277
1278    for connection_id in session
1279        .connection_pool()
1280        .await
1281        .user_connection_ids(session.user_id())
1282    {
1283        session
1284            .peer
1285            .send(
1286                connection_id,
1287                proto::CallCanceled {
1288                    room_id: room_id.to_proto(),
1289                },
1290            )
1291            .trace_err();
1292    }
1293
1294    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1295    {
1296        live_kit
1297            .room_token(
1298                &joined_room.room.livekit_room,
1299                &session.user_id().to_string(),
1300            )
1301            .trace_err()
1302            .map(|token| proto::LiveKitConnectionInfo {
1303                server_url: live_kit.url().into(),
1304                token,
1305                can_publish: true,
1306            })
1307    } else {
1308        None
1309    };
1310
1311    response.send(proto::JoinRoomResponse {
1312        room: Some(joined_room.room),
1313        channel_id: None,
1314        live_kit_connection_info,
1315    })?;
1316
1317    update_user_contacts(session.user_id(), &session).await?;
1318    Ok(())
1319}
1320
1321/// Rejoin room is used to reconnect to a room after connection errors.
1322async fn rejoin_room(
1323    request: proto::RejoinRoom,
1324    response: Response<proto::RejoinRoom>,
1325    session: Session,
1326) -> Result<()> {
1327    let room;
1328    let channel;
1329    {
1330        let mut rejoined_room = session
1331            .db()
1332            .await
1333            .rejoin_room(request, session.user_id(), session.connection_id)
1334            .await?;
1335
1336        response.send(proto::RejoinRoomResponse {
1337            room: Some(rejoined_room.room.clone()),
1338            reshared_projects: rejoined_room
1339                .reshared_projects
1340                .iter()
1341                .map(|project| proto::ResharedProject {
1342                    id: project.id.to_proto(),
1343                    collaborators: project
1344                        .collaborators
1345                        .iter()
1346                        .map(|collaborator| collaborator.to_proto())
1347                        .collect(),
1348                })
1349                .collect(),
1350            rejoined_projects: rejoined_room
1351                .rejoined_projects
1352                .iter()
1353                .map(|rejoined_project| rejoined_project.to_proto())
1354                .collect(),
1355        })?;
1356        room_updated(&rejoined_room.room, &session.peer);
1357
1358        for project in &rejoined_room.reshared_projects {
1359            for collaborator in &project.collaborators {
1360                session
1361                    .peer
1362                    .send(
1363                        collaborator.connection_id,
1364                        proto::UpdateProjectCollaborator {
1365                            project_id: project.id.to_proto(),
1366                            old_peer_id: Some(project.old_connection_id.into()),
1367                            new_peer_id: Some(session.connection_id.into()),
1368                        },
1369                    )
1370                    .trace_err();
1371            }
1372
1373            broadcast(
1374                Some(session.connection_id),
1375                project
1376                    .collaborators
1377                    .iter()
1378                    .map(|collaborator| collaborator.connection_id),
1379                |connection_id| {
1380                    session.peer.forward_send(
1381                        session.connection_id,
1382                        connection_id,
1383                        proto::UpdateProject {
1384                            project_id: project.id.to_proto(),
1385                            worktrees: project.worktrees.clone(),
1386                        },
1387                    )
1388                },
1389            );
1390        }
1391
1392        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1393
1394        let rejoined_room = rejoined_room.into_inner();
1395
1396        room = rejoined_room.room;
1397        channel = rejoined_room.channel;
1398    }
1399
1400    if let Some(channel) = channel {
1401        channel_updated(
1402            &channel,
1403            &room,
1404            &session.peer,
1405            &*session.connection_pool().await,
1406        );
1407    }
1408
1409    update_user_contacts(session.user_id(), &session).await?;
1410    Ok(())
1411}
1412
1413fn notify_rejoined_projects(
1414    rejoined_projects: &mut Vec<RejoinedProject>,
1415    session: &Session,
1416) -> Result<()> {
1417    for project in rejoined_projects.iter() {
1418        for collaborator in &project.collaborators {
1419            session
1420                .peer
1421                .send(
1422                    collaborator.connection_id,
1423                    proto::UpdateProjectCollaborator {
1424                        project_id: project.id.to_proto(),
1425                        old_peer_id: Some(project.old_connection_id.into()),
1426                        new_peer_id: Some(session.connection_id.into()),
1427                    },
1428                )
1429                .trace_err();
1430        }
1431    }
1432
1433    for project in rejoined_projects {
1434        for worktree in mem::take(&mut project.worktrees) {
1435            // Stream this worktree's entries.
1436            let message = proto::UpdateWorktree {
1437                project_id: project.id.to_proto(),
1438                worktree_id: worktree.id,
1439                abs_path: worktree.abs_path.clone(),
1440                root_name: worktree.root_name,
1441                updated_entries: worktree.updated_entries,
1442                removed_entries: worktree.removed_entries,
1443                scan_id: worktree.scan_id,
1444                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1445                updated_repositories: worktree.updated_repositories,
1446                removed_repositories: worktree.removed_repositories,
1447            };
1448            for update in proto::split_worktree_update(message) {
1449                session.peer.send(session.connection_id, update.clone())?;
1450            }
1451
1452            // Stream this worktree's diagnostics.
1453            for summary in worktree.diagnostic_summaries {
1454                session.peer.send(
1455                    session.connection_id,
1456                    proto::UpdateDiagnosticSummary {
1457                        project_id: project.id.to_proto(),
1458                        worktree_id: worktree.id,
1459                        summary: Some(summary),
1460                    },
1461                )?;
1462            }
1463
1464            for settings_file in worktree.settings_files {
1465                session.peer.send(
1466                    session.connection_id,
1467                    proto::UpdateWorktreeSettings {
1468                        project_id: project.id.to_proto(),
1469                        worktree_id: worktree.id,
1470                        path: settings_file.path,
1471                        content: Some(settings_file.content),
1472                        kind: Some(settings_file.kind.to_proto().into()),
1473                    },
1474                )?;
1475            }
1476        }
1477
1478        for language_server in &project.language_servers {
1479            session.peer.send(
1480                session.connection_id,
1481                proto::UpdateLanguageServer {
1482                    project_id: project.id.to_proto(),
1483                    language_server_id: language_server.id,
1484                    variant: Some(
1485                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1486                            proto::LspDiskBasedDiagnosticsUpdated {},
1487                        ),
1488                    ),
1489                },
1490            )?;
1491        }
1492    }
1493    Ok(())
1494}
1495
1496/// leave room disconnects from the room.
1497async fn leave_room(
1498    _: proto::LeaveRoom,
1499    response: Response<proto::LeaveRoom>,
1500    session: Session,
1501) -> Result<()> {
1502    leave_room_for_session(&session, session.connection_id).await?;
1503    response.send(proto::Ack {})?;
1504    Ok(())
1505}
1506
1507/// Updates the permissions of someone else in the room.
1508async fn set_room_participant_role(
1509    request: proto::SetRoomParticipantRole,
1510    response: Response<proto::SetRoomParticipantRole>,
1511    session: Session,
1512) -> Result<()> {
1513    let user_id = UserId::from_proto(request.user_id);
1514    let role = ChannelRole::from(request.role());
1515
1516    let (livekit_room, can_publish) = {
1517        let room = session
1518            .db()
1519            .await
1520            .set_room_participant_role(
1521                session.user_id(),
1522                RoomId::from_proto(request.room_id),
1523                user_id,
1524                role,
1525            )
1526            .await?;
1527
1528        let livekit_room = room.livekit_room.clone();
1529        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1530        room_updated(&room, &session.peer);
1531        (livekit_room, can_publish)
1532    };
1533
1534    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1535        live_kit
1536            .update_participant(
1537                livekit_room.clone(),
1538                request.user_id.to_string(),
1539                livekit_server::proto::ParticipantPermission {
1540                    can_subscribe: true,
1541                    can_publish,
1542                    can_publish_data: can_publish,
1543                    hidden: false,
1544                    recorder: false,
1545                },
1546            )
1547            .await
1548            .trace_err();
1549    }
1550
1551    response.send(proto::Ack {})?;
1552    Ok(())
1553}
1554
1555/// Call someone else into the current room
1556async fn call(
1557    request: proto::Call,
1558    response: Response<proto::Call>,
1559    session: Session,
1560) -> Result<()> {
1561    let room_id = RoomId::from_proto(request.room_id);
1562    let calling_user_id = session.user_id();
1563    let calling_connection_id = session.connection_id;
1564    let called_user_id = UserId::from_proto(request.called_user_id);
1565    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1566    if !session
1567        .db()
1568        .await
1569        .has_contact(calling_user_id, called_user_id)
1570        .await?
1571    {
1572        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1573    }
1574
1575    let incoming_call = {
1576        let (room, incoming_call) = &mut *session
1577            .db()
1578            .await
1579            .call(
1580                room_id,
1581                calling_user_id,
1582                calling_connection_id,
1583                called_user_id,
1584                initial_project_id,
1585            )
1586            .await?;
1587        room_updated(room, &session.peer);
1588        mem::take(incoming_call)
1589    };
1590    update_user_contacts(called_user_id, &session).await?;
1591
1592    let mut calls = session
1593        .connection_pool()
1594        .await
1595        .user_connection_ids(called_user_id)
1596        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1597        .collect::<FuturesUnordered<_>>();
1598
1599    while let Some(call_response) = calls.next().await {
1600        match call_response.as_ref() {
1601            Ok(_) => {
1602                response.send(proto::Ack {})?;
1603                return Ok(());
1604            }
1605            Err(_) => {
1606                call_response.trace_err();
1607            }
1608        }
1609    }
1610
1611    {
1612        let room = session
1613            .db()
1614            .await
1615            .call_failed(room_id, called_user_id)
1616            .await?;
1617        room_updated(&room, &session.peer);
1618    }
1619    update_user_contacts(called_user_id, &session).await?;
1620
1621    Err(anyhow!("failed to ring user"))?
1622}
1623
1624/// Cancel an outgoing call.
1625async fn cancel_call(
1626    request: proto::CancelCall,
1627    response: Response<proto::CancelCall>,
1628    session: Session,
1629) -> Result<()> {
1630    let called_user_id = UserId::from_proto(request.called_user_id);
1631    let room_id = RoomId::from_proto(request.room_id);
1632    {
1633        let room = session
1634            .db()
1635            .await
1636            .cancel_call(room_id, session.connection_id, called_user_id)
1637            .await?;
1638        room_updated(&room, &session.peer);
1639    }
1640
1641    for connection_id in session
1642        .connection_pool()
1643        .await
1644        .user_connection_ids(called_user_id)
1645    {
1646        session
1647            .peer
1648            .send(
1649                connection_id,
1650                proto::CallCanceled {
1651                    room_id: room_id.to_proto(),
1652                },
1653            )
1654            .trace_err();
1655    }
1656    response.send(proto::Ack {})?;
1657
1658    update_user_contacts(called_user_id, &session).await?;
1659    Ok(())
1660}
1661
1662/// Decline an incoming call.
1663async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1664    let room_id = RoomId::from_proto(message.room_id);
1665    {
1666        let room = session
1667            .db()
1668            .await
1669            .decline_call(Some(room_id), session.user_id())
1670            .await?
1671            .ok_or_else(|| anyhow!("failed to decline call"))?;
1672        room_updated(&room, &session.peer);
1673    }
1674
1675    for connection_id in session
1676        .connection_pool()
1677        .await
1678        .user_connection_ids(session.user_id())
1679    {
1680        session
1681            .peer
1682            .send(
1683                connection_id,
1684                proto::CallCanceled {
1685                    room_id: room_id.to_proto(),
1686                },
1687            )
1688            .trace_err();
1689    }
1690    update_user_contacts(session.user_id(), &session).await?;
1691    Ok(())
1692}
1693
1694/// Updates other participants in the room with your current location.
1695async fn update_participant_location(
1696    request: proto::UpdateParticipantLocation,
1697    response: Response<proto::UpdateParticipantLocation>,
1698    session: Session,
1699) -> Result<()> {
1700    let room_id = RoomId::from_proto(request.room_id);
1701    let location = request
1702        .location
1703        .ok_or_else(|| anyhow!("invalid location"))?;
1704
1705    let db = session.db().await;
1706    let room = db
1707        .update_room_participant_location(room_id, session.connection_id, location)
1708        .await?;
1709
1710    room_updated(&room, &session.peer);
1711    response.send(proto::Ack {})?;
1712    Ok(())
1713}
1714
1715/// Share a project into the room.
1716async fn share_project(
1717    request: proto::ShareProject,
1718    response: Response<proto::ShareProject>,
1719    session: Session,
1720) -> Result<()> {
1721    let (project_id, room) = &*session
1722        .db()
1723        .await
1724        .share_project(
1725            RoomId::from_proto(request.room_id),
1726            session.connection_id,
1727            &request.worktrees,
1728            request.is_ssh_project,
1729        )
1730        .await?;
1731    response.send(proto::ShareProjectResponse {
1732        project_id: project_id.to_proto(),
1733    })?;
1734    room_updated(room, &session.peer);
1735
1736    Ok(())
1737}
1738
1739/// Unshare a project from the room.
1740async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1741    let project_id = ProjectId::from_proto(message.project_id);
1742    unshare_project_internal(project_id, session.connection_id, &session).await
1743}
1744
1745async fn unshare_project_internal(
1746    project_id: ProjectId,
1747    connection_id: ConnectionId,
1748    session: &Session,
1749) -> Result<()> {
1750    let delete = {
1751        let room_guard = session
1752            .db()
1753            .await
1754            .unshare_project(project_id, connection_id)
1755            .await?;
1756
1757        let (delete, room, guest_connection_ids) = &*room_guard;
1758
1759        let message = proto::UnshareProject {
1760            project_id: project_id.to_proto(),
1761        };
1762
1763        broadcast(
1764            Some(connection_id),
1765            guest_connection_ids.iter().copied(),
1766            |conn_id| session.peer.send(conn_id, message.clone()),
1767        );
1768        if let Some(room) = room {
1769            room_updated(room, &session.peer);
1770        }
1771
1772        *delete
1773    };
1774
1775    if delete {
1776        let db = session.db().await;
1777        db.delete_project(project_id).await?;
1778    }
1779
1780    Ok(())
1781}
1782
1783/// Join someone elses shared project.
1784async fn join_project(
1785    request: proto::JoinProject,
1786    response: Response<proto::JoinProject>,
1787    session: Session,
1788) -> Result<()> {
1789    let project_id = ProjectId::from_proto(request.project_id);
1790
1791    tracing::info!(%project_id, "join project");
1792
1793    let db = session.db().await;
1794    let (project, replica_id) = &mut *db
1795        .join_project(project_id, session.connection_id, session.user_id())
1796        .await?;
1797    drop(db);
1798    tracing::info!(%project_id, "join remote project");
1799    join_project_internal(response, session, project, replica_id)
1800}
1801
1802trait JoinProjectInternalResponse {
1803    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1804}
1805impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1806    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1807        Response::<proto::JoinProject>::send(self, result)
1808    }
1809}
1810
1811fn join_project_internal(
1812    response: impl JoinProjectInternalResponse,
1813    session: Session,
1814    project: &mut Project,
1815    replica_id: &ReplicaId,
1816) -> Result<()> {
1817    let collaborators = project
1818        .collaborators
1819        .iter()
1820        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1821        .map(|collaborator| collaborator.to_proto())
1822        .collect::<Vec<_>>();
1823    let project_id = project.id;
1824    let guest_user_id = session.user_id();
1825
1826    let worktrees = project
1827        .worktrees
1828        .iter()
1829        .map(|(id, worktree)| proto::WorktreeMetadata {
1830            id: *id,
1831            root_name: worktree.root_name.clone(),
1832            visible: worktree.visible,
1833            abs_path: worktree.abs_path.clone(),
1834        })
1835        .collect::<Vec<_>>();
1836
1837    let add_project_collaborator = proto::AddProjectCollaborator {
1838        project_id: project_id.to_proto(),
1839        collaborator: Some(proto::Collaborator {
1840            peer_id: Some(session.connection_id.into()),
1841            replica_id: replica_id.0 as u32,
1842            user_id: guest_user_id.to_proto(),
1843            is_host: false,
1844        }),
1845    };
1846
1847    for collaborator in &collaborators {
1848        session
1849            .peer
1850            .send(
1851                collaborator.peer_id.unwrap().into(),
1852                add_project_collaborator.clone(),
1853            )
1854            .trace_err();
1855    }
1856
1857    // First, we send the metadata associated with each worktree.
1858    response.send(proto::JoinProjectResponse {
1859        project_id: project.id.0 as u64,
1860        worktrees: worktrees.clone(),
1861        replica_id: replica_id.0 as u32,
1862        collaborators: collaborators.clone(),
1863        language_servers: project.language_servers.clone(),
1864        role: project.role.into(),
1865    })?;
1866
1867    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1868        // Stream this worktree's entries.
1869        let message = proto::UpdateWorktree {
1870            project_id: project_id.to_proto(),
1871            worktree_id,
1872            abs_path: worktree.abs_path.clone(),
1873            root_name: worktree.root_name,
1874            updated_entries: worktree.entries,
1875            removed_entries: Default::default(),
1876            scan_id: worktree.scan_id,
1877            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1878            updated_repositories: worktree.repository_entries.into_values().collect(),
1879            removed_repositories: Default::default(),
1880        };
1881        for update in proto::split_worktree_update(message) {
1882            session.peer.send(session.connection_id, update.clone())?;
1883        }
1884
1885        // Stream this worktree's diagnostics.
1886        for summary in worktree.diagnostic_summaries {
1887            session.peer.send(
1888                session.connection_id,
1889                proto::UpdateDiagnosticSummary {
1890                    project_id: project_id.to_proto(),
1891                    worktree_id: worktree.id,
1892                    summary: Some(summary),
1893                },
1894            )?;
1895        }
1896
1897        for settings_file in worktree.settings_files {
1898            session.peer.send(
1899                session.connection_id,
1900                proto::UpdateWorktreeSettings {
1901                    project_id: project_id.to_proto(),
1902                    worktree_id: worktree.id,
1903                    path: settings_file.path,
1904                    content: Some(settings_file.content),
1905                    kind: Some(settings_file.kind.to_proto() as i32),
1906                },
1907            )?;
1908        }
1909    }
1910
1911    for language_server in &project.language_servers {
1912        session.peer.send(
1913            session.connection_id,
1914            proto::UpdateLanguageServer {
1915                project_id: project_id.to_proto(),
1916                language_server_id: language_server.id,
1917                variant: Some(
1918                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1919                        proto::LspDiskBasedDiagnosticsUpdated {},
1920                    ),
1921                ),
1922            },
1923        )?;
1924    }
1925
1926    Ok(())
1927}
1928
1929/// Leave someone elses shared project.
1930async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1931    let sender_id = session.connection_id;
1932    let project_id = ProjectId::from_proto(request.project_id);
1933    let db = session.db().await;
1934
1935    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1936    tracing::info!(
1937        %project_id,
1938        "leave project"
1939    );
1940
1941    project_left(project, &session);
1942    if let Some(room) = room {
1943        room_updated(room, &session.peer);
1944    }
1945
1946    Ok(())
1947}
1948
1949/// Updates other participants with changes to the project
1950async fn update_project(
1951    request: proto::UpdateProject,
1952    response: Response<proto::UpdateProject>,
1953    session: Session,
1954) -> Result<()> {
1955    let project_id = ProjectId::from_proto(request.project_id);
1956    let (room, guest_connection_ids) = &*session
1957        .db()
1958        .await
1959        .update_project(project_id, session.connection_id, &request.worktrees)
1960        .await?;
1961    broadcast(
1962        Some(session.connection_id),
1963        guest_connection_ids.iter().copied(),
1964        |connection_id| {
1965            session
1966                .peer
1967                .forward_send(session.connection_id, connection_id, request.clone())
1968        },
1969    );
1970    if let Some(room) = room {
1971        room_updated(room, &session.peer);
1972    }
1973    response.send(proto::Ack {})?;
1974
1975    Ok(())
1976}
1977
1978/// Updates other participants with changes to the worktree
1979async fn update_worktree(
1980    request: proto::UpdateWorktree,
1981    response: Response<proto::UpdateWorktree>,
1982    session: Session,
1983) -> Result<()> {
1984    let guest_connection_ids = session
1985        .db()
1986        .await
1987        .update_worktree(&request, session.connection_id)
1988        .await?;
1989
1990    broadcast(
1991        Some(session.connection_id),
1992        guest_connection_ids.iter().copied(),
1993        |connection_id| {
1994            session
1995                .peer
1996                .forward_send(session.connection_id, connection_id, request.clone())
1997        },
1998    );
1999    response.send(proto::Ack {})?;
2000    Ok(())
2001}
2002
2003/// Updates other participants with changes to the diagnostics
2004async fn update_diagnostic_summary(
2005    message: proto::UpdateDiagnosticSummary,
2006    session: Session,
2007) -> Result<()> {
2008    let guest_connection_ids = session
2009        .db()
2010        .await
2011        .update_diagnostic_summary(&message, session.connection_id)
2012        .await?;
2013
2014    broadcast(
2015        Some(session.connection_id),
2016        guest_connection_ids.iter().copied(),
2017        |connection_id| {
2018            session
2019                .peer
2020                .forward_send(session.connection_id, connection_id, message.clone())
2021        },
2022    );
2023
2024    Ok(())
2025}
2026
2027/// Updates other participants with changes to the worktree settings
2028async fn update_worktree_settings(
2029    message: proto::UpdateWorktreeSettings,
2030    session: Session,
2031) -> Result<()> {
2032    let guest_connection_ids = session
2033        .db()
2034        .await
2035        .update_worktree_settings(&message, session.connection_id)
2036        .await?;
2037
2038    broadcast(
2039        Some(session.connection_id),
2040        guest_connection_ids.iter().copied(),
2041        |connection_id| {
2042            session
2043                .peer
2044                .forward_send(session.connection_id, connection_id, message.clone())
2045        },
2046    );
2047
2048    Ok(())
2049}
2050
2051/// Notify other participants that a  language server has started.
2052async fn start_language_server(
2053    request: proto::StartLanguageServer,
2054    session: Session,
2055) -> Result<()> {
2056    let guest_connection_ids = session
2057        .db()
2058        .await
2059        .start_language_server(&request, session.connection_id)
2060        .await?;
2061
2062    broadcast(
2063        Some(session.connection_id),
2064        guest_connection_ids.iter().copied(),
2065        |connection_id| {
2066            session
2067                .peer
2068                .forward_send(session.connection_id, connection_id, request.clone())
2069        },
2070    );
2071    Ok(())
2072}
2073
2074/// Notify other participants that a language server has changed.
2075async fn update_language_server(
2076    request: proto::UpdateLanguageServer,
2077    session: Session,
2078) -> Result<()> {
2079    let project_id = ProjectId::from_proto(request.project_id);
2080    let project_connection_ids = session
2081        .db()
2082        .await
2083        .project_connection_ids(project_id, session.connection_id, true)
2084        .await?;
2085    broadcast(
2086        Some(session.connection_id),
2087        project_connection_ids.iter().copied(),
2088        |connection_id| {
2089            session
2090                .peer
2091                .forward_send(session.connection_id, connection_id, request.clone())
2092        },
2093    );
2094    Ok(())
2095}
2096
2097/// forward a project request to the host. These requests should be read only
2098/// as guests are allowed to send them.
2099async fn forward_read_only_project_request<T>(
2100    request: T,
2101    response: Response<T>,
2102    session: Session,
2103) -> Result<()>
2104where
2105    T: EntityMessage + RequestMessage,
2106{
2107    let project_id = ProjectId::from_proto(request.remote_entity_id());
2108    let host_connection_id = session
2109        .db()
2110        .await
2111        .host_for_read_only_project_request(project_id, session.connection_id)
2112        .await?;
2113    let payload = session
2114        .peer
2115        .forward_request(session.connection_id, host_connection_id, request)
2116        .await?;
2117    response.send(payload)?;
2118    Ok(())
2119}
2120
2121async fn forward_find_search_candidates_request(
2122    request: proto::FindSearchCandidates,
2123    response: Response<proto::FindSearchCandidates>,
2124    session: Session,
2125) -> Result<()> {
2126    let project_id = ProjectId::from_proto(request.remote_entity_id());
2127    let host_connection_id = session
2128        .db()
2129        .await
2130        .host_for_read_only_project_request(project_id, session.connection_id)
2131        .await?;
2132    let payload = session
2133        .peer
2134        .forward_request(session.connection_id, host_connection_id, request)
2135        .await?;
2136    response.send(payload)?;
2137    Ok(())
2138}
2139
2140/// forward a project request to the host. These requests are disallowed
2141/// for guests.
2142async fn forward_mutating_project_request<T>(
2143    request: T,
2144    response: Response<T>,
2145    session: Session,
2146) -> Result<()>
2147where
2148    T: EntityMessage + RequestMessage,
2149{
2150    let project_id = ProjectId::from_proto(request.remote_entity_id());
2151
2152    let host_connection_id = session
2153        .db()
2154        .await
2155        .host_for_mutating_project_request(project_id, session.connection_id)
2156        .await?;
2157    let payload = session
2158        .peer
2159        .forward_request(session.connection_id, host_connection_id, request)
2160        .await?;
2161    response.send(payload)?;
2162    Ok(())
2163}
2164
2165/// Notify other participants that a new buffer has been created
2166async fn create_buffer_for_peer(
2167    request: proto::CreateBufferForPeer,
2168    session: Session,
2169) -> Result<()> {
2170    session
2171        .db()
2172        .await
2173        .check_user_is_project_host(
2174            ProjectId::from_proto(request.project_id),
2175            session.connection_id,
2176        )
2177        .await?;
2178    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2179    session
2180        .peer
2181        .forward_send(session.connection_id, peer_id.into(), request)?;
2182    Ok(())
2183}
2184
2185/// Notify other participants that a buffer has been updated. This is
2186/// allowed for guests as long as the update is limited to selections.
2187async fn update_buffer(
2188    request: proto::UpdateBuffer,
2189    response: Response<proto::UpdateBuffer>,
2190    session: Session,
2191) -> Result<()> {
2192    let project_id = ProjectId::from_proto(request.project_id);
2193    let mut capability = Capability::ReadOnly;
2194
2195    for op in request.operations.iter() {
2196        match op.variant {
2197            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2198            Some(_) => capability = Capability::ReadWrite,
2199        }
2200    }
2201
2202    let host = {
2203        let guard = session
2204            .db()
2205            .await
2206            .connections_for_buffer_update(project_id, session.connection_id, capability)
2207            .await?;
2208
2209        let (host, guests) = &*guard;
2210
2211        broadcast(
2212            Some(session.connection_id),
2213            guests.clone(),
2214            |connection_id| {
2215                session
2216                    .peer
2217                    .forward_send(session.connection_id, connection_id, request.clone())
2218            },
2219        );
2220
2221        *host
2222    };
2223
2224    if host != session.connection_id {
2225        session
2226            .peer
2227            .forward_request(session.connection_id, host, request.clone())
2228            .await?;
2229    }
2230
2231    response.send(proto::Ack {})?;
2232    Ok(())
2233}
2234
2235async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2236    let project_id = ProjectId::from_proto(message.project_id);
2237
2238    let operation = message.operation.as_ref().context("invalid operation")?;
2239    let capability = match operation.variant.as_ref() {
2240        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2241            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2242                match buffer_op.variant {
2243                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2244                        Capability::ReadOnly
2245                    }
2246                    _ => Capability::ReadWrite,
2247                }
2248            } else {
2249                Capability::ReadWrite
2250            }
2251        }
2252        Some(_) => Capability::ReadWrite,
2253        None => Capability::ReadOnly,
2254    };
2255
2256    let guard = session
2257        .db()
2258        .await
2259        .connections_for_buffer_update(project_id, session.connection_id, capability)
2260        .await?;
2261
2262    let (host, guests) = &*guard;
2263
2264    broadcast(
2265        Some(session.connection_id),
2266        guests.iter().chain([host]).copied(),
2267        |connection_id| {
2268            session
2269                .peer
2270                .forward_send(session.connection_id, connection_id, message.clone())
2271        },
2272    );
2273
2274    Ok(())
2275}
2276
2277/// Notify other participants that a project has been updated.
2278async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2279    request: T,
2280    session: Session,
2281) -> Result<()> {
2282    let project_id = ProjectId::from_proto(request.remote_entity_id());
2283    let project_connection_ids = session
2284        .db()
2285        .await
2286        .project_connection_ids(project_id, session.connection_id, false)
2287        .await?;
2288
2289    broadcast(
2290        Some(session.connection_id),
2291        project_connection_ids.iter().copied(),
2292        |connection_id| {
2293            session
2294                .peer
2295                .forward_send(session.connection_id, connection_id, request.clone())
2296        },
2297    );
2298    Ok(())
2299}
2300
2301/// Start following another user in a call.
2302async fn follow(
2303    request: proto::Follow,
2304    response: Response<proto::Follow>,
2305    session: Session,
2306) -> Result<()> {
2307    let room_id = RoomId::from_proto(request.room_id);
2308    let project_id = request.project_id.map(ProjectId::from_proto);
2309    let leader_id = request
2310        .leader_id
2311        .ok_or_else(|| anyhow!("invalid leader id"))?
2312        .into();
2313    let follower_id = session.connection_id;
2314
2315    session
2316        .db()
2317        .await
2318        .check_room_participants(room_id, leader_id, session.connection_id)
2319        .await?;
2320
2321    let response_payload = session
2322        .peer
2323        .forward_request(session.connection_id, leader_id, request)
2324        .await?;
2325    response.send(response_payload)?;
2326
2327    if let Some(project_id) = project_id {
2328        let room = session
2329            .db()
2330            .await
2331            .follow(room_id, project_id, leader_id, follower_id)
2332            .await?;
2333        room_updated(&room, &session.peer);
2334    }
2335
2336    Ok(())
2337}
2338
2339/// Stop following another user in a call.
2340async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2341    let room_id = RoomId::from_proto(request.room_id);
2342    let project_id = request.project_id.map(ProjectId::from_proto);
2343    let leader_id = request
2344        .leader_id
2345        .ok_or_else(|| anyhow!("invalid leader id"))?
2346        .into();
2347    let follower_id = session.connection_id;
2348
2349    session
2350        .db()
2351        .await
2352        .check_room_participants(room_id, leader_id, session.connection_id)
2353        .await?;
2354
2355    session
2356        .peer
2357        .forward_send(session.connection_id, leader_id, request)?;
2358
2359    if let Some(project_id) = project_id {
2360        let room = session
2361            .db()
2362            .await
2363            .unfollow(room_id, project_id, leader_id, follower_id)
2364            .await?;
2365        room_updated(&room, &session.peer);
2366    }
2367
2368    Ok(())
2369}
2370
2371/// Notify everyone following you of your current location.
2372async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2373    let room_id = RoomId::from_proto(request.room_id);
2374    let database = session.db.lock().await;
2375
2376    let connection_ids = if let Some(project_id) = request.project_id {
2377        let project_id = ProjectId::from_proto(project_id);
2378        database
2379            .project_connection_ids(project_id, session.connection_id, true)
2380            .await?
2381    } else {
2382        database
2383            .room_connection_ids(room_id, session.connection_id)
2384            .await?
2385    };
2386
2387    // For now, don't send view update messages back to that view's current leader.
2388    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2389        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2390        _ => None,
2391    });
2392
2393    for connection_id in connection_ids.iter().cloned() {
2394        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2395            session
2396                .peer
2397                .forward_send(session.connection_id, connection_id, request.clone())?;
2398        }
2399    }
2400    Ok(())
2401}
2402
2403/// Get public data about users.
2404async fn get_users(
2405    request: proto::GetUsers,
2406    response: Response<proto::GetUsers>,
2407    session: Session,
2408) -> Result<()> {
2409    let user_ids = request
2410        .user_ids
2411        .into_iter()
2412        .map(UserId::from_proto)
2413        .collect();
2414    let users = session
2415        .db()
2416        .await
2417        .get_users_by_ids(user_ids)
2418        .await?
2419        .into_iter()
2420        .map(|user| proto::User {
2421            id: user.id.to_proto(),
2422            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2423            github_login: user.github_login,
2424            email: user.email_address,
2425            name: user.name,
2426        })
2427        .collect();
2428    response.send(proto::UsersResponse { users })?;
2429    Ok(())
2430}
2431
2432/// Search for users (to invite) buy Github login
2433async fn fuzzy_search_users(
2434    request: proto::FuzzySearchUsers,
2435    response: Response<proto::FuzzySearchUsers>,
2436    session: Session,
2437) -> Result<()> {
2438    let query = request.query;
2439    let users = match query.len() {
2440        0 => vec![],
2441        1 | 2 => session
2442            .db()
2443            .await
2444            .get_user_by_github_login(&query)
2445            .await?
2446            .into_iter()
2447            .collect(),
2448        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2449    };
2450    let users = users
2451        .into_iter()
2452        .filter(|user| user.id != session.user_id())
2453        .map(|user| proto::User {
2454            id: user.id.to_proto(),
2455            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2456            github_login: user.github_login,
2457            name: user.name,
2458            email: user.email_address,
2459        })
2460        .collect();
2461    response.send(proto::UsersResponse { users })?;
2462    Ok(())
2463}
2464
2465/// Send a contact request to another user.
2466async fn request_contact(
2467    request: proto::RequestContact,
2468    response: Response<proto::RequestContact>,
2469    session: Session,
2470) -> Result<()> {
2471    let requester_id = session.user_id();
2472    let responder_id = UserId::from_proto(request.responder_id);
2473    if requester_id == responder_id {
2474        return Err(anyhow!("cannot add yourself as a contact"))?;
2475    }
2476
2477    let notifications = session
2478        .db()
2479        .await
2480        .send_contact_request(requester_id, responder_id)
2481        .await?;
2482
2483    // Update outgoing contact requests of requester
2484    let mut update = proto::UpdateContacts::default();
2485    update.outgoing_requests.push(responder_id.to_proto());
2486    for connection_id in session
2487        .connection_pool()
2488        .await
2489        .user_connection_ids(requester_id)
2490    {
2491        session.peer.send(connection_id, update.clone())?;
2492    }
2493
2494    // Update incoming contact requests of responder
2495    let mut update = proto::UpdateContacts::default();
2496    update
2497        .incoming_requests
2498        .push(proto::IncomingContactRequest {
2499            requester_id: requester_id.to_proto(),
2500        });
2501    let connection_pool = session.connection_pool().await;
2502    for connection_id in connection_pool.user_connection_ids(responder_id) {
2503        session.peer.send(connection_id, update.clone())?;
2504    }
2505
2506    send_notifications(&connection_pool, &session.peer, notifications);
2507
2508    response.send(proto::Ack {})?;
2509    Ok(())
2510}
2511
2512/// Accept or decline a contact request
2513async fn respond_to_contact_request(
2514    request: proto::RespondToContactRequest,
2515    response: Response<proto::RespondToContactRequest>,
2516    session: Session,
2517) -> Result<()> {
2518    let responder_id = session.user_id();
2519    let requester_id = UserId::from_proto(request.requester_id);
2520    let db = session.db().await;
2521    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2522        db.dismiss_contact_notification(responder_id, requester_id)
2523            .await?;
2524    } else {
2525        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2526
2527        let notifications = db
2528            .respond_to_contact_request(responder_id, requester_id, accept)
2529            .await?;
2530        let requester_busy = db.is_user_busy(requester_id).await?;
2531        let responder_busy = db.is_user_busy(responder_id).await?;
2532
2533        let pool = session.connection_pool().await;
2534        // Update responder with new contact
2535        let mut update = proto::UpdateContacts::default();
2536        if accept {
2537            update
2538                .contacts
2539                .push(contact_for_user(requester_id, requester_busy, &pool));
2540        }
2541        update
2542            .remove_incoming_requests
2543            .push(requester_id.to_proto());
2544        for connection_id in pool.user_connection_ids(responder_id) {
2545            session.peer.send(connection_id, update.clone())?;
2546        }
2547
2548        // Update requester with new contact
2549        let mut update = proto::UpdateContacts::default();
2550        if accept {
2551            update
2552                .contacts
2553                .push(contact_for_user(responder_id, responder_busy, &pool));
2554        }
2555        update
2556            .remove_outgoing_requests
2557            .push(responder_id.to_proto());
2558
2559        for connection_id in pool.user_connection_ids(requester_id) {
2560            session.peer.send(connection_id, update.clone())?;
2561        }
2562
2563        send_notifications(&pool, &session.peer, notifications);
2564    }
2565
2566    response.send(proto::Ack {})?;
2567    Ok(())
2568}
2569
2570/// Remove a contact.
2571async fn remove_contact(
2572    request: proto::RemoveContact,
2573    response: Response<proto::RemoveContact>,
2574    session: Session,
2575) -> Result<()> {
2576    let requester_id = session.user_id();
2577    let responder_id = UserId::from_proto(request.user_id);
2578    let db = session.db().await;
2579    let (contact_accepted, deleted_notification_id) =
2580        db.remove_contact(requester_id, responder_id).await?;
2581
2582    let pool = session.connection_pool().await;
2583    // Update outgoing contact requests of requester
2584    let mut update = proto::UpdateContacts::default();
2585    if contact_accepted {
2586        update.remove_contacts.push(responder_id.to_proto());
2587    } else {
2588        update
2589            .remove_outgoing_requests
2590            .push(responder_id.to_proto());
2591    }
2592    for connection_id in pool.user_connection_ids(requester_id) {
2593        session.peer.send(connection_id, update.clone())?;
2594    }
2595
2596    // Update incoming contact requests of responder
2597    let mut update = proto::UpdateContacts::default();
2598    if contact_accepted {
2599        update.remove_contacts.push(requester_id.to_proto());
2600    } else {
2601        update
2602            .remove_incoming_requests
2603            .push(requester_id.to_proto());
2604    }
2605    for connection_id in pool.user_connection_ids(responder_id) {
2606        session.peer.send(connection_id, update.clone())?;
2607        if let Some(notification_id) = deleted_notification_id {
2608            session.peer.send(
2609                connection_id,
2610                proto::DeleteNotification {
2611                    notification_id: notification_id.to_proto(),
2612                },
2613            )?;
2614        }
2615    }
2616
2617    response.send(proto::Ack {})?;
2618    Ok(())
2619}
2620
2621fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2622    version.0.minor() < 139
2623}
2624
2625async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2626    let plan = session.current_plan(&session.db().await).await?;
2627
2628    session
2629        .peer
2630        .send(
2631            session.connection_id,
2632            proto::UpdateUserPlan { plan: plan.into() },
2633        )
2634        .trace_err();
2635
2636    Ok(())
2637}
2638
2639async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2640    subscribe_user_to_channels(session.user_id(), &session).await?;
2641    Ok(())
2642}
2643
2644async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2645    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2646    let mut pool = session.connection_pool().await;
2647    for membership in &channels_for_user.channel_memberships {
2648        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2649    }
2650    session.peer.send(
2651        session.connection_id,
2652        build_update_user_channels(&channels_for_user),
2653    )?;
2654    session.peer.send(
2655        session.connection_id,
2656        build_channels_update(channels_for_user),
2657    )?;
2658    Ok(())
2659}
2660
2661/// Creates a new channel.
2662async fn create_channel(
2663    request: proto::CreateChannel,
2664    response: Response<proto::CreateChannel>,
2665    session: Session,
2666) -> Result<()> {
2667    let db = session.db().await;
2668
2669    let parent_id = request.parent_id.map(ChannelId::from_proto);
2670    let (channel, membership) = db
2671        .create_channel(&request.name, parent_id, session.user_id())
2672        .await?;
2673
2674    let root_id = channel.root_id();
2675    let channel = Channel::from_model(channel);
2676
2677    response.send(proto::CreateChannelResponse {
2678        channel: Some(channel.to_proto()),
2679        parent_id: request.parent_id,
2680    })?;
2681
2682    let mut connection_pool = session.connection_pool().await;
2683    if let Some(membership) = membership {
2684        connection_pool.subscribe_to_channel(
2685            membership.user_id,
2686            membership.channel_id,
2687            membership.role,
2688        );
2689        let update = proto::UpdateUserChannels {
2690            channel_memberships: vec![proto::ChannelMembership {
2691                channel_id: membership.channel_id.to_proto(),
2692                role: membership.role.into(),
2693            }],
2694            ..Default::default()
2695        };
2696        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2697            session.peer.send(connection_id, update.clone())?;
2698        }
2699    }
2700
2701    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2702        if !role.can_see_channel(channel.visibility) {
2703            continue;
2704        }
2705
2706        let update = proto::UpdateChannels {
2707            channels: vec![channel.to_proto()],
2708            ..Default::default()
2709        };
2710        session.peer.send(connection_id, update.clone())?;
2711    }
2712
2713    Ok(())
2714}
2715
2716/// Delete a channel
2717async fn delete_channel(
2718    request: proto::DeleteChannel,
2719    response: Response<proto::DeleteChannel>,
2720    session: Session,
2721) -> Result<()> {
2722    let db = session.db().await;
2723
2724    let channel_id = request.channel_id;
2725    let (root_channel, removed_channels) = db
2726        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2727        .await?;
2728    response.send(proto::Ack {})?;
2729
2730    // Notify members of removed channels
2731    let mut update = proto::UpdateChannels::default();
2732    update
2733        .delete_channels
2734        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2735
2736    let connection_pool = session.connection_pool().await;
2737    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2738        session.peer.send(connection_id, update.clone())?;
2739    }
2740
2741    Ok(())
2742}
2743
2744/// Invite someone to join a channel.
2745async fn invite_channel_member(
2746    request: proto::InviteChannelMember,
2747    response: Response<proto::InviteChannelMember>,
2748    session: Session,
2749) -> Result<()> {
2750    let db = session.db().await;
2751    let channel_id = ChannelId::from_proto(request.channel_id);
2752    let invitee_id = UserId::from_proto(request.user_id);
2753    let InviteMemberResult {
2754        channel,
2755        notifications,
2756    } = db
2757        .invite_channel_member(
2758            channel_id,
2759            invitee_id,
2760            session.user_id(),
2761            request.role().into(),
2762        )
2763        .await?;
2764
2765    let update = proto::UpdateChannels {
2766        channel_invitations: vec![channel.to_proto()],
2767        ..Default::default()
2768    };
2769
2770    let connection_pool = session.connection_pool().await;
2771    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2772        session.peer.send(connection_id, update.clone())?;
2773    }
2774
2775    send_notifications(&connection_pool, &session.peer, notifications);
2776
2777    response.send(proto::Ack {})?;
2778    Ok(())
2779}
2780
2781/// remove someone from a channel
2782async fn remove_channel_member(
2783    request: proto::RemoveChannelMember,
2784    response: Response<proto::RemoveChannelMember>,
2785    session: Session,
2786) -> Result<()> {
2787    let db = session.db().await;
2788    let channel_id = ChannelId::from_proto(request.channel_id);
2789    let member_id = UserId::from_proto(request.user_id);
2790
2791    let RemoveChannelMemberResult {
2792        membership_update,
2793        notification_id,
2794    } = db
2795        .remove_channel_member(channel_id, member_id, session.user_id())
2796        .await?;
2797
2798    let mut connection_pool = session.connection_pool().await;
2799    notify_membership_updated(
2800        &mut connection_pool,
2801        membership_update,
2802        member_id,
2803        &session.peer,
2804    );
2805    for connection_id in connection_pool.user_connection_ids(member_id) {
2806        if let Some(notification_id) = notification_id {
2807            session
2808                .peer
2809                .send(
2810                    connection_id,
2811                    proto::DeleteNotification {
2812                        notification_id: notification_id.to_proto(),
2813                    },
2814                )
2815                .trace_err();
2816        }
2817    }
2818
2819    response.send(proto::Ack {})?;
2820    Ok(())
2821}
2822
2823/// Toggle the channel between public and private.
2824/// Care is taken to maintain the invariant that public channels only descend from public channels,
2825/// (though members-only channels can appear at any point in the hierarchy).
2826async fn set_channel_visibility(
2827    request: proto::SetChannelVisibility,
2828    response: Response<proto::SetChannelVisibility>,
2829    session: Session,
2830) -> Result<()> {
2831    let db = session.db().await;
2832    let channel_id = ChannelId::from_proto(request.channel_id);
2833    let visibility = request.visibility().into();
2834
2835    let channel_model = db
2836        .set_channel_visibility(channel_id, visibility, session.user_id())
2837        .await?;
2838    let root_id = channel_model.root_id();
2839    let channel = Channel::from_model(channel_model);
2840
2841    let mut connection_pool = session.connection_pool().await;
2842    for (user_id, role) in connection_pool
2843        .channel_user_ids(root_id)
2844        .collect::<Vec<_>>()
2845        .into_iter()
2846    {
2847        let update = if role.can_see_channel(channel.visibility) {
2848            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2849            proto::UpdateChannels {
2850                channels: vec![channel.to_proto()],
2851                ..Default::default()
2852            }
2853        } else {
2854            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2855            proto::UpdateChannels {
2856                delete_channels: vec![channel.id.to_proto()],
2857                ..Default::default()
2858            }
2859        };
2860
2861        for connection_id in connection_pool.user_connection_ids(user_id) {
2862            session.peer.send(connection_id, update.clone())?;
2863        }
2864    }
2865
2866    response.send(proto::Ack {})?;
2867    Ok(())
2868}
2869
2870/// Alter the role for a user in the channel.
2871async fn set_channel_member_role(
2872    request: proto::SetChannelMemberRole,
2873    response: Response<proto::SetChannelMemberRole>,
2874    session: Session,
2875) -> Result<()> {
2876    let db = session.db().await;
2877    let channel_id = ChannelId::from_proto(request.channel_id);
2878    let member_id = UserId::from_proto(request.user_id);
2879    let result = db
2880        .set_channel_member_role(
2881            channel_id,
2882            session.user_id(),
2883            member_id,
2884            request.role().into(),
2885        )
2886        .await?;
2887
2888    match result {
2889        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2890            let mut connection_pool = session.connection_pool().await;
2891            notify_membership_updated(
2892                &mut connection_pool,
2893                membership_update,
2894                member_id,
2895                &session.peer,
2896            )
2897        }
2898        db::SetMemberRoleResult::InviteUpdated(channel) => {
2899            let update = proto::UpdateChannels {
2900                channel_invitations: vec![channel.to_proto()],
2901                ..Default::default()
2902            };
2903
2904            for connection_id in session
2905                .connection_pool()
2906                .await
2907                .user_connection_ids(member_id)
2908            {
2909                session.peer.send(connection_id, update.clone())?;
2910            }
2911        }
2912    }
2913
2914    response.send(proto::Ack {})?;
2915    Ok(())
2916}
2917
2918/// Change the name of a channel
2919async fn rename_channel(
2920    request: proto::RenameChannel,
2921    response: Response<proto::RenameChannel>,
2922    session: Session,
2923) -> Result<()> {
2924    let db = session.db().await;
2925    let channel_id = ChannelId::from_proto(request.channel_id);
2926    let channel_model = db
2927        .rename_channel(channel_id, session.user_id(), &request.name)
2928        .await?;
2929    let root_id = channel_model.root_id();
2930    let channel = Channel::from_model(channel_model);
2931
2932    response.send(proto::RenameChannelResponse {
2933        channel: Some(channel.to_proto()),
2934    })?;
2935
2936    let connection_pool = session.connection_pool().await;
2937    let update = proto::UpdateChannels {
2938        channels: vec![channel.to_proto()],
2939        ..Default::default()
2940    };
2941    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2942        if role.can_see_channel(channel.visibility) {
2943            session.peer.send(connection_id, update.clone())?;
2944        }
2945    }
2946
2947    Ok(())
2948}
2949
2950/// Move a channel to a new parent.
2951async fn move_channel(
2952    request: proto::MoveChannel,
2953    response: Response<proto::MoveChannel>,
2954    session: Session,
2955) -> Result<()> {
2956    let channel_id = ChannelId::from_proto(request.channel_id);
2957    let to = ChannelId::from_proto(request.to);
2958
2959    let (root_id, channels) = session
2960        .db()
2961        .await
2962        .move_channel(channel_id, to, session.user_id())
2963        .await?;
2964
2965    let connection_pool = session.connection_pool().await;
2966    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2967        let channels = channels
2968            .iter()
2969            .filter_map(|channel| {
2970                if role.can_see_channel(channel.visibility) {
2971                    Some(channel.to_proto())
2972                } else {
2973                    None
2974                }
2975            })
2976            .collect::<Vec<_>>();
2977        if channels.is_empty() {
2978            continue;
2979        }
2980
2981        let update = proto::UpdateChannels {
2982            channels,
2983            ..Default::default()
2984        };
2985
2986        session.peer.send(connection_id, update.clone())?;
2987    }
2988
2989    response.send(Ack {})?;
2990    Ok(())
2991}
2992
2993/// Get the list of channel members
2994async fn get_channel_members(
2995    request: proto::GetChannelMembers,
2996    response: Response<proto::GetChannelMembers>,
2997    session: Session,
2998) -> Result<()> {
2999    let db = session.db().await;
3000    let channel_id = ChannelId::from_proto(request.channel_id);
3001    let limit = if request.limit == 0 {
3002        u16::MAX as u64
3003    } else {
3004        request.limit
3005    };
3006    let (members, users) = db
3007        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3008        .await?;
3009    response.send(proto::GetChannelMembersResponse { members, users })?;
3010    Ok(())
3011}
3012
3013/// Accept or decline a channel invitation.
3014async fn respond_to_channel_invite(
3015    request: proto::RespondToChannelInvite,
3016    response: Response<proto::RespondToChannelInvite>,
3017    session: Session,
3018) -> Result<()> {
3019    let db = session.db().await;
3020    let channel_id = ChannelId::from_proto(request.channel_id);
3021    let RespondToChannelInvite {
3022        membership_update,
3023        notifications,
3024    } = db
3025        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3026        .await?;
3027
3028    let mut connection_pool = session.connection_pool().await;
3029    if let Some(membership_update) = membership_update {
3030        notify_membership_updated(
3031            &mut connection_pool,
3032            membership_update,
3033            session.user_id(),
3034            &session.peer,
3035        );
3036    } else {
3037        let update = proto::UpdateChannels {
3038            remove_channel_invitations: vec![channel_id.to_proto()],
3039            ..Default::default()
3040        };
3041
3042        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3043            session.peer.send(connection_id, update.clone())?;
3044        }
3045    };
3046
3047    send_notifications(&connection_pool, &session.peer, notifications);
3048
3049    response.send(proto::Ack {})?;
3050
3051    Ok(())
3052}
3053
3054/// Join the channels' room
3055async fn join_channel(
3056    request: proto::JoinChannel,
3057    response: Response<proto::JoinChannel>,
3058    session: Session,
3059) -> Result<()> {
3060    let channel_id = ChannelId::from_proto(request.channel_id);
3061    join_channel_internal(channel_id, Box::new(response), session).await
3062}
3063
3064trait JoinChannelInternalResponse {
3065    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3066}
3067impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3068    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3069        Response::<proto::JoinChannel>::send(self, result)
3070    }
3071}
3072impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3073    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3074        Response::<proto::JoinRoom>::send(self, result)
3075    }
3076}
3077
3078async fn join_channel_internal(
3079    channel_id: ChannelId,
3080    response: Box<impl JoinChannelInternalResponse>,
3081    session: Session,
3082) -> Result<()> {
3083    let joined_room = {
3084        let mut db = session.db().await;
3085        // If zed quits without leaving the room, and the user re-opens zed before the
3086        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3087        // room they were in.
3088        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3089            tracing::info!(
3090                stale_connection_id = %connection,
3091                "cleaning up stale connection",
3092            );
3093            drop(db);
3094            leave_room_for_session(&session, connection).await?;
3095            db = session.db().await;
3096        }
3097
3098        let (joined_room, membership_updated, role) = db
3099            .join_channel(channel_id, session.user_id(), session.connection_id)
3100            .await?;
3101
3102        let live_kit_connection_info =
3103            session
3104                .app_state
3105                .livekit_client
3106                .as_ref()
3107                .and_then(|live_kit| {
3108                    let (can_publish, token) = if role == ChannelRole::Guest {
3109                        (
3110                            false,
3111                            live_kit
3112                                .guest_token(
3113                                    &joined_room.room.livekit_room,
3114                                    &session.user_id().to_string(),
3115                                )
3116                                .trace_err()?,
3117                        )
3118                    } else {
3119                        (
3120                            true,
3121                            live_kit
3122                                .room_token(
3123                                    &joined_room.room.livekit_room,
3124                                    &session.user_id().to_string(),
3125                                )
3126                                .trace_err()?,
3127                        )
3128                    };
3129
3130                    Some(LiveKitConnectionInfo {
3131                        server_url: live_kit.url().into(),
3132                        token,
3133                        can_publish,
3134                    })
3135                });
3136
3137        response.send(proto::JoinRoomResponse {
3138            room: Some(joined_room.room.clone()),
3139            channel_id: joined_room
3140                .channel
3141                .as_ref()
3142                .map(|channel| channel.id.to_proto()),
3143            live_kit_connection_info,
3144        })?;
3145
3146        let mut connection_pool = session.connection_pool().await;
3147        if let Some(membership_updated) = membership_updated {
3148            notify_membership_updated(
3149                &mut connection_pool,
3150                membership_updated,
3151                session.user_id(),
3152                &session.peer,
3153            );
3154        }
3155
3156        room_updated(&joined_room.room, &session.peer);
3157
3158        joined_room
3159    };
3160
3161    channel_updated(
3162        &joined_room
3163            .channel
3164            .ok_or_else(|| anyhow!("channel not returned"))?,
3165        &joined_room.room,
3166        &session.peer,
3167        &*session.connection_pool().await,
3168    );
3169
3170    update_user_contacts(session.user_id(), &session).await?;
3171    Ok(())
3172}
3173
3174/// Start editing the channel notes
3175async fn join_channel_buffer(
3176    request: proto::JoinChannelBuffer,
3177    response: Response<proto::JoinChannelBuffer>,
3178    session: Session,
3179) -> Result<()> {
3180    let db = session.db().await;
3181    let channel_id = ChannelId::from_proto(request.channel_id);
3182
3183    let open_response = db
3184        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3185        .await?;
3186
3187    let collaborators = open_response.collaborators.clone();
3188    response.send(open_response)?;
3189
3190    let update = UpdateChannelBufferCollaborators {
3191        channel_id: channel_id.to_proto(),
3192        collaborators: collaborators.clone(),
3193    };
3194    channel_buffer_updated(
3195        session.connection_id,
3196        collaborators
3197            .iter()
3198            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3199        &update,
3200        &session.peer,
3201    );
3202
3203    Ok(())
3204}
3205
3206/// Edit the channel notes
3207async fn update_channel_buffer(
3208    request: proto::UpdateChannelBuffer,
3209    session: Session,
3210) -> Result<()> {
3211    let db = session.db().await;
3212    let channel_id = ChannelId::from_proto(request.channel_id);
3213
3214    let (collaborators, epoch, version) = db
3215        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3216        .await?;
3217
3218    channel_buffer_updated(
3219        session.connection_id,
3220        collaborators.clone(),
3221        &proto::UpdateChannelBuffer {
3222            channel_id: channel_id.to_proto(),
3223            operations: request.operations,
3224        },
3225        &session.peer,
3226    );
3227
3228    let pool = &*session.connection_pool().await;
3229
3230    let non_collaborators =
3231        pool.channel_connection_ids(channel_id)
3232            .filter_map(|(connection_id, _)| {
3233                if collaborators.contains(&connection_id) {
3234                    None
3235                } else {
3236                    Some(connection_id)
3237                }
3238            });
3239
3240    broadcast(None, non_collaborators, |peer_id| {
3241        session.peer.send(
3242            peer_id,
3243            proto::UpdateChannels {
3244                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3245                    channel_id: channel_id.to_proto(),
3246                    epoch: epoch as u64,
3247                    version: version.clone(),
3248                }],
3249                ..Default::default()
3250            },
3251        )
3252    });
3253
3254    Ok(())
3255}
3256
3257/// Rejoin the channel notes after a connection blip
3258async fn rejoin_channel_buffers(
3259    request: proto::RejoinChannelBuffers,
3260    response: Response<proto::RejoinChannelBuffers>,
3261    session: Session,
3262) -> Result<()> {
3263    let db = session.db().await;
3264    let buffers = db
3265        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3266        .await?;
3267
3268    for rejoined_buffer in &buffers {
3269        let collaborators_to_notify = rejoined_buffer
3270            .buffer
3271            .collaborators
3272            .iter()
3273            .filter_map(|c| Some(c.peer_id?.into()));
3274        channel_buffer_updated(
3275            session.connection_id,
3276            collaborators_to_notify,
3277            &proto::UpdateChannelBufferCollaborators {
3278                channel_id: rejoined_buffer.buffer.channel_id,
3279                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3280            },
3281            &session.peer,
3282        );
3283    }
3284
3285    response.send(proto::RejoinChannelBuffersResponse {
3286        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3287    })?;
3288
3289    Ok(())
3290}
3291
3292/// Stop editing the channel notes
3293async fn leave_channel_buffer(
3294    request: proto::LeaveChannelBuffer,
3295    response: Response<proto::LeaveChannelBuffer>,
3296    session: Session,
3297) -> Result<()> {
3298    let db = session.db().await;
3299    let channel_id = ChannelId::from_proto(request.channel_id);
3300
3301    let left_buffer = db
3302        .leave_channel_buffer(channel_id, session.connection_id)
3303        .await?;
3304
3305    response.send(Ack {})?;
3306
3307    channel_buffer_updated(
3308        session.connection_id,
3309        left_buffer.connections,
3310        &proto::UpdateChannelBufferCollaborators {
3311            channel_id: channel_id.to_proto(),
3312            collaborators: left_buffer.collaborators,
3313        },
3314        &session.peer,
3315    );
3316
3317    Ok(())
3318}
3319
3320fn channel_buffer_updated<T: EnvelopedMessage>(
3321    sender_id: ConnectionId,
3322    collaborators: impl IntoIterator<Item = ConnectionId>,
3323    message: &T,
3324    peer: &Peer,
3325) {
3326    broadcast(Some(sender_id), collaborators, |peer_id| {
3327        peer.send(peer_id, message.clone())
3328    });
3329}
3330
3331fn send_notifications(
3332    connection_pool: &ConnectionPool,
3333    peer: &Peer,
3334    notifications: db::NotificationBatch,
3335) {
3336    for (user_id, notification) in notifications {
3337        for connection_id in connection_pool.user_connection_ids(user_id) {
3338            if let Err(error) = peer.send(
3339                connection_id,
3340                proto::AddNotification {
3341                    notification: Some(notification.clone()),
3342                },
3343            ) {
3344                tracing::error!(
3345                    "failed to send notification to {:?} {}",
3346                    connection_id,
3347                    error
3348                );
3349            }
3350        }
3351    }
3352}
3353
3354/// Send a message to the channel
3355async fn send_channel_message(
3356    request: proto::SendChannelMessage,
3357    response: Response<proto::SendChannelMessage>,
3358    session: Session,
3359) -> Result<()> {
3360    // Validate the message body.
3361    let body = request.body.trim().to_string();
3362    if body.len() > MAX_MESSAGE_LEN {
3363        return Err(anyhow!("message is too long"))?;
3364    }
3365    if body.is_empty() {
3366        return Err(anyhow!("message can't be blank"))?;
3367    }
3368
3369    // TODO: adjust mentions if body is trimmed
3370
3371    let timestamp = OffsetDateTime::now_utc();
3372    let nonce = request
3373        .nonce
3374        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3375
3376    let channel_id = ChannelId::from_proto(request.channel_id);
3377    let CreatedChannelMessage {
3378        message_id,
3379        participant_connection_ids,
3380        notifications,
3381    } = session
3382        .db()
3383        .await
3384        .create_channel_message(
3385            channel_id,
3386            session.user_id(),
3387            &body,
3388            &request.mentions,
3389            timestamp,
3390            nonce.clone().into(),
3391            request.reply_to_message_id.map(MessageId::from_proto),
3392        )
3393        .await?;
3394
3395    let message = proto::ChannelMessage {
3396        sender_id: session.user_id().to_proto(),
3397        id: message_id.to_proto(),
3398        body,
3399        mentions: request.mentions,
3400        timestamp: timestamp.unix_timestamp() as u64,
3401        nonce: Some(nonce),
3402        reply_to_message_id: request.reply_to_message_id,
3403        edited_at: None,
3404    };
3405    broadcast(
3406        Some(session.connection_id),
3407        participant_connection_ids.clone(),
3408        |connection| {
3409            session.peer.send(
3410                connection,
3411                proto::ChannelMessageSent {
3412                    channel_id: channel_id.to_proto(),
3413                    message: Some(message.clone()),
3414                },
3415            )
3416        },
3417    );
3418    response.send(proto::SendChannelMessageResponse {
3419        message: Some(message),
3420    })?;
3421
3422    let pool = &*session.connection_pool().await;
3423    let non_participants =
3424        pool.channel_connection_ids(channel_id)
3425            .filter_map(|(connection_id, _)| {
3426                if participant_connection_ids.contains(&connection_id) {
3427                    None
3428                } else {
3429                    Some(connection_id)
3430                }
3431            });
3432    broadcast(None, non_participants, |peer_id| {
3433        session.peer.send(
3434            peer_id,
3435            proto::UpdateChannels {
3436                latest_channel_message_ids: vec![proto::ChannelMessageId {
3437                    channel_id: channel_id.to_proto(),
3438                    message_id: message_id.to_proto(),
3439                }],
3440                ..Default::default()
3441            },
3442        )
3443    });
3444    send_notifications(pool, &session.peer, notifications);
3445
3446    Ok(())
3447}
3448
3449/// Delete a channel message
3450async fn remove_channel_message(
3451    request: proto::RemoveChannelMessage,
3452    response: Response<proto::RemoveChannelMessage>,
3453    session: Session,
3454) -> Result<()> {
3455    let channel_id = ChannelId::from_proto(request.channel_id);
3456    let message_id = MessageId::from_proto(request.message_id);
3457    let (connection_ids, existing_notification_ids) = session
3458        .db()
3459        .await
3460        .remove_channel_message(channel_id, message_id, session.user_id())
3461        .await?;
3462
3463    broadcast(
3464        Some(session.connection_id),
3465        connection_ids,
3466        move |connection| {
3467            session.peer.send(connection, request.clone())?;
3468
3469            for notification_id in &existing_notification_ids {
3470                session.peer.send(
3471                    connection,
3472                    proto::DeleteNotification {
3473                        notification_id: (*notification_id).to_proto(),
3474                    },
3475                )?;
3476            }
3477
3478            Ok(())
3479        },
3480    );
3481    response.send(proto::Ack {})?;
3482    Ok(())
3483}
3484
3485async fn update_channel_message(
3486    request: proto::UpdateChannelMessage,
3487    response: Response<proto::UpdateChannelMessage>,
3488    session: Session,
3489) -> Result<()> {
3490    let channel_id = ChannelId::from_proto(request.channel_id);
3491    let message_id = MessageId::from_proto(request.message_id);
3492    let updated_at = OffsetDateTime::now_utc();
3493    let UpdatedChannelMessage {
3494        message_id,
3495        participant_connection_ids,
3496        notifications,
3497        reply_to_message_id,
3498        timestamp,
3499        deleted_mention_notification_ids,
3500        updated_mention_notifications,
3501    } = session
3502        .db()
3503        .await
3504        .update_channel_message(
3505            channel_id,
3506            message_id,
3507            session.user_id(),
3508            request.body.as_str(),
3509            &request.mentions,
3510            updated_at,
3511        )
3512        .await?;
3513
3514    let nonce = request
3515        .nonce
3516        .clone()
3517        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3518
3519    let message = proto::ChannelMessage {
3520        sender_id: session.user_id().to_proto(),
3521        id: message_id.to_proto(),
3522        body: request.body.clone(),
3523        mentions: request.mentions.clone(),
3524        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3525        nonce: Some(nonce),
3526        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3527        edited_at: Some(updated_at.unix_timestamp() as u64),
3528    };
3529
3530    response.send(proto::Ack {})?;
3531
3532    let pool = &*session.connection_pool().await;
3533    broadcast(
3534        Some(session.connection_id),
3535        participant_connection_ids,
3536        |connection| {
3537            session.peer.send(
3538                connection,
3539                proto::ChannelMessageUpdate {
3540                    channel_id: channel_id.to_proto(),
3541                    message: Some(message.clone()),
3542                },
3543            )?;
3544
3545            for notification_id in &deleted_mention_notification_ids {
3546                session.peer.send(
3547                    connection,
3548                    proto::DeleteNotification {
3549                        notification_id: (*notification_id).to_proto(),
3550                    },
3551                )?;
3552            }
3553
3554            for notification in &updated_mention_notifications {
3555                session.peer.send(
3556                    connection,
3557                    proto::UpdateNotification {
3558                        notification: Some(notification.clone()),
3559                    },
3560                )?;
3561            }
3562
3563            Ok(())
3564        },
3565    );
3566
3567    send_notifications(pool, &session.peer, notifications);
3568
3569    Ok(())
3570}
3571
3572/// Mark a channel message as read
3573async fn acknowledge_channel_message(
3574    request: proto::AckChannelMessage,
3575    session: Session,
3576) -> Result<()> {
3577    let channel_id = ChannelId::from_proto(request.channel_id);
3578    let message_id = MessageId::from_proto(request.message_id);
3579    let notifications = session
3580        .db()
3581        .await
3582        .observe_channel_message(channel_id, session.user_id(), message_id)
3583        .await?;
3584    send_notifications(
3585        &*session.connection_pool().await,
3586        &session.peer,
3587        notifications,
3588    );
3589    Ok(())
3590}
3591
3592/// Mark a buffer version as synced
3593async fn acknowledge_buffer_version(
3594    request: proto::AckBufferOperation,
3595    session: Session,
3596) -> Result<()> {
3597    let buffer_id = BufferId::from_proto(request.buffer_id);
3598    session
3599        .db()
3600        .await
3601        .observe_buffer_version(
3602            buffer_id,
3603            session.user_id(),
3604            request.epoch as i32,
3605            &request.version,
3606        )
3607        .await?;
3608    Ok(())
3609}
3610
3611async fn count_language_model_tokens(
3612    request: proto::CountLanguageModelTokens,
3613    response: Response<proto::CountLanguageModelTokens>,
3614    session: Session,
3615    config: &Config,
3616) -> Result<()> {
3617    authorize_access_to_legacy_llm_endpoints(&session).await?;
3618
3619    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3620        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3621        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3622    };
3623
3624    session
3625        .app_state
3626        .rate_limiter
3627        .check(&*rate_limit, session.user_id())
3628        .await?;
3629
3630    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3631        Some(proto::LanguageModelProvider::Google) => {
3632            let api_key = config
3633                .google_ai_api_key
3634                .as_ref()
3635                .context("no Google AI API key configured on the server")?;
3636            google_ai::count_tokens(
3637                session.http_client.as_ref(),
3638                google_ai::API_URL,
3639                api_key,
3640                serde_json::from_str(&request.request)?,
3641            )
3642            .await?
3643        }
3644        _ => return Err(anyhow!("unsupported provider"))?,
3645    };
3646
3647    response.send(proto::CountLanguageModelTokensResponse {
3648        token_count: result.total_tokens as u32,
3649    })?;
3650
3651    Ok(())
3652}
3653
3654struct ZedProCountLanguageModelTokensRateLimit;
3655
3656impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3657    fn capacity(&self) -> usize {
3658        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3659            .ok()
3660            .and_then(|v| v.parse().ok())
3661            .unwrap_or(600) // Picked arbitrarily
3662    }
3663
3664    fn refill_duration(&self) -> chrono::Duration {
3665        chrono::Duration::hours(1)
3666    }
3667
3668    fn db_name(&self) -> &'static str {
3669        "zed-pro:count-language-model-tokens"
3670    }
3671}
3672
3673struct FreeCountLanguageModelTokensRateLimit;
3674
3675impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3676    fn capacity(&self) -> usize {
3677        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3678            .ok()
3679            .and_then(|v| v.parse().ok())
3680            .unwrap_or(600 / 10) // Picked arbitrarily
3681    }
3682
3683    fn refill_duration(&self) -> chrono::Duration {
3684        chrono::Duration::hours(1)
3685    }
3686
3687    fn db_name(&self) -> &'static str {
3688        "free:count-language-model-tokens"
3689    }
3690}
3691
3692struct ZedProComputeEmbeddingsRateLimit;
3693
3694impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3695    fn capacity(&self) -> usize {
3696        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3697            .ok()
3698            .and_then(|v| v.parse().ok())
3699            .unwrap_or(5000) // Picked arbitrarily
3700    }
3701
3702    fn refill_duration(&self) -> chrono::Duration {
3703        chrono::Duration::hours(1)
3704    }
3705
3706    fn db_name(&self) -> &'static str {
3707        "zed-pro:compute-embeddings"
3708    }
3709}
3710
3711struct FreeComputeEmbeddingsRateLimit;
3712
3713impl RateLimit for FreeComputeEmbeddingsRateLimit {
3714    fn capacity(&self) -> usize {
3715        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3716            .ok()
3717            .and_then(|v| v.parse().ok())
3718            .unwrap_or(5000 / 10) // Picked arbitrarily
3719    }
3720
3721    fn refill_duration(&self) -> chrono::Duration {
3722        chrono::Duration::hours(1)
3723    }
3724
3725    fn db_name(&self) -> &'static str {
3726        "free:compute-embeddings"
3727    }
3728}
3729
3730async fn compute_embeddings(
3731    request: proto::ComputeEmbeddings,
3732    response: Response<proto::ComputeEmbeddings>,
3733    session: Session,
3734    api_key: Option<Arc<str>>,
3735) -> Result<()> {
3736    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3737    authorize_access_to_legacy_llm_endpoints(&session).await?;
3738
3739    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3740        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3741        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3742    };
3743
3744    session
3745        .app_state
3746        .rate_limiter
3747        .check(&*rate_limit, session.user_id())
3748        .await?;
3749
3750    let embeddings = match request.model.as_str() {
3751        "openai/text-embedding-3-small" => {
3752            open_ai::embed(
3753                session.http_client.as_ref(),
3754                OPEN_AI_API_URL,
3755                &api_key,
3756                OpenAiEmbeddingModel::TextEmbedding3Small,
3757                request.texts.iter().map(|text| text.as_str()),
3758            )
3759            .await?
3760        }
3761        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3762    };
3763
3764    let embeddings = request
3765        .texts
3766        .iter()
3767        .map(|text| {
3768            let mut hasher = sha2::Sha256::new();
3769            hasher.update(text.as_bytes());
3770            let result = hasher.finalize();
3771            result.to_vec()
3772        })
3773        .zip(
3774            embeddings
3775                .data
3776                .into_iter()
3777                .map(|embedding| embedding.embedding),
3778        )
3779        .collect::<HashMap<_, _>>();
3780
3781    let db = session.db().await;
3782    db.save_embeddings(&request.model, &embeddings)
3783        .await
3784        .context("failed to save embeddings")
3785        .trace_err();
3786
3787    response.send(proto::ComputeEmbeddingsResponse {
3788        embeddings: embeddings
3789            .into_iter()
3790            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3791            .collect(),
3792    })?;
3793    Ok(())
3794}
3795
3796async fn get_cached_embeddings(
3797    request: proto::GetCachedEmbeddings,
3798    response: Response<proto::GetCachedEmbeddings>,
3799    session: Session,
3800) -> Result<()> {
3801    authorize_access_to_legacy_llm_endpoints(&session).await?;
3802
3803    let db = session.db().await;
3804    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3805
3806    response.send(proto::GetCachedEmbeddingsResponse {
3807        embeddings: embeddings
3808            .into_iter()
3809            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3810            .collect(),
3811    })?;
3812    Ok(())
3813}
3814
3815/// This is leftover from before the LLM service.
3816///
3817/// The endpoints protected by this check will be moved there eventually.
3818async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3819    if session.is_staff() {
3820        Ok(())
3821    } else {
3822        Err(anyhow!("permission denied"))?
3823    }
3824}
3825
3826/// Get a Supermaven API key for the user
3827async fn get_supermaven_api_key(
3828    _request: proto::GetSupermavenApiKey,
3829    response: Response<proto::GetSupermavenApiKey>,
3830    session: Session,
3831) -> Result<()> {
3832    let user_id: String = session.user_id().to_string();
3833    if !session.is_staff() {
3834        return Err(anyhow!("supermaven not enabled for this account"))?;
3835    }
3836
3837    let email = session
3838        .email()
3839        .ok_or_else(|| anyhow!("user must have an email"))?;
3840
3841    let supermaven_admin_api = session
3842        .supermaven_client
3843        .as_ref()
3844        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3845
3846    let result = supermaven_admin_api
3847        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3848        .await?;
3849
3850    response.send(proto::GetSupermavenApiKeyResponse {
3851        api_key: result.api_key,
3852    })?;
3853
3854    Ok(())
3855}
3856
3857/// Start receiving chat updates for a channel
3858async fn join_channel_chat(
3859    request: proto::JoinChannelChat,
3860    response: Response<proto::JoinChannelChat>,
3861    session: Session,
3862) -> Result<()> {
3863    let channel_id = ChannelId::from_proto(request.channel_id);
3864
3865    let db = session.db().await;
3866    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3867        .await?;
3868    let messages = db
3869        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3870        .await?;
3871    response.send(proto::JoinChannelChatResponse {
3872        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3873        messages,
3874    })?;
3875    Ok(())
3876}
3877
3878/// Stop receiving chat updates for a channel
3879async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3880    let channel_id = ChannelId::from_proto(request.channel_id);
3881    session
3882        .db()
3883        .await
3884        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3885        .await?;
3886    Ok(())
3887}
3888
3889/// Retrieve the chat history for a channel
3890async fn get_channel_messages(
3891    request: proto::GetChannelMessages,
3892    response: Response<proto::GetChannelMessages>,
3893    session: Session,
3894) -> Result<()> {
3895    let channel_id = ChannelId::from_proto(request.channel_id);
3896    let messages = session
3897        .db()
3898        .await
3899        .get_channel_messages(
3900            channel_id,
3901            session.user_id(),
3902            MESSAGE_COUNT_PER_PAGE,
3903            Some(MessageId::from_proto(request.before_message_id)),
3904        )
3905        .await?;
3906    response.send(proto::GetChannelMessagesResponse {
3907        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3908        messages,
3909    })?;
3910    Ok(())
3911}
3912
3913/// Retrieve specific chat messages
3914async fn get_channel_messages_by_id(
3915    request: proto::GetChannelMessagesById,
3916    response: Response<proto::GetChannelMessagesById>,
3917    session: Session,
3918) -> Result<()> {
3919    let message_ids = request
3920        .message_ids
3921        .iter()
3922        .map(|id| MessageId::from_proto(*id))
3923        .collect::<Vec<_>>();
3924    let messages = session
3925        .db()
3926        .await
3927        .get_channel_messages_by_id(session.user_id(), &message_ids)
3928        .await?;
3929    response.send(proto::GetChannelMessagesResponse {
3930        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3931        messages,
3932    })?;
3933    Ok(())
3934}
3935
3936/// Retrieve the current users notifications
3937async fn get_notifications(
3938    request: proto::GetNotifications,
3939    response: Response<proto::GetNotifications>,
3940    session: Session,
3941) -> Result<()> {
3942    let notifications = session
3943        .db()
3944        .await
3945        .get_notifications(
3946            session.user_id(),
3947            NOTIFICATION_COUNT_PER_PAGE,
3948            request.before_id.map(db::NotificationId::from_proto),
3949        )
3950        .await?;
3951    response.send(proto::GetNotificationsResponse {
3952        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3953        notifications,
3954    })?;
3955    Ok(())
3956}
3957
3958/// Mark notifications as read
3959async fn mark_notification_as_read(
3960    request: proto::MarkNotificationRead,
3961    response: Response<proto::MarkNotificationRead>,
3962    session: Session,
3963) -> Result<()> {
3964    let database = &session.db().await;
3965    let notifications = database
3966        .mark_notification_as_read_by_id(
3967            session.user_id(),
3968            NotificationId::from_proto(request.notification_id),
3969        )
3970        .await?;
3971    send_notifications(
3972        &*session.connection_pool().await,
3973        &session.peer,
3974        notifications,
3975    );
3976    response.send(proto::Ack {})?;
3977    Ok(())
3978}
3979
3980/// Get the current users information
3981async fn get_private_user_info(
3982    _request: proto::GetPrivateUserInfo,
3983    response: Response<proto::GetPrivateUserInfo>,
3984    session: Session,
3985) -> Result<()> {
3986    let db = session.db().await;
3987
3988    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3989    let user = db
3990        .get_user_by_id(session.user_id())
3991        .await?
3992        .ok_or_else(|| anyhow!("user not found"))?;
3993    let flags = db.get_user_flags(session.user_id()).await?;
3994
3995    response.send(proto::GetPrivateUserInfoResponse {
3996        metrics_id,
3997        staff: user.admin,
3998        flags,
3999        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4000    })?;
4001    Ok(())
4002}
4003
4004/// Accept the terms of service (tos) on behalf of the current user
4005async fn accept_terms_of_service(
4006    _request: proto::AcceptTermsOfService,
4007    response: Response<proto::AcceptTermsOfService>,
4008    session: Session,
4009) -> Result<()> {
4010    let db = session.db().await;
4011
4012    let accepted_tos_at = Utc::now();
4013    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4014        .await?;
4015
4016    response.send(proto::AcceptTermsOfServiceResponse {
4017        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4018    })?;
4019    Ok(())
4020}
4021
4022/// The minimum account age an account must have in order to use the LLM service.
4023const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4024
4025async fn get_llm_api_token(
4026    _request: proto::GetLlmToken,
4027    response: Response<proto::GetLlmToken>,
4028    session: Session,
4029) -> Result<()> {
4030    let db = session.db().await;
4031
4032    let flags = db.get_user_flags(session.user_id()).await?;
4033    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4034    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4035    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4036
4037    if !session.is_staff() && !has_language_models_feature_flag {
4038        Err(anyhow!("permission denied"))?
4039    }
4040
4041    let user_id = session.user_id();
4042    let user = db
4043        .get_user_by_id(user_id)
4044        .await?
4045        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4046
4047    if user.accepted_tos_at.is_none() {
4048        Err(anyhow!("terms of service not accepted"))?
4049    }
4050
4051    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4052
4053    let bypass_account_age_check =
4054        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4055    if !bypass_account_age_check {
4056        let mut account_created_at = user.created_at;
4057        if let Some(github_created_at) = user.github_user_created_at {
4058            account_created_at = account_created_at.min(github_created_at);
4059        }
4060        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4061            Err(anyhow!("account too young"))?
4062        }
4063    }
4064
4065    let billing_preferences = db.get_billing_preferences(user.id).await?;
4066
4067    let token = LlmTokenClaims::create(
4068        &user,
4069        session.is_staff(),
4070        billing_preferences,
4071        has_llm_closed_beta_feature_flag,
4072        has_predict_edits_feature_flag,
4073        has_llm_subscription,
4074        session.current_plan(&db).await?,
4075        session.system_id.clone(),
4076        &session.app_state.config,
4077    )?;
4078    response.send(proto::GetLlmTokenResponse { token })?;
4079    Ok(())
4080}
4081
4082fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4083    let message = match message {
4084        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4085        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4086        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4087        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4088        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4089            code: frame.code.into(),
4090            reason: frame.reason,
4091        })),
4092        // We should never receive a frame while reading the message, according
4093        // to the `tungstenite` maintainers:
4094        //
4095        // > It cannot occur when you read messages from the WebSocket, but it
4096        // > can be used when you want to send the raw frames (e.g. you want to
4097        // > send the frames to the WebSocket without composing the full message first).
4098        // >
4099        // > — https://github.com/snapview/tungstenite-rs/issues/268
4100        TungsteniteMessage::Frame(_) => {
4101            bail!("received an unexpected frame while reading the message")
4102        }
4103    };
4104
4105    Ok(message)
4106}
4107
4108fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4109    match message {
4110        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4111        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4112        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4113        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4114        AxumMessage::Close(frame) => {
4115            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4116                code: frame.code.into(),
4117                reason: frame.reason,
4118            }))
4119        }
4120    }
4121}
4122
4123fn notify_membership_updated(
4124    connection_pool: &mut ConnectionPool,
4125    result: MembershipUpdated,
4126    user_id: UserId,
4127    peer: &Peer,
4128) {
4129    for membership in &result.new_channels.channel_memberships {
4130        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4131    }
4132    for channel_id in &result.removed_channels {
4133        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4134    }
4135
4136    let user_channels_update = proto::UpdateUserChannels {
4137        channel_memberships: result
4138            .new_channels
4139            .channel_memberships
4140            .iter()
4141            .map(|cm| proto::ChannelMembership {
4142                channel_id: cm.channel_id.to_proto(),
4143                role: cm.role.into(),
4144            })
4145            .collect(),
4146        ..Default::default()
4147    };
4148
4149    let mut update = build_channels_update(result.new_channels);
4150    update.delete_channels = result
4151        .removed_channels
4152        .into_iter()
4153        .map(|id| id.to_proto())
4154        .collect();
4155    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4156
4157    for connection_id in connection_pool.user_connection_ids(user_id) {
4158        peer.send(connection_id, user_channels_update.clone())
4159            .trace_err();
4160        peer.send(connection_id, update.clone()).trace_err();
4161    }
4162}
4163
4164fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4165    proto::UpdateUserChannels {
4166        channel_memberships: channels
4167            .channel_memberships
4168            .iter()
4169            .map(|m| proto::ChannelMembership {
4170                channel_id: m.channel_id.to_proto(),
4171                role: m.role.into(),
4172            })
4173            .collect(),
4174        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4175        observed_channel_message_id: channels.observed_channel_messages.clone(),
4176    }
4177}
4178
4179fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4180    let mut update = proto::UpdateChannels::default();
4181
4182    for channel in channels.channels {
4183        update.channels.push(channel.to_proto());
4184    }
4185
4186    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4187    update.latest_channel_message_ids = channels.latest_channel_messages;
4188
4189    for (channel_id, participants) in channels.channel_participants {
4190        update
4191            .channel_participants
4192            .push(proto::ChannelParticipants {
4193                channel_id: channel_id.to_proto(),
4194                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4195            });
4196    }
4197
4198    for channel in channels.invited_channels {
4199        update.channel_invitations.push(channel.to_proto());
4200    }
4201
4202    update
4203}
4204
4205fn build_initial_contacts_update(
4206    contacts: Vec<db::Contact>,
4207    pool: &ConnectionPool,
4208) -> proto::UpdateContacts {
4209    let mut update = proto::UpdateContacts::default();
4210
4211    for contact in contacts {
4212        match contact {
4213            db::Contact::Accepted { user_id, busy } => {
4214                update.contacts.push(contact_for_user(user_id, busy, pool));
4215            }
4216            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4217            db::Contact::Incoming { user_id } => {
4218                update
4219                    .incoming_requests
4220                    .push(proto::IncomingContactRequest {
4221                        requester_id: user_id.to_proto(),
4222                    })
4223            }
4224        }
4225    }
4226
4227    update
4228}
4229
4230fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4231    proto::Contact {
4232        user_id: user_id.to_proto(),
4233        online: pool.is_user_online(user_id),
4234        busy,
4235    }
4236}
4237
4238fn room_updated(room: &proto::Room, peer: &Peer) {
4239    broadcast(
4240        None,
4241        room.participants
4242            .iter()
4243            .filter_map(|participant| Some(participant.peer_id?.into())),
4244        |peer_id| {
4245            peer.send(
4246                peer_id,
4247                proto::RoomUpdated {
4248                    room: Some(room.clone()),
4249                },
4250            )
4251        },
4252    );
4253}
4254
4255fn channel_updated(
4256    channel: &db::channel::Model,
4257    room: &proto::Room,
4258    peer: &Peer,
4259    pool: &ConnectionPool,
4260) {
4261    let participants = room
4262        .participants
4263        .iter()
4264        .map(|p| p.user_id)
4265        .collect::<Vec<_>>();
4266
4267    broadcast(
4268        None,
4269        pool.channel_connection_ids(channel.root_id())
4270            .filter_map(|(channel_id, role)| {
4271                role.can_see_channel(channel.visibility)
4272                    .then_some(channel_id)
4273            }),
4274        |peer_id| {
4275            peer.send(
4276                peer_id,
4277                proto::UpdateChannels {
4278                    channel_participants: vec![proto::ChannelParticipants {
4279                        channel_id: channel.id.to_proto(),
4280                        participant_user_ids: participants.clone(),
4281                    }],
4282                    ..Default::default()
4283                },
4284            )
4285        },
4286    );
4287}
4288
4289async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4290    let db = session.db().await;
4291
4292    let contacts = db.get_contacts(user_id).await?;
4293    let busy = db.is_user_busy(user_id).await?;
4294
4295    let pool = session.connection_pool().await;
4296    let updated_contact = contact_for_user(user_id, busy, &pool);
4297    for contact in contacts {
4298        if let db::Contact::Accepted {
4299            user_id: contact_user_id,
4300            ..
4301        } = contact
4302        {
4303            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4304                session
4305                    .peer
4306                    .send(
4307                        contact_conn_id,
4308                        proto::UpdateContacts {
4309                            contacts: vec![updated_contact.clone()],
4310                            remove_contacts: Default::default(),
4311                            incoming_requests: Default::default(),
4312                            remove_incoming_requests: Default::default(),
4313                            outgoing_requests: Default::default(),
4314                            remove_outgoing_requests: Default::default(),
4315                        },
4316                    )
4317                    .trace_err();
4318            }
4319        }
4320    }
4321    Ok(())
4322}
4323
4324async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4325    let mut contacts_to_update = HashSet::default();
4326
4327    let room_id;
4328    let canceled_calls_to_user_ids;
4329    let livekit_room;
4330    let delete_livekit_room;
4331    let room;
4332    let channel;
4333
4334    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4335        contacts_to_update.insert(session.user_id());
4336
4337        for project in left_room.left_projects.values() {
4338            project_left(project, session);
4339        }
4340
4341        room_id = RoomId::from_proto(left_room.room.id);
4342        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4343        livekit_room = mem::take(&mut left_room.room.livekit_room);
4344        delete_livekit_room = left_room.deleted;
4345        room = mem::take(&mut left_room.room);
4346        channel = mem::take(&mut left_room.channel);
4347
4348        room_updated(&room, &session.peer);
4349    } else {
4350        return Ok(());
4351    }
4352
4353    if let Some(channel) = channel {
4354        channel_updated(
4355            &channel,
4356            &room,
4357            &session.peer,
4358            &*session.connection_pool().await,
4359        );
4360    }
4361
4362    {
4363        let pool = session.connection_pool().await;
4364        for canceled_user_id in canceled_calls_to_user_ids {
4365            for connection_id in pool.user_connection_ids(canceled_user_id) {
4366                session
4367                    .peer
4368                    .send(
4369                        connection_id,
4370                        proto::CallCanceled {
4371                            room_id: room_id.to_proto(),
4372                        },
4373                    )
4374                    .trace_err();
4375            }
4376            contacts_to_update.insert(canceled_user_id);
4377        }
4378    }
4379
4380    for contact_user_id in contacts_to_update {
4381        update_user_contacts(contact_user_id, session).await?;
4382    }
4383
4384    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4385        live_kit
4386            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4387            .await
4388            .trace_err();
4389
4390        if delete_livekit_room {
4391            live_kit.delete_room(livekit_room).await.trace_err();
4392        }
4393    }
4394
4395    Ok(())
4396}
4397
4398async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4399    let left_channel_buffers = session
4400        .db()
4401        .await
4402        .leave_channel_buffers(session.connection_id)
4403        .await?;
4404
4405    for left_buffer in left_channel_buffers {
4406        channel_buffer_updated(
4407            session.connection_id,
4408            left_buffer.connections,
4409            &proto::UpdateChannelBufferCollaborators {
4410                channel_id: left_buffer.channel_id.to_proto(),
4411                collaborators: left_buffer.collaborators,
4412            },
4413            &session.peer,
4414        );
4415    }
4416
4417    Ok(())
4418}
4419
4420fn project_left(project: &db::LeftProject, session: &Session) {
4421    for connection_id in &project.connection_ids {
4422        if project.should_unshare {
4423            session
4424                .peer
4425                .send(
4426                    *connection_id,
4427                    proto::UnshareProject {
4428                        project_id: project.id.to_proto(),
4429                    },
4430                )
4431                .trace_err();
4432        } else {
4433            session
4434                .peer
4435                .send(
4436                    *connection_id,
4437                    proto::RemoveProjectCollaborator {
4438                        project_id: project.id.to_proto(),
4439                        peer_id: Some(session.connection_id.into()),
4440                    },
4441                )
4442                .trace_err();
4443        }
4444    }
4445}
4446
4447pub trait ResultExt {
4448    type Ok;
4449
4450    fn trace_err(self) -> Option<Self::Ok>;
4451}
4452
4453impl<T, E> ResultExt for Result<T, E>
4454where
4455    E: std::fmt::Debug,
4456{
4457    type Ok = T;
4458
4459    #[track_caller]
4460    fn trace_err(self) -> Option<T> {
4461        match self {
4462            Ok(value) => Some(value),
4463            Err(error) => {
4464                tracing::error!("{:?}", error);
4465                None
4466            }
4467        }
4468    }
4469}