rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
 313            .add_request_handler(
 314                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 315            )
 316            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 317            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 318            .add_request_handler(
 319                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 320            )
 321            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 322            .add_request_handler(
 323                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 324            )
 325            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 326            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 327            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 330            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 332            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 337            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 338            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 339            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 340            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 341            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 342            .add_message_handler(create_buffer_for_peer)
 343            .add_request_handler(update_buffer)
 344            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 345            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 346            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 347            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 349            .add_request_handler(get_users)
 350            .add_request_handler(fuzzy_search_users)
 351            .add_request_handler(request_contact)
 352            .add_request_handler(remove_contact)
 353            .add_request_handler(respond_to_contact_request)
 354            .add_message_handler(subscribe_to_channels)
 355            .add_request_handler(create_channel)
 356            .add_request_handler(delete_channel)
 357            .add_request_handler(invite_channel_member)
 358            .add_request_handler(remove_channel_member)
 359            .add_request_handler(set_channel_member_role)
 360            .add_request_handler(set_channel_visibility)
 361            .add_request_handler(rename_channel)
 362            .add_request_handler(join_channel_buffer)
 363            .add_request_handler(leave_channel_buffer)
 364            .add_message_handler(update_channel_buffer)
 365            .add_request_handler(rejoin_channel_buffers)
 366            .add_request_handler(get_channel_members)
 367            .add_request_handler(respond_to_channel_invite)
 368            .add_request_handler(join_channel)
 369            .add_request_handler(join_channel_chat)
 370            .add_message_handler(leave_channel_chat)
 371            .add_request_handler(send_channel_message)
 372            .add_request_handler(remove_channel_message)
 373            .add_request_handler(update_channel_message)
 374            .add_request_handler(get_channel_messages)
 375            .add_request_handler(get_channel_messages_by_id)
 376            .add_request_handler(get_notifications)
 377            .add_request_handler(mark_notification_as_read)
 378            .add_request_handler(move_channel)
 379            .add_request_handler(follow)
 380            .add_message_handler(unfollow)
 381            .add_message_handler(update_followers)
 382            .add_request_handler(get_private_user_info)
 383            .add_request_handler(get_llm_api_token)
 384            .add_request_handler(accept_terms_of_service)
 385            .add_message_handler(acknowledge_channel_message)
 386            .add_message_handler(acknowledge_buffer_version)
 387            .add_request_handler(get_supermaven_api_key)
 388            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 389            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 390            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 391            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 392            .add_message_handler(update_context)
 393            .add_request_handler({
 394                let app_state = app_state.clone();
 395                move |request, response, session| {
 396                    let app_state = app_state.clone();
 397                    async move {
 398                        count_language_model_tokens(request, response, session, &app_state.config)
 399                            .await
 400                    }
 401                }
 402            })
 403            .add_request_handler(get_cached_embeddings)
 404            .add_request_handler({
 405                let app_state = app_state.clone();
 406                move |request, response, session| {
 407                    compute_embeddings(
 408                        request,
 409                        response,
 410                        session,
 411                        app_state.config.openai_api_key.clone(),
 412                    )
 413                }
 414            });
 415
 416        Arc::new(server)
 417    }
 418
 419    pub async fn start(&self) -> Result<()> {
 420        let server_id = *self.id.lock();
 421        let app_state = self.app_state.clone();
 422        let peer = self.peer.clone();
 423        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 424        let pool = self.connection_pool.clone();
 425        let livekit_client = self.app_state.livekit_client.clone();
 426
 427        let span = info_span!("start server");
 428        self.app_state.executor.spawn_detached(
 429            async move {
 430                tracing::info!("waiting for cleanup timeout");
 431                timeout.await;
 432                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 433                if let Some((room_ids, channel_ids)) = app_state
 434                    .db
 435                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 436                    .await
 437                    .trace_err()
 438                {
 439                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 440                    tracing::info!(
 441                        stale_channel_buffer_count = channel_ids.len(),
 442                        "retrieved stale channel buffers"
 443                    );
 444
 445                    for channel_id in channel_ids {
 446                        if let Some(refreshed_channel_buffer) = app_state
 447                            .db
 448                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 449                            .await
 450                            .trace_err()
 451                        {
 452                            for connection_id in refreshed_channel_buffer.connection_ids {
 453                                peer.send(
 454                                    connection_id,
 455                                    proto::UpdateChannelBufferCollaborators {
 456                                        channel_id: channel_id.to_proto(),
 457                                        collaborators: refreshed_channel_buffer
 458                                            .collaborators
 459                                            .clone(),
 460                                    },
 461                                )
 462                                .trace_err();
 463                            }
 464                        }
 465                    }
 466
 467                    for room_id in room_ids {
 468                        let mut contacts_to_update = HashSet::default();
 469                        let mut canceled_calls_to_user_ids = Vec::new();
 470                        let mut livekit_room = String::new();
 471                        let mut delete_livekit_room = false;
 472
 473                        if let Some(mut refreshed_room) = app_state
 474                            .db
 475                            .clear_stale_room_participants(room_id, server_id)
 476                            .await
 477                            .trace_err()
 478                        {
 479                            tracing::info!(
 480                                room_id = room_id.0,
 481                                new_participant_count = refreshed_room.room.participants.len(),
 482                                "refreshed room"
 483                            );
 484                            room_updated(&refreshed_room.room, &peer);
 485                            if let Some(channel) = refreshed_room.channel.as_ref() {
 486                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 487                            }
 488                            contacts_to_update
 489                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 490                            contacts_to_update
 491                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 492                            canceled_calls_to_user_ids =
 493                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 494                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 495                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 496                        }
 497
 498                        {
 499                            let pool = pool.lock();
 500                            for canceled_user_id in canceled_calls_to_user_ids {
 501                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 502                                    peer.send(
 503                                        connection_id,
 504                                        proto::CallCanceled {
 505                                            room_id: room_id.to_proto(),
 506                                        },
 507                                    )
 508                                    .trace_err();
 509                                }
 510                            }
 511                        }
 512
 513                        for user_id in contacts_to_update {
 514                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 515                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 516                            if let Some((busy, contacts)) = busy.zip(contacts) {
 517                                let pool = pool.lock();
 518                                let updated_contact = contact_for_user(user_id, busy, &pool);
 519                                for contact in contacts {
 520                                    if let db::Contact::Accepted {
 521                                        user_id: contact_user_id,
 522                                        ..
 523                                    } = contact
 524                                    {
 525                                        for contact_conn_id in
 526                                            pool.user_connection_ids(contact_user_id)
 527                                        {
 528                                            peer.send(
 529                                                contact_conn_id,
 530                                                proto::UpdateContacts {
 531                                                    contacts: vec![updated_contact.clone()],
 532                                                    remove_contacts: Default::default(),
 533                                                    incoming_requests: Default::default(),
 534                                                    remove_incoming_requests: Default::default(),
 535                                                    outgoing_requests: Default::default(),
 536                                                    remove_outgoing_requests: Default::default(),
 537                                                },
 538                                            )
 539                                            .trace_err();
 540                                        }
 541                                    }
 542                                }
 543                            }
 544                        }
 545
 546                        if let Some(live_kit) = livekit_client.as_ref() {
 547                            if delete_livekit_room {
 548                                live_kit.delete_room(livekit_room).await.trace_err();
 549                            }
 550                        }
 551                    }
 552                }
 553
 554                app_state
 555                    .db
 556                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 557                    .await
 558                    .trace_err();
 559            }
 560            .instrument(span),
 561        );
 562        Ok(())
 563    }
 564
 565    pub fn teardown(&self) {
 566        self.peer.teardown();
 567        self.connection_pool.lock().reset();
 568        let _ = self.teardown.send(true);
 569    }
 570
 571    #[cfg(test)]
 572    pub fn reset(&self, id: ServerId) {
 573        self.teardown();
 574        *self.id.lock() = id;
 575        self.peer.reset(id.0 as u32);
 576        let _ = self.teardown.send(false);
 577    }
 578
 579    #[cfg(test)]
 580    pub fn id(&self) -> ServerId {
 581        *self.id.lock()
 582    }
 583
 584    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 585    where
 586        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 587        Fut: 'static + Send + Future<Output = Result<()>>,
 588        M: EnvelopedMessage,
 589    {
 590        let prev_handler = self.handlers.insert(
 591            TypeId::of::<M>(),
 592            Box::new(move |envelope, session| {
 593                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 594                let received_at = envelope.received_at;
 595                tracing::info!("message received");
 596                let start_time = Instant::now();
 597                let future = (handler)(*envelope, session);
 598                async move {
 599                    let result = future.await;
 600                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 601                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 602                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 603                    let payload_type = M::NAME;
 604
 605                    match result {
 606                        Err(error) => {
 607                            tracing::error!(
 608                                ?error,
 609                                total_duration_ms,
 610                                processing_duration_ms,
 611                                queue_duration_ms,
 612                                payload_type,
 613                                "error handling message"
 614                            )
 615                        }
 616                        Ok(()) => tracing::info!(
 617                            total_duration_ms,
 618                            processing_duration_ms,
 619                            queue_duration_ms,
 620                            "finished handling message"
 621                        ),
 622                    }
 623                }
 624                .boxed()
 625            }),
 626        );
 627        if prev_handler.is_some() {
 628            panic!("registered a handler for the same message twice");
 629        }
 630        self
 631    }
 632
 633    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 634    where
 635        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 636        Fut: 'static + Send + Future<Output = Result<()>>,
 637        M: EnvelopedMessage,
 638    {
 639        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 640        self
 641    }
 642
 643    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 644    where
 645        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 646        Fut: Send + Future<Output = Result<()>>,
 647        M: RequestMessage,
 648    {
 649        let handler = Arc::new(handler);
 650        self.add_handler(move |envelope, session| {
 651            let receipt = envelope.receipt();
 652            let handler = handler.clone();
 653            async move {
 654                let peer = session.peer.clone();
 655                let responded = Arc::new(AtomicBool::default());
 656                let response = Response {
 657                    peer: peer.clone(),
 658                    responded: responded.clone(),
 659                    receipt,
 660                };
 661                match (handler)(envelope.payload, response, session).await {
 662                    Ok(()) => {
 663                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 664                            Ok(())
 665                        } else {
 666                            Err(anyhow!("handler did not send a response"))?
 667                        }
 668                    }
 669                    Err(error) => {
 670                        let proto_err = match &error {
 671                            Error::Internal(err) => err.to_proto(),
 672                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 673                        };
 674                        peer.respond_with_error(receipt, proto_err)?;
 675                        Err(error)
 676                    }
 677                }
 678            }
 679        })
 680    }
 681
 682    #[allow(clippy::too_many_arguments)]
 683    pub fn handle_connection(
 684        self: &Arc<Self>,
 685        connection: Connection,
 686        address: String,
 687        principal: Principal,
 688        zed_version: ZedVersion,
 689        geoip_country_code: Option<String>,
 690        system_id: Option<String>,
 691        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 692        executor: Executor,
 693    ) -> impl Future<Output = ()> {
 694        let this = self.clone();
 695        let span = info_span!("handle connection", %address,
 696            connection_id=field::Empty,
 697            user_id=field::Empty,
 698            login=field::Empty,
 699            impersonator=field::Empty,
 700            geoip_country_code=field::Empty
 701        );
 702        principal.update_span(&span);
 703        if let Some(country_code) = geoip_country_code.as_ref() {
 704            span.record("geoip_country_code", country_code);
 705        }
 706
 707        let mut teardown = self.teardown.subscribe();
 708        async move {
 709            if *teardown.borrow() {
 710                tracing::error!("server is tearing down");
 711                return
 712            }
 713            let (connection_id, handle_io, mut incoming_rx) = this
 714                .peer
 715                .add_connection(connection, {
 716                    let executor = executor.clone();
 717                    move |duration| executor.sleep(duration)
 718                });
 719            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 720
 721            tracing::info!("connection opened");
 722
 723            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 724            let http_client = match ReqwestClient::user_agent(&user_agent) {
 725                Ok(http_client) => Arc::new(http_client),
 726                Err(error) => {
 727                    tracing::error!(?error, "failed to create HTTP client");
 728                    return;
 729                }
 730            };
 731
 732            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 733                    supermaven_admin_api_key.to_string(),
 734                    http_client.clone(),
 735                )));
 736
 737            let session = Session {
 738                principal: principal.clone(),
 739                connection_id,
 740                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 741                peer: this.peer.clone(),
 742                connection_pool: this.connection_pool.clone(),
 743                app_state: this.app_state.clone(),
 744                http_client,
 745                geoip_country_code,
 746                system_id,
 747                _executor: executor.clone(),
 748                supermaven_client,
 749            };
 750
 751            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 752                tracing::error!(?error, "failed to send initial client update");
 753                return;
 754            }
 755
 756            let handle_io = handle_io.fuse();
 757            futures::pin_mut!(handle_io);
 758
 759            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 760            // This prevents deadlocks when e.g., client A performs a request to client B and
 761            // client B performs a request to client A. If both clients stop processing further
 762            // messages until their respective request completes, they won't have a chance to
 763            // respond to the other client's request and cause a deadlock.
 764            //
 765            // This arrangement ensures we will attempt to process earlier messages first, but fall
 766            // back to processing messages arrived later in the spirit of making progress.
 767            let mut foreground_message_handlers = FuturesUnordered::new();
 768            let concurrent_handlers = Arc::new(Semaphore::new(256));
 769            loop {
 770                let next_message = async {
 771                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 772                    let message = incoming_rx.next().await;
 773                    (permit, message)
 774                }.fuse();
 775                futures::pin_mut!(next_message);
 776                futures::select_biased! {
 777                    _ = teardown.changed().fuse() => return,
 778                    result = handle_io => {
 779                        if let Err(error) = result {
 780                            tracing::error!(?error, "error handling I/O");
 781                        }
 782                        break;
 783                    }
 784                    _ = foreground_message_handlers.next() => {}
 785                    next_message = next_message => {
 786                        let (permit, message) = next_message;
 787                        if let Some(message) = message {
 788                            let type_name = message.payload_type_name();
 789                            // note: we copy all the fields from the parent span so we can query them in the logs.
 790                            // (https://github.com/tokio-rs/tracing/issues/2670).
 791                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 792                                user_id=field::Empty,
 793                                login=field::Empty,
 794                                impersonator=field::Empty,
 795                            );
 796                            principal.update_span(&span);
 797                            let span_enter = span.enter();
 798                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 799                                let is_background = message.is_background();
 800                                let handle_message = (handler)(message, session.clone());
 801                                drop(span_enter);
 802
 803                                let handle_message = async move {
 804                                    handle_message.await;
 805                                    drop(permit);
 806                                }.instrument(span);
 807                                if is_background {
 808                                    executor.spawn_detached(handle_message);
 809                                } else {
 810                                    foreground_message_handlers.push(handle_message);
 811                                }
 812                            } else {
 813                                tracing::error!("no message handler");
 814                            }
 815                        } else {
 816                            tracing::info!("connection closed");
 817                            break;
 818                        }
 819                    }
 820                }
 821            }
 822
 823            drop(foreground_message_handlers);
 824            tracing::info!("signing out");
 825            if let Err(error) = connection_lost(session, teardown, executor).await {
 826                tracing::error!(?error, "error signing out");
 827            }
 828
 829        }.instrument(span)
 830    }
 831
 832    async fn send_initial_client_update(
 833        &self,
 834        connection_id: ConnectionId,
 835        principal: &Principal,
 836        zed_version: ZedVersion,
 837        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 838        session: &Session,
 839    ) -> Result<()> {
 840        self.peer.send(
 841            connection_id,
 842            proto::Hello {
 843                peer_id: Some(connection_id.into()),
 844            },
 845        )?;
 846        tracing::info!("sent hello message");
 847        if let Some(send_connection_id) = send_connection_id.take() {
 848            let _ = send_connection_id.send(connection_id);
 849        }
 850
 851        match principal {
 852            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 853                if !user.connected_once {
 854                    self.peer.send(connection_id, proto::ShowContacts {})?;
 855                    self.app_state
 856                        .db
 857                        .set_user_connected_once(user.id, true)
 858                        .await?;
 859                }
 860
 861                update_user_plan(user.id, session).await?;
 862
 863                let contacts = self.app_state.db.get_contacts(user.id).await?;
 864
 865                {
 866                    let mut pool = self.connection_pool.lock();
 867                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 868                    self.peer.send(
 869                        connection_id,
 870                        build_initial_contacts_update(contacts, &pool),
 871                    )?;
 872                }
 873
 874                if should_auto_subscribe_to_channels(zed_version) {
 875                    subscribe_user_to_channels(user.id, session).await?;
 876                }
 877
 878                if let Some(incoming_call) =
 879                    self.app_state.db.incoming_call_for_user(user.id).await?
 880                {
 881                    self.peer.send(connection_id, incoming_call)?;
 882                }
 883
 884                update_user_contacts(user.id, session).await?;
 885            }
 886        }
 887
 888        Ok(())
 889    }
 890
 891    pub async fn invite_code_redeemed(
 892        self: &Arc<Self>,
 893        inviter_id: UserId,
 894        invitee_id: UserId,
 895    ) -> Result<()> {
 896        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 897            if let Some(code) = &user.invite_code {
 898                let pool = self.connection_pool.lock();
 899                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 900                for connection_id in pool.user_connection_ids(inviter_id) {
 901                    self.peer.send(
 902                        connection_id,
 903                        proto::UpdateContacts {
 904                            contacts: vec![invitee_contact.clone()],
 905                            ..Default::default()
 906                        },
 907                    )?;
 908                    self.peer.send(
 909                        connection_id,
 910                        proto::UpdateInviteInfo {
 911                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 912                            count: user.invite_count as u32,
 913                        },
 914                    )?;
 915                }
 916            }
 917        }
 918        Ok(())
 919    }
 920
 921    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 922        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 923            if let Some(invite_code) = &user.invite_code {
 924                let pool = self.connection_pool.lock();
 925                for connection_id in pool.user_connection_ids(user_id) {
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateInviteInfo {
 929                            url: format!(
 930                                "{}{}",
 931                                self.app_state.config.invite_link_prefix, invite_code
 932                            ),
 933                            count: user.invite_count as u32,
 934                        },
 935                    )?;
 936                }
 937            }
 938        }
 939        Ok(())
 940    }
 941
 942    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 943        let pool = self.connection_pool.lock();
 944        for connection_id in pool.user_connection_ids(user_id) {
 945            self.peer
 946                .send(connection_id, proto::RefreshLlmToken {})
 947                .trace_err();
 948        }
 949    }
 950
 951    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 952        ServerSnapshot {
 953            connection_pool: ConnectionPoolGuard {
 954                guard: self.connection_pool.lock(),
 955                _not_send: PhantomData,
 956            },
 957            peer: &self.peer,
 958        }
 959    }
 960}
 961
 962impl<'a> Deref for ConnectionPoolGuard<'a> {
 963    type Target = ConnectionPool;
 964
 965    fn deref(&self) -> &Self::Target {
 966        &self.guard
 967    }
 968}
 969
 970impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 971    fn deref_mut(&mut self) -> &mut Self::Target {
 972        &mut self.guard
 973    }
 974}
 975
 976impl<'a> Drop for ConnectionPoolGuard<'a> {
 977    fn drop(&mut self) {
 978        #[cfg(test)]
 979        self.check_invariants();
 980    }
 981}
 982
 983fn broadcast<F>(
 984    sender_id: Option<ConnectionId>,
 985    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 986    mut f: F,
 987) where
 988    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 989{
 990    for receiver_id in receiver_ids {
 991        if Some(receiver_id) != sender_id {
 992            if let Err(error) = f(receiver_id) {
 993                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 994            }
 995        }
 996    }
 997}
 998
 999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002    fn name() -> &'static HeaderName {
1003        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005    }
1006
1007    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008    where
1009        Self: Sized,
1010        I: Iterator<Item = &'i axum::http::HeaderValue>,
1011    {
1012        let version = values
1013            .next()
1014            .ok_or_else(axum::headers::Error::invalid)?
1015            .to_str()
1016            .map_err(|_| axum::headers::Error::invalid())?
1017            .parse()
1018            .map_err(|_| axum::headers::Error::invalid())?;
1019        Ok(Self(version))
1020    }
1021
1022    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023        values.extend([self.0.to_string().parse().unwrap()]);
1024    }
1025}
1026
1027pub struct AppVersionHeader(SemanticVersion);
1028impl Header for AppVersionHeader {
1029    fn name() -> &'static HeaderName {
1030        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032    }
1033
1034    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035    where
1036        Self: Sized,
1037        I: Iterator<Item = &'i axum::http::HeaderValue>,
1038    {
1039        let version = values
1040            .next()
1041            .ok_or_else(axum::headers::Error::invalid)?
1042            .to_str()
1043            .map_err(|_| axum::headers::Error::invalid())?
1044            .parse()
1045            .map_err(|_| axum::headers::Error::invalid())?;
1046        Ok(Self(version))
1047    }
1048
1049    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050        values.extend([self.0.to_string().parse().unwrap()]);
1051    }
1052}
1053
1054pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1055    Router::new()
1056        .route("/rpc", get(handle_websocket_request))
1057        .layer(
1058            ServiceBuilder::new()
1059                .layer(Extension(server.app_state.clone()))
1060                .layer(middleware::from_fn(auth::validate_header)),
1061        )
1062        .route("/metrics", get(handle_metrics))
1063        .layer(Extension(server))
1064}
1065
1066#[allow(clippy::too_many_arguments)]
1067pub async fn handle_websocket_request(
1068    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1069    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1070    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1071    Extension(server): Extension<Arc<Server>>,
1072    Extension(principal): Extension<Principal>,
1073    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1074    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1075    ws: WebSocketUpgrade,
1076) -> axum::response::Response {
1077    if protocol_version != rpc::PROTOCOL_VERSION {
1078        return (
1079            StatusCode::UPGRADE_REQUIRED,
1080            "client must be upgraded".to_string(),
1081        )
1082            .into_response();
1083    }
1084
1085    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1086        return (
1087            StatusCode::UPGRADE_REQUIRED,
1088            "no version header found".to_string(),
1089        )
1090            .into_response();
1091    };
1092
1093    if !version.can_collaborate() {
1094        return (
1095            StatusCode::UPGRADE_REQUIRED,
1096            "client must be upgraded".to_string(),
1097        )
1098            .into_response();
1099    }
1100
1101    let socket_address = socket_address.to_string();
1102    ws.on_upgrade(move |socket| {
1103        let socket = socket
1104            .map_ok(to_tungstenite_message)
1105            .err_into()
1106            .with(|message| async move { to_axum_message(message) });
1107        let connection = Connection::new(Box::pin(socket));
1108        async move {
1109            server
1110                .handle_connection(
1111                    connection,
1112                    socket_address,
1113                    principal,
1114                    version,
1115                    country_code_header.map(|header| header.to_string()),
1116                    system_id_header.map(|header| header.to_string()),
1117                    None,
1118                    Executor::Production,
1119                )
1120                .await;
1121        }
1122    })
1123}
1124
1125pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1126    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1127    let connections_metric = CONNECTIONS_METRIC
1128        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1129
1130    let connections = server
1131        .connection_pool
1132        .lock()
1133        .connections()
1134        .filter(|connection| !connection.admin)
1135        .count();
1136    connections_metric.set(connections as _);
1137
1138    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1139    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1140        register_int_gauge!(
1141            "shared_projects",
1142            "number of open projects with one or more guests"
1143        )
1144        .unwrap()
1145    });
1146
1147    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1148    shared_projects_metric.set(shared_projects as _);
1149
1150    let encoder = prometheus::TextEncoder::new();
1151    let metric_families = prometheus::gather();
1152    let encoded_metrics = encoder
1153        .encode_to_string(&metric_families)
1154        .map_err(|err| anyhow!("{}", err))?;
1155    Ok(encoded_metrics)
1156}
1157
1158#[instrument(err, skip(executor))]
1159async fn connection_lost(
1160    session: Session,
1161    mut teardown: watch::Receiver<bool>,
1162    executor: Executor,
1163) -> Result<()> {
1164    session.peer.disconnect(session.connection_id);
1165    session
1166        .connection_pool()
1167        .await
1168        .remove_connection(session.connection_id)?;
1169
1170    session
1171        .db()
1172        .await
1173        .connection_lost(session.connection_id)
1174        .await
1175        .trace_err();
1176
1177    futures::select_biased! {
1178        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1179
1180            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1181            leave_room_for_session(&session, session.connection_id).await.trace_err();
1182            leave_channel_buffers_for_session(&session)
1183                .await
1184                .trace_err();
1185
1186            if !session
1187                .connection_pool()
1188                .await
1189                .is_user_online(session.user_id())
1190            {
1191                let db = session.db().await;
1192                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1193                    room_updated(&room, &session.peer);
1194                }
1195            }
1196
1197            update_user_contacts(session.user_id(), &session).await?;
1198        },
1199        _ = teardown.changed().fuse() => {}
1200    }
1201
1202    Ok(())
1203}
1204
1205/// Acknowledges a ping from a client, used to keep the connection alive.
1206async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1207    response.send(proto::Ack {})?;
1208    Ok(())
1209}
1210
1211/// Creates a new room for calling (outside of channels)
1212async fn create_room(
1213    _request: proto::CreateRoom,
1214    response: Response<proto::CreateRoom>,
1215    session: Session,
1216) -> Result<()> {
1217    let livekit_room = nanoid::nanoid!(30);
1218
1219    let live_kit_connection_info = util::maybe!(async {
1220        let live_kit = session.app_state.livekit_client.as_ref();
1221        let live_kit = live_kit?;
1222        let user_id = session.user_id().to_string();
1223
1224        let token = live_kit
1225            .room_token(&livekit_room, &user_id.to_string())
1226            .trace_err()?;
1227
1228        Some(proto::LiveKitConnectionInfo {
1229            server_url: live_kit.url().into(),
1230            token,
1231            can_publish: true,
1232        })
1233    })
1234    .await;
1235
1236    let room = session
1237        .db()
1238        .await
1239        .create_room(session.user_id(), session.connection_id, &livekit_room)
1240        .await?;
1241
1242    response.send(proto::CreateRoomResponse {
1243        room: Some(room.clone()),
1244        live_kit_connection_info,
1245    })?;
1246
1247    update_user_contacts(session.user_id(), &session).await?;
1248    Ok(())
1249}
1250
1251/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1252async fn join_room(
1253    request: proto::JoinRoom,
1254    response: Response<proto::JoinRoom>,
1255    session: Session,
1256) -> Result<()> {
1257    let room_id = RoomId::from_proto(request.id);
1258
1259    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1260
1261    if let Some(channel_id) = channel_id {
1262        return join_channel_internal(channel_id, Box::new(response), session).await;
1263    }
1264
1265    let joined_room = {
1266        let room = session
1267            .db()
1268            .await
1269            .join_room(room_id, session.user_id(), session.connection_id)
1270            .await?;
1271        room_updated(&room.room, &session.peer);
1272        room.into_inner()
1273    };
1274
1275    for connection_id in session
1276        .connection_pool()
1277        .await
1278        .user_connection_ids(session.user_id())
1279    {
1280        session
1281            .peer
1282            .send(
1283                connection_id,
1284                proto::CallCanceled {
1285                    room_id: room_id.to_proto(),
1286                },
1287            )
1288            .trace_err();
1289    }
1290
1291    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1292    {
1293        live_kit
1294            .room_token(
1295                &joined_room.room.livekit_room,
1296                &session.user_id().to_string(),
1297            )
1298            .trace_err()
1299            .map(|token| proto::LiveKitConnectionInfo {
1300                server_url: live_kit.url().into(),
1301                token,
1302                can_publish: true,
1303            })
1304    } else {
1305        None
1306    };
1307
1308    response.send(proto::JoinRoomResponse {
1309        room: Some(joined_room.room),
1310        channel_id: None,
1311        live_kit_connection_info,
1312    })?;
1313
1314    update_user_contacts(session.user_id(), &session).await?;
1315    Ok(())
1316}
1317
1318/// Rejoin room is used to reconnect to a room after connection errors.
1319async fn rejoin_room(
1320    request: proto::RejoinRoom,
1321    response: Response<proto::RejoinRoom>,
1322    session: Session,
1323) -> Result<()> {
1324    let room;
1325    let channel;
1326    {
1327        let mut rejoined_room = session
1328            .db()
1329            .await
1330            .rejoin_room(request, session.user_id(), session.connection_id)
1331            .await?;
1332
1333        response.send(proto::RejoinRoomResponse {
1334            room: Some(rejoined_room.room.clone()),
1335            reshared_projects: rejoined_room
1336                .reshared_projects
1337                .iter()
1338                .map(|project| proto::ResharedProject {
1339                    id: project.id.to_proto(),
1340                    collaborators: project
1341                        .collaborators
1342                        .iter()
1343                        .map(|collaborator| collaborator.to_proto())
1344                        .collect(),
1345                })
1346                .collect(),
1347            rejoined_projects: rejoined_room
1348                .rejoined_projects
1349                .iter()
1350                .map(|rejoined_project| rejoined_project.to_proto())
1351                .collect(),
1352        })?;
1353        room_updated(&rejoined_room.room, &session.peer);
1354
1355        for project in &rejoined_room.reshared_projects {
1356            for collaborator in &project.collaborators {
1357                session
1358                    .peer
1359                    .send(
1360                        collaborator.connection_id,
1361                        proto::UpdateProjectCollaborator {
1362                            project_id: project.id.to_proto(),
1363                            old_peer_id: Some(project.old_connection_id.into()),
1364                            new_peer_id: Some(session.connection_id.into()),
1365                        },
1366                    )
1367                    .trace_err();
1368            }
1369
1370            broadcast(
1371                Some(session.connection_id),
1372                project
1373                    .collaborators
1374                    .iter()
1375                    .map(|collaborator| collaborator.connection_id),
1376                |connection_id| {
1377                    session.peer.forward_send(
1378                        session.connection_id,
1379                        connection_id,
1380                        proto::UpdateProject {
1381                            project_id: project.id.to_proto(),
1382                            worktrees: project.worktrees.clone(),
1383                        },
1384                    )
1385                },
1386            );
1387        }
1388
1389        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1390
1391        let rejoined_room = rejoined_room.into_inner();
1392
1393        room = rejoined_room.room;
1394        channel = rejoined_room.channel;
1395    }
1396
1397    if let Some(channel) = channel {
1398        channel_updated(
1399            &channel,
1400            &room,
1401            &session.peer,
1402            &*session.connection_pool().await,
1403        );
1404    }
1405
1406    update_user_contacts(session.user_id(), &session).await?;
1407    Ok(())
1408}
1409
1410fn notify_rejoined_projects(
1411    rejoined_projects: &mut Vec<RejoinedProject>,
1412    session: &Session,
1413) -> Result<()> {
1414    for project in rejoined_projects.iter() {
1415        for collaborator in &project.collaborators {
1416            session
1417                .peer
1418                .send(
1419                    collaborator.connection_id,
1420                    proto::UpdateProjectCollaborator {
1421                        project_id: project.id.to_proto(),
1422                        old_peer_id: Some(project.old_connection_id.into()),
1423                        new_peer_id: Some(session.connection_id.into()),
1424                    },
1425                )
1426                .trace_err();
1427        }
1428    }
1429
1430    for project in rejoined_projects {
1431        for worktree in mem::take(&mut project.worktrees) {
1432            // Stream this worktree's entries.
1433            let message = proto::UpdateWorktree {
1434                project_id: project.id.to_proto(),
1435                worktree_id: worktree.id,
1436                abs_path: worktree.abs_path.clone(),
1437                root_name: worktree.root_name,
1438                updated_entries: worktree.updated_entries,
1439                removed_entries: worktree.removed_entries,
1440                scan_id: worktree.scan_id,
1441                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1442                updated_repositories: worktree.updated_repositories,
1443                removed_repositories: worktree.removed_repositories,
1444            };
1445            for update in proto::split_worktree_update(message) {
1446                session.peer.send(session.connection_id, update.clone())?;
1447            }
1448
1449            // Stream this worktree's diagnostics.
1450            for summary in worktree.diagnostic_summaries {
1451                session.peer.send(
1452                    session.connection_id,
1453                    proto::UpdateDiagnosticSummary {
1454                        project_id: project.id.to_proto(),
1455                        worktree_id: worktree.id,
1456                        summary: Some(summary),
1457                    },
1458                )?;
1459            }
1460
1461            for settings_file in worktree.settings_files {
1462                session.peer.send(
1463                    session.connection_id,
1464                    proto::UpdateWorktreeSettings {
1465                        project_id: project.id.to_proto(),
1466                        worktree_id: worktree.id,
1467                        path: settings_file.path,
1468                        content: Some(settings_file.content),
1469                        kind: Some(settings_file.kind.to_proto().into()),
1470                    },
1471                )?;
1472            }
1473        }
1474
1475        for language_server in &project.language_servers {
1476            session.peer.send(
1477                session.connection_id,
1478                proto::UpdateLanguageServer {
1479                    project_id: project.id.to_proto(),
1480                    language_server_id: language_server.id,
1481                    variant: Some(
1482                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1483                            proto::LspDiskBasedDiagnosticsUpdated {},
1484                        ),
1485                    ),
1486                },
1487            )?;
1488        }
1489    }
1490    Ok(())
1491}
1492
1493/// leave room disconnects from the room.
1494async fn leave_room(
1495    _: proto::LeaveRoom,
1496    response: Response<proto::LeaveRoom>,
1497    session: Session,
1498) -> Result<()> {
1499    leave_room_for_session(&session, session.connection_id).await?;
1500    response.send(proto::Ack {})?;
1501    Ok(())
1502}
1503
1504/// Updates the permissions of someone else in the room.
1505async fn set_room_participant_role(
1506    request: proto::SetRoomParticipantRole,
1507    response: Response<proto::SetRoomParticipantRole>,
1508    session: Session,
1509) -> Result<()> {
1510    let user_id = UserId::from_proto(request.user_id);
1511    let role = ChannelRole::from(request.role());
1512
1513    let (livekit_room, can_publish) = {
1514        let room = session
1515            .db()
1516            .await
1517            .set_room_participant_role(
1518                session.user_id(),
1519                RoomId::from_proto(request.room_id),
1520                user_id,
1521                role,
1522            )
1523            .await?;
1524
1525        let livekit_room = room.livekit_room.clone();
1526        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1527        room_updated(&room, &session.peer);
1528        (livekit_room, can_publish)
1529    };
1530
1531    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1532        live_kit
1533            .update_participant(
1534                livekit_room.clone(),
1535                request.user_id.to_string(),
1536                livekit_server::proto::ParticipantPermission {
1537                    can_subscribe: true,
1538                    can_publish,
1539                    can_publish_data: can_publish,
1540                    hidden: false,
1541                    recorder: false,
1542                },
1543            )
1544            .await
1545            .trace_err();
1546    }
1547
1548    response.send(proto::Ack {})?;
1549    Ok(())
1550}
1551
1552/// Call someone else into the current room
1553async fn call(
1554    request: proto::Call,
1555    response: Response<proto::Call>,
1556    session: Session,
1557) -> Result<()> {
1558    let room_id = RoomId::from_proto(request.room_id);
1559    let calling_user_id = session.user_id();
1560    let calling_connection_id = session.connection_id;
1561    let called_user_id = UserId::from_proto(request.called_user_id);
1562    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1563    if !session
1564        .db()
1565        .await
1566        .has_contact(calling_user_id, called_user_id)
1567        .await?
1568    {
1569        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1570    }
1571
1572    let incoming_call = {
1573        let (room, incoming_call) = &mut *session
1574            .db()
1575            .await
1576            .call(
1577                room_id,
1578                calling_user_id,
1579                calling_connection_id,
1580                called_user_id,
1581                initial_project_id,
1582            )
1583            .await?;
1584        room_updated(room, &session.peer);
1585        mem::take(incoming_call)
1586    };
1587    update_user_contacts(called_user_id, &session).await?;
1588
1589    let mut calls = session
1590        .connection_pool()
1591        .await
1592        .user_connection_ids(called_user_id)
1593        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1594        .collect::<FuturesUnordered<_>>();
1595
1596    while let Some(call_response) = calls.next().await {
1597        match call_response.as_ref() {
1598            Ok(_) => {
1599                response.send(proto::Ack {})?;
1600                return Ok(());
1601            }
1602            Err(_) => {
1603                call_response.trace_err();
1604            }
1605        }
1606    }
1607
1608    {
1609        let room = session
1610            .db()
1611            .await
1612            .call_failed(room_id, called_user_id)
1613            .await?;
1614        room_updated(&room, &session.peer);
1615    }
1616    update_user_contacts(called_user_id, &session).await?;
1617
1618    Err(anyhow!("failed to ring user"))?
1619}
1620
1621/// Cancel an outgoing call.
1622async fn cancel_call(
1623    request: proto::CancelCall,
1624    response: Response<proto::CancelCall>,
1625    session: Session,
1626) -> Result<()> {
1627    let called_user_id = UserId::from_proto(request.called_user_id);
1628    let room_id = RoomId::from_proto(request.room_id);
1629    {
1630        let room = session
1631            .db()
1632            .await
1633            .cancel_call(room_id, session.connection_id, called_user_id)
1634            .await?;
1635        room_updated(&room, &session.peer);
1636    }
1637
1638    for connection_id in session
1639        .connection_pool()
1640        .await
1641        .user_connection_ids(called_user_id)
1642    {
1643        session
1644            .peer
1645            .send(
1646                connection_id,
1647                proto::CallCanceled {
1648                    room_id: room_id.to_proto(),
1649                },
1650            )
1651            .trace_err();
1652    }
1653    response.send(proto::Ack {})?;
1654
1655    update_user_contacts(called_user_id, &session).await?;
1656    Ok(())
1657}
1658
1659/// Decline an incoming call.
1660async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1661    let room_id = RoomId::from_proto(message.room_id);
1662    {
1663        let room = session
1664            .db()
1665            .await
1666            .decline_call(Some(room_id), session.user_id())
1667            .await?
1668            .ok_or_else(|| anyhow!("failed to decline call"))?;
1669        room_updated(&room, &session.peer);
1670    }
1671
1672    for connection_id in session
1673        .connection_pool()
1674        .await
1675        .user_connection_ids(session.user_id())
1676    {
1677        session
1678            .peer
1679            .send(
1680                connection_id,
1681                proto::CallCanceled {
1682                    room_id: room_id.to_proto(),
1683                },
1684            )
1685            .trace_err();
1686    }
1687    update_user_contacts(session.user_id(), &session).await?;
1688    Ok(())
1689}
1690
1691/// Updates other participants in the room with your current location.
1692async fn update_participant_location(
1693    request: proto::UpdateParticipantLocation,
1694    response: Response<proto::UpdateParticipantLocation>,
1695    session: Session,
1696) -> Result<()> {
1697    let room_id = RoomId::from_proto(request.room_id);
1698    let location = request
1699        .location
1700        .ok_or_else(|| anyhow!("invalid location"))?;
1701
1702    let db = session.db().await;
1703    let room = db
1704        .update_room_participant_location(room_id, session.connection_id, location)
1705        .await?;
1706
1707    room_updated(&room, &session.peer);
1708    response.send(proto::Ack {})?;
1709    Ok(())
1710}
1711
1712/// Share a project into the room.
1713async fn share_project(
1714    request: proto::ShareProject,
1715    response: Response<proto::ShareProject>,
1716    session: Session,
1717) -> Result<()> {
1718    let (project_id, room) = &*session
1719        .db()
1720        .await
1721        .share_project(
1722            RoomId::from_proto(request.room_id),
1723            session.connection_id,
1724            &request.worktrees,
1725            request.is_ssh_project,
1726        )
1727        .await?;
1728    response.send(proto::ShareProjectResponse {
1729        project_id: project_id.to_proto(),
1730    })?;
1731    room_updated(room, &session.peer);
1732
1733    Ok(())
1734}
1735
1736/// Unshare a project from the room.
1737async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1738    let project_id = ProjectId::from_proto(message.project_id);
1739    unshare_project_internal(project_id, session.connection_id, &session).await
1740}
1741
1742async fn unshare_project_internal(
1743    project_id: ProjectId,
1744    connection_id: ConnectionId,
1745    session: &Session,
1746) -> Result<()> {
1747    let delete = {
1748        let room_guard = session
1749            .db()
1750            .await
1751            .unshare_project(project_id, connection_id)
1752            .await?;
1753
1754        let (delete, room, guest_connection_ids) = &*room_guard;
1755
1756        let message = proto::UnshareProject {
1757            project_id: project_id.to_proto(),
1758        };
1759
1760        broadcast(
1761            Some(connection_id),
1762            guest_connection_ids.iter().copied(),
1763            |conn_id| session.peer.send(conn_id, message.clone()),
1764        );
1765        if let Some(room) = room {
1766            room_updated(room, &session.peer);
1767        }
1768
1769        *delete
1770    };
1771
1772    if delete {
1773        let db = session.db().await;
1774        db.delete_project(project_id).await?;
1775    }
1776
1777    Ok(())
1778}
1779
1780/// Join someone elses shared project.
1781async fn join_project(
1782    request: proto::JoinProject,
1783    response: Response<proto::JoinProject>,
1784    session: Session,
1785) -> Result<()> {
1786    let project_id = ProjectId::from_proto(request.project_id);
1787
1788    tracing::info!(%project_id, "join project");
1789
1790    let db = session.db().await;
1791    let (project, replica_id) = &mut *db
1792        .join_project(project_id, session.connection_id, session.user_id())
1793        .await?;
1794    drop(db);
1795    tracing::info!(%project_id, "join remote project");
1796    join_project_internal(response, session, project, replica_id)
1797}
1798
1799trait JoinProjectInternalResponse {
1800    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1801}
1802impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1803    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1804        Response::<proto::JoinProject>::send(self, result)
1805    }
1806}
1807
1808fn join_project_internal(
1809    response: impl JoinProjectInternalResponse,
1810    session: Session,
1811    project: &mut Project,
1812    replica_id: &ReplicaId,
1813) -> Result<()> {
1814    let collaborators = project
1815        .collaborators
1816        .iter()
1817        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1818        .map(|collaborator| collaborator.to_proto())
1819        .collect::<Vec<_>>();
1820    let project_id = project.id;
1821    let guest_user_id = session.user_id();
1822
1823    let worktrees = project
1824        .worktrees
1825        .iter()
1826        .map(|(id, worktree)| proto::WorktreeMetadata {
1827            id: *id,
1828            root_name: worktree.root_name.clone(),
1829            visible: worktree.visible,
1830            abs_path: worktree.abs_path.clone(),
1831        })
1832        .collect::<Vec<_>>();
1833
1834    let add_project_collaborator = proto::AddProjectCollaborator {
1835        project_id: project_id.to_proto(),
1836        collaborator: Some(proto::Collaborator {
1837            peer_id: Some(session.connection_id.into()),
1838            replica_id: replica_id.0 as u32,
1839            user_id: guest_user_id.to_proto(),
1840            is_host: false,
1841        }),
1842    };
1843
1844    for collaborator in &collaborators {
1845        session
1846            .peer
1847            .send(
1848                collaborator.peer_id.unwrap().into(),
1849                add_project_collaborator.clone(),
1850            )
1851            .trace_err();
1852    }
1853
1854    // First, we send the metadata associated with each worktree.
1855    response.send(proto::JoinProjectResponse {
1856        project_id: project.id.0 as u64,
1857        worktrees: worktrees.clone(),
1858        replica_id: replica_id.0 as u32,
1859        collaborators: collaborators.clone(),
1860        language_servers: project.language_servers.clone(),
1861        role: project.role.into(),
1862    })?;
1863
1864    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1865        // Stream this worktree's entries.
1866        let message = proto::UpdateWorktree {
1867            project_id: project_id.to_proto(),
1868            worktree_id,
1869            abs_path: worktree.abs_path.clone(),
1870            root_name: worktree.root_name,
1871            updated_entries: worktree.entries,
1872            removed_entries: Default::default(),
1873            scan_id: worktree.scan_id,
1874            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1875            updated_repositories: worktree.repository_entries.into_values().collect(),
1876            removed_repositories: Default::default(),
1877        };
1878        for update in proto::split_worktree_update(message) {
1879            session.peer.send(session.connection_id, update.clone())?;
1880        }
1881
1882        // Stream this worktree's diagnostics.
1883        for summary in worktree.diagnostic_summaries {
1884            session.peer.send(
1885                session.connection_id,
1886                proto::UpdateDiagnosticSummary {
1887                    project_id: project_id.to_proto(),
1888                    worktree_id: worktree.id,
1889                    summary: Some(summary),
1890                },
1891            )?;
1892        }
1893
1894        for settings_file in worktree.settings_files {
1895            session.peer.send(
1896                session.connection_id,
1897                proto::UpdateWorktreeSettings {
1898                    project_id: project_id.to_proto(),
1899                    worktree_id: worktree.id,
1900                    path: settings_file.path,
1901                    content: Some(settings_file.content),
1902                    kind: Some(settings_file.kind.to_proto() as i32),
1903                },
1904            )?;
1905        }
1906    }
1907
1908    for language_server in &project.language_servers {
1909        session.peer.send(
1910            session.connection_id,
1911            proto::UpdateLanguageServer {
1912                project_id: project_id.to_proto(),
1913                language_server_id: language_server.id,
1914                variant: Some(
1915                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1916                        proto::LspDiskBasedDiagnosticsUpdated {},
1917                    ),
1918                ),
1919            },
1920        )?;
1921    }
1922
1923    Ok(())
1924}
1925
1926/// Leave someone elses shared project.
1927async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1928    let sender_id = session.connection_id;
1929    let project_id = ProjectId::from_proto(request.project_id);
1930    let db = session.db().await;
1931
1932    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1933    tracing::info!(
1934        %project_id,
1935        "leave project"
1936    );
1937
1938    project_left(project, &session);
1939    if let Some(room) = room {
1940        room_updated(room, &session.peer);
1941    }
1942
1943    Ok(())
1944}
1945
1946/// Updates other participants with changes to the project
1947async fn update_project(
1948    request: proto::UpdateProject,
1949    response: Response<proto::UpdateProject>,
1950    session: Session,
1951) -> Result<()> {
1952    let project_id = ProjectId::from_proto(request.project_id);
1953    let (room, guest_connection_ids) = &*session
1954        .db()
1955        .await
1956        .update_project(project_id, session.connection_id, &request.worktrees)
1957        .await?;
1958    broadcast(
1959        Some(session.connection_id),
1960        guest_connection_ids.iter().copied(),
1961        |connection_id| {
1962            session
1963                .peer
1964                .forward_send(session.connection_id, connection_id, request.clone())
1965        },
1966    );
1967    if let Some(room) = room {
1968        room_updated(room, &session.peer);
1969    }
1970    response.send(proto::Ack {})?;
1971
1972    Ok(())
1973}
1974
1975/// Updates other participants with changes to the worktree
1976async fn update_worktree(
1977    request: proto::UpdateWorktree,
1978    response: Response<proto::UpdateWorktree>,
1979    session: Session,
1980) -> Result<()> {
1981    let guest_connection_ids = session
1982        .db()
1983        .await
1984        .update_worktree(&request, session.connection_id)
1985        .await?;
1986
1987    broadcast(
1988        Some(session.connection_id),
1989        guest_connection_ids.iter().copied(),
1990        |connection_id| {
1991            session
1992                .peer
1993                .forward_send(session.connection_id, connection_id, request.clone())
1994        },
1995    );
1996    response.send(proto::Ack {})?;
1997    Ok(())
1998}
1999
2000/// Updates other participants with changes to the diagnostics
2001async fn update_diagnostic_summary(
2002    message: proto::UpdateDiagnosticSummary,
2003    session: Session,
2004) -> Result<()> {
2005    let guest_connection_ids = session
2006        .db()
2007        .await
2008        .update_diagnostic_summary(&message, session.connection_id)
2009        .await?;
2010
2011    broadcast(
2012        Some(session.connection_id),
2013        guest_connection_ids.iter().copied(),
2014        |connection_id| {
2015            session
2016                .peer
2017                .forward_send(session.connection_id, connection_id, message.clone())
2018        },
2019    );
2020
2021    Ok(())
2022}
2023
2024/// Updates other participants with changes to the worktree settings
2025async fn update_worktree_settings(
2026    message: proto::UpdateWorktreeSettings,
2027    session: Session,
2028) -> Result<()> {
2029    let guest_connection_ids = session
2030        .db()
2031        .await
2032        .update_worktree_settings(&message, session.connection_id)
2033        .await?;
2034
2035    broadcast(
2036        Some(session.connection_id),
2037        guest_connection_ids.iter().copied(),
2038        |connection_id| {
2039            session
2040                .peer
2041                .forward_send(session.connection_id, connection_id, message.clone())
2042        },
2043    );
2044
2045    Ok(())
2046}
2047
2048/// Notify other participants that a  language server has started.
2049async fn start_language_server(
2050    request: proto::StartLanguageServer,
2051    session: Session,
2052) -> Result<()> {
2053    let guest_connection_ids = session
2054        .db()
2055        .await
2056        .start_language_server(&request, session.connection_id)
2057        .await?;
2058
2059    broadcast(
2060        Some(session.connection_id),
2061        guest_connection_ids.iter().copied(),
2062        |connection_id| {
2063            session
2064                .peer
2065                .forward_send(session.connection_id, connection_id, request.clone())
2066        },
2067    );
2068    Ok(())
2069}
2070
2071/// Notify other participants that a language server has changed.
2072async fn update_language_server(
2073    request: proto::UpdateLanguageServer,
2074    session: Session,
2075) -> Result<()> {
2076    let project_id = ProjectId::from_proto(request.project_id);
2077    let project_connection_ids = session
2078        .db()
2079        .await
2080        .project_connection_ids(project_id, session.connection_id, true)
2081        .await?;
2082    broadcast(
2083        Some(session.connection_id),
2084        project_connection_ids.iter().copied(),
2085        |connection_id| {
2086            session
2087                .peer
2088                .forward_send(session.connection_id, connection_id, request.clone())
2089        },
2090    );
2091    Ok(())
2092}
2093
2094/// forward a project request to the host. These requests should be read only
2095/// as guests are allowed to send them.
2096async fn forward_read_only_project_request<T>(
2097    request: T,
2098    response: Response<T>,
2099    session: Session,
2100) -> Result<()>
2101where
2102    T: EntityMessage + RequestMessage,
2103{
2104    let project_id = ProjectId::from_proto(request.remote_entity_id());
2105    let host_connection_id = session
2106        .db()
2107        .await
2108        .host_for_read_only_project_request(project_id, session.connection_id)
2109        .await?;
2110    let payload = session
2111        .peer
2112        .forward_request(session.connection_id, host_connection_id, request)
2113        .await?;
2114    response.send(payload)?;
2115    Ok(())
2116}
2117
2118async fn forward_find_search_candidates_request(
2119    request: proto::FindSearchCandidates,
2120    response: Response<proto::FindSearchCandidates>,
2121    session: Session,
2122) -> Result<()> {
2123    let project_id = ProjectId::from_proto(request.remote_entity_id());
2124    let host_connection_id = session
2125        .db()
2126        .await
2127        .host_for_read_only_project_request(project_id, session.connection_id)
2128        .await?;
2129    let payload = session
2130        .peer
2131        .forward_request(session.connection_id, host_connection_id, request)
2132        .await?;
2133    response.send(payload)?;
2134    Ok(())
2135}
2136
2137/// forward a project request to the host. These requests are disallowed
2138/// for guests.
2139async fn forward_mutating_project_request<T>(
2140    request: T,
2141    response: Response<T>,
2142    session: Session,
2143) -> Result<()>
2144where
2145    T: EntityMessage + RequestMessage,
2146{
2147    let project_id = ProjectId::from_proto(request.remote_entity_id());
2148
2149    let host_connection_id = session
2150        .db()
2151        .await
2152        .host_for_mutating_project_request(project_id, session.connection_id)
2153        .await?;
2154    let payload = session
2155        .peer
2156        .forward_request(session.connection_id, host_connection_id, request)
2157        .await?;
2158    response.send(payload)?;
2159    Ok(())
2160}
2161
2162/// Notify other participants that a new buffer has been created
2163async fn create_buffer_for_peer(
2164    request: proto::CreateBufferForPeer,
2165    session: Session,
2166) -> Result<()> {
2167    session
2168        .db()
2169        .await
2170        .check_user_is_project_host(
2171            ProjectId::from_proto(request.project_id),
2172            session.connection_id,
2173        )
2174        .await?;
2175    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2176    session
2177        .peer
2178        .forward_send(session.connection_id, peer_id.into(), request)?;
2179    Ok(())
2180}
2181
2182/// Notify other participants that a buffer has been updated. This is
2183/// allowed for guests as long as the update is limited to selections.
2184async fn update_buffer(
2185    request: proto::UpdateBuffer,
2186    response: Response<proto::UpdateBuffer>,
2187    session: Session,
2188) -> Result<()> {
2189    let project_id = ProjectId::from_proto(request.project_id);
2190    let mut capability = Capability::ReadOnly;
2191
2192    for op in request.operations.iter() {
2193        match op.variant {
2194            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2195            Some(_) => capability = Capability::ReadWrite,
2196        }
2197    }
2198
2199    let host = {
2200        let guard = session
2201            .db()
2202            .await
2203            .connections_for_buffer_update(project_id, session.connection_id, capability)
2204            .await?;
2205
2206        let (host, guests) = &*guard;
2207
2208        broadcast(
2209            Some(session.connection_id),
2210            guests.clone(),
2211            |connection_id| {
2212                session
2213                    .peer
2214                    .forward_send(session.connection_id, connection_id, request.clone())
2215            },
2216        );
2217
2218        *host
2219    };
2220
2221    if host != session.connection_id {
2222        session
2223            .peer
2224            .forward_request(session.connection_id, host, request.clone())
2225            .await?;
2226    }
2227
2228    response.send(proto::Ack {})?;
2229    Ok(())
2230}
2231
2232async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2233    let project_id = ProjectId::from_proto(message.project_id);
2234
2235    let operation = message.operation.as_ref().context("invalid operation")?;
2236    let capability = match operation.variant.as_ref() {
2237        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2238            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2239                match buffer_op.variant {
2240                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2241                        Capability::ReadOnly
2242                    }
2243                    _ => Capability::ReadWrite,
2244                }
2245            } else {
2246                Capability::ReadWrite
2247            }
2248        }
2249        Some(_) => Capability::ReadWrite,
2250        None => Capability::ReadOnly,
2251    };
2252
2253    let guard = session
2254        .db()
2255        .await
2256        .connections_for_buffer_update(project_id, session.connection_id, capability)
2257        .await?;
2258
2259    let (host, guests) = &*guard;
2260
2261    broadcast(
2262        Some(session.connection_id),
2263        guests.iter().chain([host]).copied(),
2264        |connection_id| {
2265            session
2266                .peer
2267                .forward_send(session.connection_id, connection_id, message.clone())
2268        },
2269    );
2270
2271    Ok(())
2272}
2273
2274/// Notify other participants that a project has been updated.
2275async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2276    request: T,
2277    session: Session,
2278) -> Result<()> {
2279    let project_id = ProjectId::from_proto(request.remote_entity_id());
2280    let project_connection_ids = session
2281        .db()
2282        .await
2283        .project_connection_ids(project_id, session.connection_id, false)
2284        .await?;
2285
2286    broadcast(
2287        Some(session.connection_id),
2288        project_connection_ids.iter().copied(),
2289        |connection_id| {
2290            session
2291                .peer
2292                .forward_send(session.connection_id, connection_id, request.clone())
2293        },
2294    );
2295    Ok(())
2296}
2297
2298/// Start following another user in a call.
2299async fn follow(
2300    request: proto::Follow,
2301    response: Response<proto::Follow>,
2302    session: Session,
2303) -> Result<()> {
2304    let room_id = RoomId::from_proto(request.room_id);
2305    let project_id = request.project_id.map(ProjectId::from_proto);
2306    let leader_id = request
2307        .leader_id
2308        .ok_or_else(|| anyhow!("invalid leader id"))?
2309        .into();
2310    let follower_id = session.connection_id;
2311
2312    session
2313        .db()
2314        .await
2315        .check_room_participants(room_id, leader_id, session.connection_id)
2316        .await?;
2317
2318    let response_payload = session
2319        .peer
2320        .forward_request(session.connection_id, leader_id, request)
2321        .await?;
2322    response.send(response_payload)?;
2323
2324    if let Some(project_id) = project_id {
2325        let room = session
2326            .db()
2327            .await
2328            .follow(room_id, project_id, leader_id, follower_id)
2329            .await?;
2330        room_updated(&room, &session.peer);
2331    }
2332
2333    Ok(())
2334}
2335
2336/// Stop following another user in a call.
2337async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2338    let room_id = RoomId::from_proto(request.room_id);
2339    let project_id = request.project_id.map(ProjectId::from_proto);
2340    let leader_id = request
2341        .leader_id
2342        .ok_or_else(|| anyhow!("invalid leader id"))?
2343        .into();
2344    let follower_id = session.connection_id;
2345
2346    session
2347        .db()
2348        .await
2349        .check_room_participants(room_id, leader_id, session.connection_id)
2350        .await?;
2351
2352    session
2353        .peer
2354        .forward_send(session.connection_id, leader_id, request)?;
2355
2356    if let Some(project_id) = project_id {
2357        let room = session
2358            .db()
2359            .await
2360            .unfollow(room_id, project_id, leader_id, follower_id)
2361            .await?;
2362        room_updated(&room, &session.peer);
2363    }
2364
2365    Ok(())
2366}
2367
2368/// Notify everyone following you of your current location.
2369async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2370    let room_id = RoomId::from_proto(request.room_id);
2371    let database = session.db.lock().await;
2372
2373    let connection_ids = if let Some(project_id) = request.project_id {
2374        let project_id = ProjectId::from_proto(project_id);
2375        database
2376            .project_connection_ids(project_id, session.connection_id, true)
2377            .await?
2378    } else {
2379        database
2380            .room_connection_ids(room_id, session.connection_id)
2381            .await?
2382    };
2383
2384    // For now, don't send view update messages back to that view's current leader.
2385    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2386        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2387        _ => None,
2388    });
2389
2390    for connection_id in connection_ids.iter().cloned() {
2391        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2392            session
2393                .peer
2394                .forward_send(session.connection_id, connection_id, request.clone())?;
2395        }
2396    }
2397    Ok(())
2398}
2399
2400/// Get public data about users.
2401async fn get_users(
2402    request: proto::GetUsers,
2403    response: Response<proto::GetUsers>,
2404    session: Session,
2405) -> Result<()> {
2406    let user_ids = request
2407        .user_ids
2408        .into_iter()
2409        .map(UserId::from_proto)
2410        .collect();
2411    let users = session
2412        .db()
2413        .await
2414        .get_users_by_ids(user_ids)
2415        .await?
2416        .into_iter()
2417        .map(|user| proto::User {
2418            id: user.id.to_proto(),
2419            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2420            github_login: user.github_login,
2421        })
2422        .collect();
2423    response.send(proto::UsersResponse { users })?;
2424    Ok(())
2425}
2426
2427/// Search for users (to invite) buy Github login
2428async fn fuzzy_search_users(
2429    request: proto::FuzzySearchUsers,
2430    response: Response<proto::FuzzySearchUsers>,
2431    session: Session,
2432) -> Result<()> {
2433    let query = request.query;
2434    let users = match query.len() {
2435        0 => vec![],
2436        1 | 2 => session
2437            .db()
2438            .await
2439            .get_user_by_github_login(&query)
2440            .await?
2441            .into_iter()
2442            .collect(),
2443        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2444    };
2445    let users = users
2446        .into_iter()
2447        .filter(|user| user.id != session.user_id())
2448        .map(|user| proto::User {
2449            id: user.id.to_proto(),
2450            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2451            github_login: user.github_login,
2452        })
2453        .collect();
2454    response.send(proto::UsersResponse { users })?;
2455    Ok(())
2456}
2457
2458/// Send a contact request to another user.
2459async fn request_contact(
2460    request: proto::RequestContact,
2461    response: Response<proto::RequestContact>,
2462    session: Session,
2463) -> Result<()> {
2464    let requester_id = session.user_id();
2465    let responder_id = UserId::from_proto(request.responder_id);
2466    if requester_id == responder_id {
2467        return Err(anyhow!("cannot add yourself as a contact"))?;
2468    }
2469
2470    let notifications = session
2471        .db()
2472        .await
2473        .send_contact_request(requester_id, responder_id)
2474        .await?;
2475
2476    // Update outgoing contact requests of requester
2477    let mut update = proto::UpdateContacts::default();
2478    update.outgoing_requests.push(responder_id.to_proto());
2479    for connection_id in session
2480        .connection_pool()
2481        .await
2482        .user_connection_ids(requester_id)
2483    {
2484        session.peer.send(connection_id, update.clone())?;
2485    }
2486
2487    // Update incoming contact requests of responder
2488    let mut update = proto::UpdateContacts::default();
2489    update
2490        .incoming_requests
2491        .push(proto::IncomingContactRequest {
2492            requester_id: requester_id.to_proto(),
2493        });
2494    let connection_pool = session.connection_pool().await;
2495    for connection_id in connection_pool.user_connection_ids(responder_id) {
2496        session.peer.send(connection_id, update.clone())?;
2497    }
2498
2499    send_notifications(&connection_pool, &session.peer, notifications);
2500
2501    response.send(proto::Ack {})?;
2502    Ok(())
2503}
2504
2505/// Accept or decline a contact request
2506async fn respond_to_contact_request(
2507    request: proto::RespondToContactRequest,
2508    response: Response<proto::RespondToContactRequest>,
2509    session: Session,
2510) -> Result<()> {
2511    let responder_id = session.user_id();
2512    let requester_id = UserId::from_proto(request.requester_id);
2513    let db = session.db().await;
2514    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2515        db.dismiss_contact_notification(responder_id, requester_id)
2516            .await?;
2517    } else {
2518        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2519
2520        let notifications = db
2521            .respond_to_contact_request(responder_id, requester_id, accept)
2522            .await?;
2523        let requester_busy = db.is_user_busy(requester_id).await?;
2524        let responder_busy = db.is_user_busy(responder_id).await?;
2525
2526        let pool = session.connection_pool().await;
2527        // Update responder with new contact
2528        let mut update = proto::UpdateContacts::default();
2529        if accept {
2530            update
2531                .contacts
2532                .push(contact_for_user(requester_id, requester_busy, &pool));
2533        }
2534        update
2535            .remove_incoming_requests
2536            .push(requester_id.to_proto());
2537        for connection_id in pool.user_connection_ids(responder_id) {
2538            session.peer.send(connection_id, update.clone())?;
2539        }
2540
2541        // Update requester with new contact
2542        let mut update = proto::UpdateContacts::default();
2543        if accept {
2544            update
2545                .contacts
2546                .push(contact_for_user(responder_id, responder_busy, &pool));
2547        }
2548        update
2549            .remove_outgoing_requests
2550            .push(responder_id.to_proto());
2551
2552        for connection_id in pool.user_connection_ids(requester_id) {
2553            session.peer.send(connection_id, update.clone())?;
2554        }
2555
2556        send_notifications(&pool, &session.peer, notifications);
2557    }
2558
2559    response.send(proto::Ack {})?;
2560    Ok(())
2561}
2562
2563/// Remove a contact.
2564async fn remove_contact(
2565    request: proto::RemoveContact,
2566    response: Response<proto::RemoveContact>,
2567    session: Session,
2568) -> Result<()> {
2569    let requester_id = session.user_id();
2570    let responder_id = UserId::from_proto(request.user_id);
2571    let db = session.db().await;
2572    let (contact_accepted, deleted_notification_id) =
2573        db.remove_contact(requester_id, responder_id).await?;
2574
2575    let pool = session.connection_pool().await;
2576    // Update outgoing contact requests of requester
2577    let mut update = proto::UpdateContacts::default();
2578    if contact_accepted {
2579        update.remove_contacts.push(responder_id.to_proto());
2580    } else {
2581        update
2582            .remove_outgoing_requests
2583            .push(responder_id.to_proto());
2584    }
2585    for connection_id in pool.user_connection_ids(requester_id) {
2586        session.peer.send(connection_id, update.clone())?;
2587    }
2588
2589    // Update incoming contact requests of responder
2590    let mut update = proto::UpdateContacts::default();
2591    if contact_accepted {
2592        update.remove_contacts.push(requester_id.to_proto());
2593    } else {
2594        update
2595            .remove_incoming_requests
2596            .push(requester_id.to_proto());
2597    }
2598    for connection_id in pool.user_connection_ids(responder_id) {
2599        session.peer.send(connection_id, update.clone())?;
2600        if let Some(notification_id) = deleted_notification_id {
2601            session.peer.send(
2602                connection_id,
2603                proto::DeleteNotification {
2604                    notification_id: notification_id.to_proto(),
2605                },
2606            )?;
2607        }
2608    }
2609
2610    response.send(proto::Ack {})?;
2611    Ok(())
2612}
2613
2614fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2615    version.0.minor() < 139
2616}
2617
2618async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2619    let plan = session.current_plan(&session.db().await).await?;
2620
2621    session
2622        .peer
2623        .send(
2624            session.connection_id,
2625            proto::UpdateUserPlan { plan: plan.into() },
2626        )
2627        .trace_err();
2628
2629    Ok(())
2630}
2631
2632async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2633    subscribe_user_to_channels(session.user_id(), &session).await?;
2634    Ok(())
2635}
2636
2637async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2638    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2639    let mut pool = session.connection_pool().await;
2640    for membership in &channels_for_user.channel_memberships {
2641        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2642    }
2643    session.peer.send(
2644        session.connection_id,
2645        build_update_user_channels(&channels_for_user),
2646    )?;
2647    session.peer.send(
2648        session.connection_id,
2649        build_channels_update(channels_for_user),
2650    )?;
2651    Ok(())
2652}
2653
2654/// Creates a new channel.
2655async fn create_channel(
2656    request: proto::CreateChannel,
2657    response: Response<proto::CreateChannel>,
2658    session: Session,
2659) -> Result<()> {
2660    let db = session.db().await;
2661
2662    let parent_id = request.parent_id.map(ChannelId::from_proto);
2663    let (channel, membership) = db
2664        .create_channel(&request.name, parent_id, session.user_id())
2665        .await?;
2666
2667    let root_id = channel.root_id();
2668    let channel = Channel::from_model(channel);
2669
2670    response.send(proto::CreateChannelResponse {
2671        channel: Some(channel.to_proto()),
2672        parent_id: request.parent_id,
2673    })?;
2674
2675    let mut connection_pool = session.connection_pool().await;
2676    if let Some(membership) = membership {
2677        connection_pool.subscribe_to_channel(
2678            membership.user_id,
2679            membership.channel_id,
2680            membership.role,
2681        );
2682        let update = proto::UpdateUserChannels {
2683            channel_memberships: vec![proto::ChannelMembership {
2684                channel_id: membership.channel_id.to_proto(),
2685                role: membership.role.into(),
2686            }],
2687            ..Default::default()
2688        };
2689        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2690            session.peer.send(connection_id, update.clone())?;
2691        }
2692    }
2693
2694    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2695        if !role.can_see_channel(channel.visibility) {
2696            continue;
2697        }
2698
2699        let update = proto::UpdateChannels {
2700            channels: vec![channel.to_proto()],
2701            ..Default::default()
2702        };
2703        session.peer.send(connection_id, update.clone())?;
2704    }
2705
2706    Ok(())
2707}
2708
2709/// Delete a channel
2710async fn delete_channel(
2711    request: proto::DeleteChannel,
2712    response: Response<proto::DeleteChannel>,
2713    session: Session,
2714) -> Result<()> {
2715    let db = session.db().await;
2716
2717    let channel_id = request.channel_id;
2718    let (root_channel, removed_channels) = db
2719        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2720        .await?;
2721    response.send(proto::Ack {})?;
2722
2723    // Notify members of removed channels
2724    let mut update = proto::UpdateChannels::default();
2725    update
2726        .delete_channels
2727        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2728
2729    let connection_pool = session.connection_pool().await;
2730    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2731        session.peer.send(connection_id, update.clone())?;
2732    }
2733
2734    Ok(())
2735}
2736
2737/// Invite someone to join a channel.
2738async fn invite_channel_member(
2739    request: proto::InviteChannelMember,
2740    response: Response<proto::InviteChannelMember>,
2741    session: Session,
2742) -> Result<()> {
2743    let db = session.db().await;
2744    let channel_id = ChannelId::from_proto(request.channel_id);
2745    let invitee_id = UserId::from_proto(request.user_id);
2746    let InviteMemberResult {
2747        channel,
2748        notifications,
2749    } = db
2750        .invite_channel_member(
2751            channel_id,
2752            invitee_id,
2753            session.user_id(),
2754            request.role().into(),
2755        )
2756        .await?;
2757
2758    let update = proto::UpdateChannels {
2759        channel_invitations: vec![channel.to_proto()],
2760        ..Default::default()
2761    };
2762
2763    let connection_pool = session.connection_pool().await;
2764    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2765        session.peer.send(connection_id, update.clone())?;
2766    }
2767
2768    send_notifications(&connection_pool, &session.peer, notifications);
2769
2770    response.send(proto::Ack {})?;
2771    Ok(())
2772}
2773
2774/// remove someone from a channel
2775async fn remove_channel_member(
2776    request: proto::RemoveChannelMember,
2777    response: Response<proto::RemoveChannelMember>,
2778    session: Session,
2779) -> Result<()> {
2780    let db = session.db().await;
2781    let channel_id = ChannelId::from_proto(request.channel_id);
2782    let member_id = UserId::from_proto(request.user_id);
2783
2784    let RemoveChannelMemberResult {
2785        membership_update,
2786        notification_id,
2787    } = db
2788        .remove_channel_member(channel_id, member_id, session.user_id())
2789        .await?;
2790
2791    let mut connection_pool = session.connection_pool().await;
2792    notify_membership_updated(
2793        &mut connection_pool,
2794        membership_update,
2795        member_id,
2796        &session.peer,
2797    );
2798    for connection_id in connection_pool.user_connection_ids(member_id) {
2799        if let Some(notification_id) = notification_id {
2800            session
2801                .peer
2802                .send(
2803                    connection_id,
2804                    proto::DeleteNotification {
2805                        notification_id: notification_id.to_proto(),
2806                    },
2807                )
2808                .trace_err();
2809        }
2810    }
2811
2812    response.send(proto::Ack {})?;
2813    Ok(())
2814}
2815
2816/// Toggle the channel between public and private.
2817/// Care is taken to maintain the invariant that public channels only descend from public channels,
2818/// (though members-only channels can appear at any point in the hierarchy).
2819async fn set_channel_visibility(
2820    request: proto::SetChannelVisibility,
2821    response: Response<proto::SetChannelVisibility>,
2822    session: Session,
2823) -> Result<()> {
2824    let db = session.db().await;
2825    let channel_id = ChannelId::from_proto(request.channel_id);
2826    let visibility = request.visibility().into();
2827
2828    let channel_model = db
2829        .set_channel_visibility(channel_id, visibility, session.user_id())
2830        .await?;
2831    let root_id = channel_model.root_id();
2832    let channel = Channel::from_model(channel_model);
2833
2834    let mut connection_pool = session.connection_pool().await;
2835    for (user_id, role) in connection_pool
2836        .channel_user_ids(root_id)
2837        .collect::<Vec<_>>()
2838        .into_iter()
2839    {
2840        let update = if role.can_see_channel(channel.visibility) {
2841            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2842            proto::UpdateChannels {
2843                channels: vec![channel.to_proto()],
2844                ..Default::default()
2845            }
2846        } else {
2847            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2848            proto::UpdateChannels {
2849                delete_channels: vec![channel.id.to_proto()],
2850                ..Default::default()
2851            }
2852        };
2853
2854        for connection_id in connection_pool.user_connection_ids(user_id) {
2855            session.peer.send(connection_id, update.clone())?;
2856        }
2857    }
2858
2859    response.send(proto::Ack {})?;
2860    Ok(())
2861}
2862
2863/// Alter the role for a user in the channel.
2864async fn set_channel_member_role(
2865    request: proto::SetChannelMemberRole,
2866    response: Response<proto::SetChannelMemberRole>,
2867    session: Session,
2868) -> Result<()> {
2869    let db = session.db().await;
2870    let channel_id = ChannelId::from_proto(request.channel_id);
2871    let member_id = UserId::from_proto(request.user_id);
2872    let result = db
2873        .set_channel_member_role(
2874            channel_id,
2875            session.user_id(),
2876            member_id,
2877            request.role().into(),
2878        )
2879        .await?;
2880
2881    match result {
2882        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2883            let mut connection_pool = session.connection_pool().await;
2884            notify_membership_updated(
2885                &mut connection_pool,
2886                membership_update,
2887                member_id,
2888                &session.peer,
2889            )
2890        }
2891        db::SetMemberRoleResult::InviteUpdated(channel) => {
2892            let update = proto::UpdateChannels {
2893                channel_invitations: vec![channel.to_proto()],
2894                ..Default::default()
2895            };
2896
2897            for connection_id in session
2898                .connection_pool()
2899                .await
2900                .user_connection_ids(member_id)
2901            {
2902                session.peer.send(connection_id, update.clone())?;
2903            }
2904        }
2905    }
2906
2907    response.send(proto::Ack {})?;
2908    Ok(())
2909}
2910
2911/// Change the name of a channel
2912async fn rename_channel(
2913    request: proto::RenameChannel,
2914    response: Response<proto::RenameChannel>,
2915    session: Session,
2916) -> Result<()> {
2917    let db = session.db().await;
2918    let channel_id = ChannelId::from_proto(request.channel_id);
2919    let channel_model = db
2920        .rename_channel(channel_id, session.user_id(), &request.name)
2921        .await?;
2922    let root_id = channel_model.root_id();
2923    let channel = Channel::from_model(channel_model);
2924
2925    response.send(proto::RenameChannelResponse {
2926        channel: Some(channel.to_proto()),
2927    })?;
2928
2929    let connection_pool = session.connection_pool().await;
2930    let update = proto::UpdateChannels {
2931        channels: vec![channel.to_proto()],
2932        ..Default::default()
2933    };
2934    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2935        if role.can_see_channel(channel.visibility) {
2936            session.peer.send(connection_id, update.clone())?;
2937        }
2938    }
2939
2940    Ok(())
2941}
2942
2943/// Move a channel to a new parent.
2944async fn move_channel(
2945    request: proto::MoveChannel,
2946    response: Response<proto::MoveChannel>,
2947    session: Session,
2948) -> Result<()> {
2949    let channel_id = ChannelId::from_proto(request.channel_id);
2950    let to = ChannelId::from_proto(request.to);
2951
2952    let (root_id, channels) = session
2953        .db()
2954        .await
2955        .move_channel(channel_id, to, session.user_id())
2956        .await?;
2957
2958    let connection_pool = session.connection_pool().await;
2959    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2960        let channels = channels
2961            .iter()
2962            .filter_map(|channel| {
2963                if role.can_see_channel(channel.visibility) {
2964                    Some(channel.to_proto())
2965                } else {
2966                    None
2967                }
2968            })
2969            .collect::<Vec<_>>();
2970        if channels.is_empty() {
2971            continue;
2972        }
2973
2974        let update = proto::UpdateChannels {
2975            channels,
2976            ..Default::default()
2977        };
2978
2979        session.peer.send(connection_id, update.clone())?;
2980    }
2981
2982    response.send(Ack {})?;
2983    Ok(())
2984}
2985
2986/// Get the list of channel members
2987async fn get_channel_members(
2988    request: proto::GetChannelMembers,
2989    response: Response<proto::GetChannelMembers>,
2990    session: Session,
2991) -> Result<()> {
2992    let db = session.db().await;
2993    let channel_id = ChannelId::from_proto(request.channel_id);
2994    let limit = if request.limit == 0 {
2995        u16::MAX as u64
2996    } else {
2997        request.limit
2998    };
2999    let (members, users) = db
3000        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3001        .await?;
3002    response.send(proto::GetChannelMembersResponse { members, users })?;
3003    Ok(())
3004}
3005
3006/// Accept or decline a channel invitation.
3007async fn respond_to_channel_invite(
3008    request: proto::RespondToChannelInvite,
3009    response: Response<proto::RespondToChannelInvite>,
3010    session: Session,
3011) -> Result<()> {
3012    let db = session.db().await;
3013    let channel_id = ChannelId::from_proto(request.channel_id);
3014    let RespondToChannelInvite {
3015        membership_update,
3016        notifications,
3017    } = db
3018        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3019        .await?;
3020
3021    let mut connection_pool = session.connection_pool().await;
3022    if let Some(membership_update) = membership_update {
3023        notify_membership_updated(
3024            &mut connection_pool,
3025            membership_update,
3026            session.user_id(),
3027            &session.peer,
3028        );
3029    } else {
3030        let update = proto::UpdateChannels {
3031            remove_channel_invitations: vec![channel_id.to_proto()],
3032            ..Default::default()
3033        };
3034
3035        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3036            session.peer.send(connection_id, update.clone())?;
3037        }
3038    };
3039
3040    send_notifications(&connection_pool, &session.peer, notifications);
3041
3042    response.send(proto::Ack {})?;
3043
3044    Ok(())
3045}
3046
3047/// Join the channels' room
3048async fn join_channel(
3049    request: proto::JoinChannel,
3050    response: Response<proto::JoinChannel>,
3051    session: Session,
3052) -> Result<()> {
3053    let channel_id = ChannelId::from_proto(request.channel_id);
3054    join_channel_internal(channel_id, Box::new(response), session).await
3055}
3056
3057trait JoinChannelInternalResponse {
3058    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3059}
3060impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3061    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3062        Response::<proto::JoinChannel>::send(self, result)
3063    }
3064}
3065impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3066    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3067        Response::<proto::JoinRoom>::send(self, result)
3068    }
3069}
3070
3071async fn join_channel_internal(
3072    channel_id: ChannelId,
3073    response: Box<impl JoinChannelInternalResponse>,
3074    session: Session,
3075) -> Result<()> {
3076    let joined_room = {
3077        let mut db = session.db().await;
3078        // If zed quits without leaving the room, and the user re-opens zed before the
3079        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3080        // room they were in.
3081        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3082            tracing::info!(
3083                stale_connection_id = %connection,
3084                "cleaning up stale connection",
3085            );
3086            drop(db);
3087            leave_room_for_session(&session, connection).await?;
3088            db = session.db().await;
3089        }
3090
3091        let (joined_room, membership_updated, role) = db
3092            .join_channel(channel_id, session.user_id(), session.connection_id)
3093            .await?;
3094
3095        let live_kit_connection_info =
3096            session
3097                .app_state
3098                .livekit_client
3099                .as_ref()
3100                .and_then(|live_kit| {
3101                    let (can_publish, token) = if role == ChannelRole::Guest {
3102                        (
3103                            false,
3104                            live_kit
3105                                .guest_token(
3106                                    &joined_room.room.livekit_room,
3107                                    &session.user_id().to_string(),
3108                                )
3109                                .trace_err()?,
3110                        )
3111                    } else {
3112                        (
3113                            true,
3114                            live_kit
3115                                .room_token(
3116                                    &joined_room.room.livekit_room,
3117                                    &session.user_id().to_string(),
3118                                )
3119                                .trace_err()?,
3120                        )
3121                    };
3122
3123                    Some(LiveKitConnectionInfo {
3124                        server_url: live_kit.url().into(),
3125                        token,
3126                        can_publish,
3127                    })
3128                });
3129
3130        response.send(proto::JoinRoomResponse {
3131            room: Some(joined_room.room.clone()),
3132            channel_id: joined_room
3133                .channel
3134                .as_ref()
3135                .map(|channel| channel.id.to_proto()),
3136            live_kit_connection_info,
3137        })?;
3138
3139        let mut connection_pool = session.connection_pool().await;
3140        if let Some(membership_updated) = membership_updated {
3141            notify_membership_updated(
3142                &mut connection_pool,
3143                membership_updated,
3144                session.user_id(),
3145                &session.peer,
3146            );
3147        }
3148
3149        room_updated(&joined_room.room, &session.peer);
3150
3151        joined_room
3152    };
3153
3154    channel_updated(
3155        &joined_room
3156            .channel
3157            .ok_or_else(|| anyhow!("channel not returned"))?,
3158        &joined_room.room,
3159        &session.peer,
3160        &*session.connection_pool().await,
3161    );
3162
3163    update_user_contacts(session.user_id(), &session).await?;
3164    Ok(())
3165}
3166
3167/// Start editing the channel notes
3168async fn join_channel_buffer(
3169    request: proto::JoinChannelBuffer,
3170    response: Response<proto::JoinChannelBuffer>,
3171    session: Session,
3172) -> Result<()> {
3173    let db = session.db().await;
3174    let channel_id = ChannelId::from_proto(request.channel_id);
3175
3176    let open_response = db
3177        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3178        .await?;
3179
3180    let collaborators = open_response.collaborators.clone();
3181    response.send(open_response)?;
3182
3183    let update = UpdateChannelBufferCollaborators {
3184        channel_id: channel_id.to_proto(),
3185        collaborators: collaborators.clone(),
3186    };
3187    channel_buffer_updated(
3188        session.connection_id,
3189        collaborators
3190            .iter()
3191            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3192        &update,
3193        &session.peer,
3194    );
3195
3196    Ok(())
3197}
3198
3199/// Edit the channel notes
3200async fn update_channel_buffer(
3201    request: proto::UpdateChannelBuffer,
3202    session: Session,
3203) -> Result<()> {
3204    let db = session.db().await;
3205    let channel_id = ChannelId::from_proto(request.channel_id);
3206
3207    let (collaborators, epoch, version) = db
3208        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3209        .await?;
3210
3211    channel_buffer_updated(
3212        session.connection_id,
3213        collaborators.clone(),
3214        &proto::UpdateChannelBuffer {
3215            channel_id: channel_id.to_proto(),
3216            operations: request.operations,
3217        },
3218        &session.peer,
3219    );
3220
3221    let pool = &*session.connection_pool().await;
3222
3223    let non_collaborators =
3224        pool.channel_connection_ids(channel_id)
3225            .filter_map(|(connection_id, _)| {
3226                if collaborators.contains(&connection_id) {
3227                    None
3228                } else {
3229                    Some(connection_id)
3230                }
3231            });
3232
3233    broadcast(None, non_collaborators, |peer_id| {
3234        session.peer.send(
3235            peer_id,
3236            proto::UpdateChannels {
3237                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3238                    channel_id: channel_id.to_proto(),
3239                    epoch: epoch as u64,
3240                    version: version.clone(),
3241                }],
3242                ..Default::default()
3243            },
3244        )
3245    });
3246
3247    Ok(())
3248}
3249
3250/// Rejoin the channel notes after a connection blip
3251async fn rejoin_channel_buffers(
3252    request: proto::RejoinChannelBuffers,
3253    response: Response<proto::RejoinChannelBuffers>,
3254    session: Session,
3255) -> Result<()> {
3256    let db = session.db().await;
3257    let buffers = db
3258        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3259        .await?;
3260
3261    for rejoined_buffer in &buffers {
3262        let collaborators_to_notify = rejoined_buffer
3263            .buffer
3264            .collaborators
3265            .iter()
3266            .filter_map(|c| Some(c.peer_id?.into()));
3267        channel_buffer_updated(
3268            session.connection_id,
3269            collaborators_to_notify,
3270            &proto::UpdateChannelBufferCollaborators {
3271                channel_id: rejoined_buffer.buffer.channel_id,
3272                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3273            },
3274            &session.peer,
3275        );
3276    }
3277
3278    response.send(proto::RejoinChannelBuffersResponse {
3279        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3280    })?;
3281
3282    Ok(())
3283}
3284
3285/// Stop editing the channel notes
3286async fn leave_channel_buffer(
3287    request: proto::LeaveChannelBuffer,
3288    response: Response<proto::LeaveChannelBuffer>,
3289    session: Session,
3290) -> Result<()> {
3291    let db = session.db().await;
3292    let channel_id = ChannelId::from_proto(request.channel_id);
3293
3294    let left_buffer = db
3295        .leave_channel_buffer(channel_id, session.connection_id)
3296        .await?;
3297
3298    response.send(Ack {})?;
3299
3300    channel_buffer_updated(
3301        session.connection_id,
3302        left_buffer.connections,
3303        &proto::UpdateChannelBufferCollaborators {
3304            channel_id: channel_id.to_proto(),
3305            collaborators: left_buffer.collaborators,
3306        },
3307        &session.peer,
3308    );
3309
3310    Ok(())
3311}
3312
3313fn channel_buffer_updated<T: EnvelopedMessage>(
3314    sender_id: ConnectionId,
3315    collaborators: impl IntoIterator<Item = ConnectionId>,
3316    message: &T,
3317    peer: &Peer,
3318) {
3319    broadcast(Some(sender_id), collaborators, |peer_id| {
3320        peer.send(peer_id, message.clone())
3321    });
3322}
3323
3324fn send_notifications(
3325    connection_pool: &ConnectionPool,
3326    peer: &Peer,
3327    notifications: db::NotificationBatch,
3328) {
3329    for (user_id, notification) in notifications {
3330        for connection_id in connection_pool.user_connection_ids(user_id) {
3331            if let Err(error) = peer.send(
3332                connection_id,
3333                proto::AddNotification {
3334                    notification: Some(notification.clone()),
3335                },
3336            ) {
3337                tracing::error!(
3338                    "failed to send notification to {:?} {}",
3339                    connection_id,
3340                    error
3341                );
3342            }
3343        }
3344    }
3345}
3346
3347/// Send a message to the channel
3348async fn send_channel_message(
3349    request: proto::SendChannelMessage,
3350    response: Response<proto::SendChannelMessage>,
3351    session: Session,
3352) -> Result<()> {
3353    // Validate the message body.
3354    let body = request.body.trim().to_string();
3355    if body.len() > MAX_MESSAGE_LEN {
3356        return Err(anyhow!("message is too long"))?;
3357    }
3358    if body.is_empty() {
3359        return Err(anyhow!("message can't be blank"))?;
3360    }
3361
3362    // TODO: adjust mentions if body is trimmed
3363
3364    let timestamp = OffsetDateTime::now_utc();
3365    let nonce = request
3366        .nonce
3367        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3368
3369    let channel_id = ChannelId::from_proto(request.channel_id);
3370    let CreatedChannelMessage {
3371        message_id,
3372        participant_connection_ids,
3373        notifications,
3374    } = session
3375        .db()
3376        .await
3377        .create_channel_message(
3378            channel_id,
3379            session.user_id(),
3380            &body,
3381            &request.mentions,
3382            timestamp,
3383            nonce.clone().into(),
3384            request.reply_to_message_id.map(MessageId::from_proto),
3385        )
3386        .await?;
3387
3388    let message = proto::ChannelMessage {
3389        sender_id: session.user_id().to_proto(),
3390        id: message_id.to_proto(),
3391        body,
3392        mentions: request.mentions,
3393        timestamp: timestamp.unix_timestamp() as u64,
3394        nonce: Some(nonce),
3395        reply_to_message_id: request.reply_to_message_id,
3396        edited_at: None,
3397    };
3398    broadcast(
3399        Some(session.connection_id),
3400        participant_connection_ids.clone(),
3401        |connection| {
3402            session.peer.send(
3403                connection,
3404                proto::ChannelMessageSent {
3405                    channel_id: channel_id.to_proto(),
3406                    message: Some(message.clone()),
3407                },
3408            )
3409        },
3410    );
3411    response.send(proto::SendChannelMessageResponse {
3412        message: Some(message),
3413    })?;
3414
3415    let pool = &*session.connection_pool().await;
3416    let non_participants =
3417        pool.channel_connection_ids(channel_id)
3418            .filter_map(|(connection_id, _)| {
3419                if participant_connection_ids.contains(&connection_id) {
3420                    None
3421                } else {
3422                    Some(connection_id)
3423                }
3424            });
3425    broadcast(None, non_participants, |peer_id| {
3426        session.peer.send(
3427            peer_id,
3428            proto::UpdateChannels {
3429                latest_channel_message_ids: vec![proto::ChannelMessageId {
3430                    channel_id: channel_id.to_proto(),
3431                    message_id: message_id.to_proto(),
3432                }],
3433                ..Default::default()
3434            },
3435        )
3436    });
3437    send_notifications(pool, &session.peer, notifications);
3438
3439    Ok(())
3440}
3441
3442/// Delete a channel message
3443async fn remove_channel_message(
3444    request: proto::RemoveChannelMessage,
3445    response: Response<proto::RemoveChannelMessage>,
3446    session: Session,
3447) -> Result<()> {
3448    let channel_id = ChannelId::from_proto(request.channel_id);
3449    let message_id = MessageId::from_proto(request.message_id);
3450    let (connection_ids, existing_notification_ids) = session
3451        .db()
3452        .await
3453        .remove_channel_message(channel_id, message_id, session.user_id())
3454        .await?;
3455
3456    broadcast(
3457        Some(session.connection_id),
3458        connection_ids,
3459        move |connection| {
3460            session.peer.send(connection, request.clone())?;
3461
3462            for notification_id in &existing_notification_ids {
3463                session.peer.send(
3464                    connection,
3465                    proto::DeleteNotification {
3466                        notification_id: (*notification_id).to_proto(),
3467                    },
3468                )?;
3469            }
3470
3471            Ok(())
3472        },
3473    );
3474    response.send(proto::Ack {})?;
3475    Ok(())
3476}
3477
3478async fn update_channel_message(
3479    request: proto::UpdateChannelMessage,
3480    response: Response<proto::UpdateChannelMessage>,
3481    session: Session,
3482) -> Result<()> {
3483    let channel_id = ChannelId::from_proto(request.channel_id);
3484    let message_id = MessageId::from_proto(request.message_id);
3485    let updated_at = OffsetDateTime::now_utc();
3486    let UpdatedChannelMessage {
3487        message_id,
3488        participant_connection_ids,
3489        notifications,
3490        reply_to_message_id,
3491        timestamp,
3492        deleted_mention_notification_ids,
3493        updated_mention_notifications,
3494    } = session
3495        .db()
3496        .await
3497        .update_channel_message(
3498            channel_id,
3499            message_id,
3500            session.user_id(),
3501            request.body.as_str(),
3502            &request.mentions,
3503            updated_at,
3504        )
3505        .await?;
3506
3507    let nonce = request
3508        .nonce
3509        .clone()
3510        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3511
3512    let message = proto::ChannelMessage {
3513        sender_id: session.user_id().to_proto(),
3514        id: message_id.to_proto(),
3515        body: request.body.clone(),
3516        mentions: request.mentions.clone(),
3517        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3518        nonce: Some(nonce),
3519        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3520        edited_at: Some(updated_at.unix_timestamp() as u64),
3521    };
3522
3523    response.send(proto::Ack {})?;
3524
3525    let pool = &*session.connection_pool().await;
3526    broadcast(
3527        Some(session.connection_id),
3528        participant_connection_ids,
3529        |connection| {
3530            session.peer.send(
3531                connection,
3532                proto::ChannelMessageUpdate {
3533                    channel_id: channel_id.to_proto(),
3534                    message: Some(message.clone()),
3535                },
3536            )?;
3537
3538            for notification_id in &deleted_mention_notification_ids {
3539                session.peer.send(
3540                    connection,
3541                    proto::DeleteNotification {
3542                        notification_id: (*notification_id).to_proto(),
3543                    },
3544                )?;
3545            }
3546
3547            for notification in &updated_mention_notifications {
3548                session.peer.send(
3549                    connection,
3550                    proto::UpdateNotification {
3551                        notification: Some(notification.clone()),
3552                    },
3553                )?;
3554            }
3555
3556            Ok(())
3557        },
3558    );
3559
3560    send_notifications(pool, &session.peer, notifications);
3561
3562    Ok(())
3563}
3564
3565/// Mark a channel message as read
3566async fn acknowledge_channel_message(
3567    request: proto::AckChannelMessage,
3568    session: Session,
3569) -> Result<()> {
3570    let channel_id = ChannelId::from_proto(request.channel_id);
3571    let message_id = MessageId::from_proto(request.message_id);
3572    let notifications = session
3573        .db()
3574        .await
3575        .observe_channel_message(channel_id, session.user_id(), message_id)
3576        .await?;
3577    send_notifications(
3578        &*session.connection_pool().await,
3579        &session.peer,
3580        notifications,
3581    );
3582    Ok(())
3583}
3584
3585/// Mark a buffer version as synced
3586async fn acknowledge_buffer_version(
3587    request: proto::AckBufferOperation,
3588    session: Session,
3589) -> Result<()> {
3590    let buffer_id = BufferId::from_proto(request.buffer_id);
3591    session
3592        .db()
3593        .await
3594        .observe_buffer_version(
3595            buffer_id,
3596            session.user_id(),
3597            request.epoch as i32,
3598            &request.version,
3599        )
3600        .await?;
3601    Ok(())
3602}
3603
3604async fn count_language_model_tokens(
3605    request: proto::CountLanguageModelTokens,
3606    response: Response<proto::CountLanguageModelTokens>,
3607    session: Session,
3608    config: &Config,
3609) -> Result<()> {
3610    authorize_access_to_legacy_llm_endpoints(&session).await?;
3611
3612    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3613        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3614        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3615    };
3616
3617    session
3618        .app_state
3619        .rate_limiter
3620        .check(&*rate_limit, session.user_id())
3621        .await?;
3622
3623    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3624        Some(proto::LanguageModelProvider::Google) => {
3625            let api_key = config
3626                .google_ai_api_key
3627                .as_ref()
3628                .context("no Google AI API key configured on the server")?;
3629            google_ai::count_tokens(
3630                session.http_client.as_ref(),
3631                google_ai::API_URL,
3632                api_key,
3633                serde_json::from_str(&request.request)?,
3634            )
3635            .await?
3636        }
3637        _ => return Err(anyhow!("unsupported provider"))?,
3638    };
3639
3640    response.send(proto::CountLanguageModelTokensResponse {
3641        token_count: result.total_tokens as u32,
3642    })?;
3643
3644    Ok(())
3645}
3646
3647struct ZedProCountLanguageModelTokensRateLimit;
3648
3649impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3650    fn capacity(&self) -> usize {
3651        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3652            .ok()
3653            .and_then(|v| v.parse().ok())
3654            .unwrap_or(600) // Picked arbitrarily
3655    }
3656
3657    fn refill_duration(&self) -> chrono::Duration {
3658        chrono::Duration::hours(1)
3659    }
3660
3661    fn db_name(&self) -> &'static str {
3662        "zed-pro:count-language-model-tokens"
3663    }
3664}
3665
3666struct FreeCountLanguageModelTokensRateLimit;
3667
3668impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3669    fn capacity(&self) -> usize {
3670        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3671            .ok()
3672            .and_then(|v| v.parse().ok())
3673            .unwrap_or(600 / 10) // Picked arbitrarily
3674    }
3675
3676    fn refill_duration(&self) -> chrono::Duration {
3677        chrono::Duration::hours(1)
3678    }
3679
3680    fn db_name(&self) -> &'static str {
3681        "free:count-language-model-tokens"
3682    }
3683}
3684
3685struct ZedProComputeEmbeddingsRateLimit;
3686
3687impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3688    fn capacity(&self) -> usize {
3689        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3690            .ok()
3691            .and_then(|v| v.parse().ok())
3692            .unwrap_or(5000) // Picked arbitrarily
3693    }
3694
3695    fn refill_duration(&self) -> chrono::Duration {
3696        chrono::Duration::hours(1)
3697    }
3698
3699    fn db_name(&self) -> &'static str {
3700        "zed-pro:compute-embeddings"
3701    }
3702}
3703
3704struct FreeComputeEmbeddingsRateLimit;
3705
3706impl RateLimit for FreeComputeEmbeddingsRateLimit {
3707    fn capacity(&self) -> usize {
3708        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3709            .ok()
3710            .and_then(|v| v.parse().ok())
3711            .unwrap_or(5000 / 10) // Picked arbitrarily
3712    }
3713
3714    fn refill_duration(&self) -> chrono::Duration {
3715        chrono::Duration::hours(1)
3716    }
3717
3718    fn db_name(&self) -> &'static str {
3719        "free:compute-embeddings"
3720    }
3721}
3722
3723async fn compute_embeddings(
3724    request: proto::ComputeEmbeddings,
3725    response: Response<proto::ComputeEmbeddings>,
3726    session: Session,
3727    api_key: Option<Arc<str>>,
3728) -> Result<()> {
3729    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3730    authorize_access_to_legacy_llm_endpoints(&session).await?;
3731
3732    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3733        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3734        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3735    };
3736
3737    session
3738        .app_state
3739        .rate_limiter
3740        .check(&*rate_limit, session.user_id())
3741        .await?;
3742
3743    let embeddings = match request.model.as_str() {
3744        "openai/text-embedding-3-small" => {
3745            open_ai::embed(
3746                session.http_client.as_ref(),
3747                OPEN_AI_API_URL,
3748                &api_key,
3749                OpenAiEmbeddingModel::TextEmbedding3Small,
3750                request.texts.iter().map(|text| text.as_str()),
3751            )
3752            .await?
3753        }
3754        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3755    };
3756
3757    let embeddings = request
3758        .texts
3759        .iter()
3760        .map(|text| {
3761            let mut hasher = sha2::Sha256::new();
3762            hasher.update(text.as_bytes());
3763            let result = hasher.finalize();
3764            result.to_vec()
3765        })
3766        .zip(
3767            embeddings
3768                .data
3769                .into_iter()
3770                .map(|embedding| embedding.embedding),
3771        )
3772        .collect::<HashMap<_, _>>();
3773
3774    let db = session.db().await;
3775    db.save_embeddings(&request.model, &embeddings)
3776        .await
3777        .context("failed to save embeddings")
3778        .trace_err();
3779
3780    response.send(proto::ComputeEmbeddingsResponse {
3781        embeddings: embeddings
3782            .into_iter()
3783            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3784            .collect(),
3785    })?;
3786    Ok(())
3787}
3788
3789async fn get_cached_embeddings(
3790    request: proto::GetCachedEmbeddings,
3791    response: Response<proto::GetCachedEmbeddings>,
3792    session: Session,
3793) -> Result<()> {
3794    authorize_access_to_legacy_llm_endpoints(&session).await?;
3795
3796    let db = session.db().await;
3797    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3798
3799    response.send(proto::GetCachedEmbeddingsResponse {
3800        embeddings: embeddings
3801            .into_iter()
3802            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3803            .collect(),
3804    })?;
3805    Ok(())
3806}
3807
3808/// This is leftover from before the LLM service.
3809///
3810/// The endpoints protected by this check will be moved there eventually.
3811async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3812    if session.is_staff() {
3813        Ok(())
3814    } else {
3815        Err(anyhow!("permission denied"))?
3816    }
3817}
3818
3819/// Get a Supermaven API key for the user
3820async fn get_supermaven_api_key(
3821    _request: proto::GetSupermavenApiKey,
3822    response: Response<proto::GetSupermavenApiKey>,
3823    session: Session,
3824) -> Result<()> {
3825    let user_id: String = session.user_id().to_string();
3826    if !session.is_staff() {
3827        return Err(anyhow!("supermaven not enabled for this account"))?;
3828    }
3829
3830    let email = session
3831        .email()
3832        .ok_or_else(|| anyhow!("user must have an email"))?;
3833
3834    let supermaven_admin_api = session
3835        .supermaven_client
3836        .as_ref()
3837        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3838
3839    let result = supermaven_admin_api
3840        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3841        .await?;
3842
3843    response.send(proto::GetSupermavenApiKeyResponse {
3844        api_key: result.api_key,
3845    })?;
3846
3847    Ok(())
3848}
3849
3850/// Start receiving chat updates for a channel
3851async fn join_channel_chat(
3852    request: proto::JoinChannelChat,
3853    response: Response<proto::JoinChannelChat>,
3854    session: Session,
3855) -> Result<()> {
3856    let channel_id = ChannelId::from_proto(request.channel_id);
3857
3858    let db = session.db().await;
3859    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3860        .await?;
3861    let messages = db
3862        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3863        .await?;
3864    response.send(proto::JoinChannelChatResponse {
3865        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3866        messages,
3867    })?;
3868    Ok(())
3869}
3870
3871/// Stop receiving chat updates for a channel
3872async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3873    let channel_id = ChannelId::from_proto(request.channel_id);
3874    session
3875        .db()
3876        .await
3877        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3878        .await?;
3879    Ok(())
3880}
3881
3882/// Retrieve the chat history for a channel
3883async fn get_channel_messages(
3884    request: proto::GetChannelMessages,
3885    response: Response<proto::GetChannelMessages>,
3886    session: Session,
3887) -> Result<()> {
3888    let channel_id = ChannelId::from_proto(request.channel_id);
3889    let messages = session
3890        .db()
3891        .await
3892        .get_channel_messages(
3893            channel_id,
3894            session.user_id(),
3895            MESSAGE_COUNT_PER_PAGE,
3896            Some(MessageId::from_proto(request.before_message_id)),
3897        )
3898        .await?;
3899    response.send(proto::GetChannelMessagesResponse {
3900        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3901        messages,
3902    })?;
3903    Ok(())
3904}
3905
3906/// Retrieve specific chat messages
3907async fn get_channel_messages_by_id(
3908    request: proto::GetChannelMessagesById,
3909    response: Response<proto::GetChannelMessagesById>,
3910    session: Session,
3911) -> Result<()> {
3912    let message_ids = request
3913        .message_ids
3914        .iter()
3915        .map(|id| MessageId::from_proto(*id))
3916        .collect::<Vec<_>>();
3917    let messages = session
3918        .db()
3919        .await
3920        .get_channel_messages_by_id(session.user_id(), &message_ids)
3921        .await?;
3922    response.send(proto::GetChannelMessagesResponse {
3923        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3924        messages,
3925    })?;
3926    Ok(())
3927}
3928
3929/// Retrieve the current users notifications
3930async fn get_notifications(
3931    request: proto::GetNotifications,
3932    response: Response<proto::GetNotifications>,
3933    session: Session,
3934) -> Result<()> {
3935    let notifications = session
3936        .db()
3937        .await
3938        .get_notifications(
3939            session.user_id(),
3940            NOTIFICATION_COUNT_PER_PAGE,
3941            request.before_id.map(db::NotificationId::from_proto),
3942        )
3943        .await?;
3944    response.send(proto::GetNotificationsResponse {
3945        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3946        notifications,
3947    })?;
3948    Ok(())
3949}
3950
3951/// Mark notifications as read
3952async fn mark_notification_as_read(
3953    request: proto::MarkNotificationRead,
3954    response: Response<proto::MarkNotificationRead>,
3955    session: Session,
3956) -> Result<()> {
3957    let database = &session.db().await;
3958    let notifications = database
3959        .mark_notification_as_read_by_id(
3960            session.user_id(),
3961            NotificationId::from_proto(request.notification_id),
3962        )
3963        .await?;
3964    send_notifications(
3965        &*session.connection_pool().await,
3966        &session.peer,
3967        notifications,
3968    );
3969    response.send(proto::Ack {})?;
3970    Ok(())
3971}
3972
3973/// Get the current users information
3974async fn get_private_user_info(
3975    _request: proto::GetPrivateUserInfo,
3976    response: Response<proto::GetPrivateUserInfo>,
3977    session: Session,
3978) -> Result<()> {
3979    let db = session.db().await;
3980
3981    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3982    let user = db
3983        .get_user_by_id(session.user_id())
3984        .await?
3985        .ok_or_else(|| anyhow!("user not found"))?;
3986    let flags = db.get_user_flags(session.user_id()).await?;
3987
3988    response.send(proto::GetPrivateUserInfoResponse {
3989        metrics_id,
3990        staff: user.admin,
3991        flags,
3992        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3993    })?;
3994    Ok(())
3995}
3996
3997/// Accept the terms of service (tos) on behalf of the current user
3998async fn accept_terms_of_service(
3999    _request: proto::AcceptTermsOfService,
4000    response: Response<proto::AcceptTermsOfService>,
4001    session: Session,
4002) -> Result<()> {
4003    let db = session.db().await;
4004
4005    let accepted_tos_at = Utc::now();
4006    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4007        .await?;
4008
4009    response.send(proto::AcceptTermsOfServiceResponse {
4010        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4011    })?;
4012    Ok(())
4013}
4014
4015/// The minimum account age an account must have in order to use the LLM service.
4016const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4017
4018async fn get_llm_api_token(
4019    _request: proto::GetLlmToken,
4020    response: Response<proto::GetLlmToken>,
4021    session: Session,
4022) -> Result<()> {
4023    let db = session.db().await;
4024
4025    let flags = db.get_user_flags(session.user_id()).await?;
4026    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4027    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4028    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4029
4030    if !session.is_staff() && !has_language_models_feature_flag {
4031        Err(anyhow!("permission denied"))?
4032    }
4033
4034    let user_id = session.user_id();
4035    let user = db
4036        .get_user_by_id(user_id)
4037        .await?
4038        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4039
4040    if user.accepted_tos_at.is_none() {
4041        Err(anyhow!("terms of service not accepted"))?
4042    }
4043
4044    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4045
4046    let bypass_account_age_check =
4047        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4048    if !bypass_account_age_check {
4049        let mut account_created_at = user.created_at;
4050        if let Some(github_created_at) = user.github_user_created_at {
4051            account_created_at = account_created_at.min(github_created_at);
4052        }
4053        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4054            Err(anyhow!("account too young"))?
4055        }
4056    }
4057
4058    let billing_preferences = db.get_billing_preferences(user.id).await?;
4059
4060    let token = LlmTokenClaims::create(
4061        &user,
4062        session.is_staff(),
4063        billing_preferences,
4064        has_llm_closed_beta_feature_flag,
4065        has_predict_edits_feature_flag,
4066        has_llm_subscription,
4067        session.current_plan(&db).await?,
4068        session.system_id.clone(),
4069        &session.app_state.config,
4070    )?;
4071    response.send(proto::GetLlmTokenResponse { token })?;
4072    Ok(())
4073}
4074
4075fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4076    let message = match message {
4077        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4078        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4079        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4080        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4081        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4082            code: frame.code.into(),
4083            reason: frame.reason,
4084        })),
4085        // We should never receive a frame while reading the message, according
4086        // to the `tungstenite` maintainers:
4087        //
4088        // > It cannot occur when you read messages from the WebSocket, but it
4089        // > can be used when you want to send the raw frames (e.g. you want to
4090        // > send the frames to the WebSocket without composing the full message first).
4091        // >
4092        // > — https://github.com/snapview/tungstenite-rs/issues/268
4093        TungsteniteMessage::Frame(_) => {
4094            bail!("received an unexpected frame while reading the message")
4095        }
4096    };
4097
4098    Ok(message)
4099}
4100
4101fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4102    match message {
4103        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4104        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4105        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4106        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4107        AxumMessage::Close(frame) => {
4108            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4109                code: frame.code.into(),
4110                reason: frame.reason,
4111            }))
4112        }
4113    }
4114}
4115
4116fn notify_membership_updated(
4117    connection_pool: &mut ConnectionPool,
4118    result: MembershipUpdated,
4119    user_id: UserId,
4120    peer: &Peer,
4121) {
4122    for membership in &result.new_channels.channel_memberships {
4123        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4124    }
4125    for channel_id in &result.removed_channels {
4126        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4127    }
4128
4129    let user_channels_update = proto::UpdateUserChannels {
4130        channel_memberships: result
4131            .new_channels
4132            .channel_memberships
4133            .iter()
4134            .map(|cm| proto::ChannelMembership {
4135                channel_id: cm.channel_id.to_proto(),
4136                role: cm.role.into(),
4137            })
4138            .collect(),
4139        ..Default::default()
4140    };
4141
4142    let mut update = build_channels_update(result.new_channels);
4143    update.delete_channels = result
4144        .removed_channels
4145        .into_iter()
4146        .map(|id| id.to_proto())
4147        .collect();
4148    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4149
4150    for connection_id in connection_pool.user_connection_ids(user_id) {
4151        peer.send(connection_id, user_channels_update.clone())
4152            .trace_err();
4153        peer.send(connection_id, update.clone()).trace_err();
4154    }
4155}
4156
4157fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4158    proto::UpdateUserChannels {
4159        channel_memberships: channels
4160            .channel_memberships
4161            .iter()
4162            .map(|m| proto::ChannelMembership {
4163                channel_id: m.channel_id.to_proto(),
4164                role: m.role.into(),
4165            })
4166            .collect(),
4167        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4168        observed_channel_message_id: channels.observed_channel_messages.clone(),
4169    }
4170}
4171
4172fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4173    let mut update = proto::UpdateChannels::default();
4174
4175    for channel in channels.channels {
4176        update.channels.push(channel.to_proto());
4177    }
4178
4179    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4180    update.latest_channel_message_ids = channels.latest_channel_messages;
4181
4182    for (channel_id, participants) in channels.channel_participants {
4183        update
4184            .channel_participants
4185            .push(proto::ChannelParticipants {
4186                channel_id: channel_id.to_proto(),
4187                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4188            });
4189    }
4190
4191    for channel in channels.invited_channels {
4192        update.channel_invitations.push(channel.to_proto());
4193    }
4194
4195    update
4196}
4197
4198fn build_initial_contacts_update(
4199    contacts: Vec<db::Contact>,
4200    pool: &ConnectionPool,
4201) -> proto::UpdateContacts {
4202    let mut update = proto::UpdateContacts::default();
4203
4204    for contact in contacts {
4205        match contact {
4206            db::Contact::Accepted { user_id, busy } => {
4207                update.contacts.push(contact_for_user(user_id, busy, pool));
4208            }
4209            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4210            db::Contact::Incoming { user_id } => {
4211                update
4212                    .incoming_requests
4213                    .push(proto::IncomingContactRequest {
4214                        requester_id: user_id.to_proto(),
4215                    })
4216            }
4217        }
4218    }
4219
4220    update
4221}
4222
4223fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4224    proto::Contact {
4225        user_id: user_id.to_proto(),
4226        online: pool.is_user_online(user_id),
4227        busy,
4228    }
4229}
4230
4231fn room_updated(room: &proto::Room, peer: &Peer) {
4232    broadcast(
4233        None,
4234        room.participants
4235            .iter()
4236            .filter_map(|participant| Some(participant.peer_id?.into())),
4237        |peer_id| {
4238            peer.send(
4239                peer_id,
4240                proto::RoomUpdated {
4241                    room: Some(room.clone()),
4242                },
4243            )
4244        },
4245    );
4246}
4247
4248fn channel_updated(
4249    channel: &db::channel::Model,
4250    room: &proto::Room,
4251    peer: &Peer,
4252    pool: &ConnectionPool,
4253) {
4254    let participants = room
4255        .participants
4256        .iter()
4257        .map(|p| p.user_id)
4258        .collect::<Vec<_>>();
4259
4260    broadcast(
4261        None,
4262        pool.channel_connection_ids(channel.root_id())
4263            .filter_map(|(channel_id, role)| {
4264                role.can_see_channel(channel.visibility)
4265                    .then_some(channel_id)
4266            }),
4267        |peer_id| {
4268            peer.send(
4269                peer_id,
4270                proto::UpdateChannels {
4271                    channel_participants: vec![proto::ChannelParticipants {
4272                        channel_id: channel.id.to_proto(),
4273                        participant_user_ids: participants.clone(),
4274                    }],
4275                    ..Default::default()
4276                },
4277            )
4278        },
4279    );
4280}
4281
4282async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4283    let db = session.db().await;
4284
4285    let contacts = db.get_contacts(user_id).await?;
4286    let busy = db.is_user_busy(user_id).await?;
4287
4288    let pool = session.connection_pool().await;
4289    let updated_contact = contact_for_user(user_id, busy, &pool);
4290    for contact in contacts {
4291        if let db::Contact::Accepted {
4292            user_id: contact_user_id,
4293            ..
4294        } = contact
4295        {
4296            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4297                session
4298                    .peer
4299                    .send(
4300                        contact_conn_id,
4301                        proto::UpdateContacts {
4302                            contacts: vec![updated_contact.clone()],
4303                            remove_contacts: Default::default(),
4304                            incoming_requests: Default::default(),
4305                            remove_incoming_requests: Default::default(),
4306                            outgoing_requests: Default::default(),
4307                            remove_outgoing_requests: Default::default(),
4308                        },
4309                    )
4310                    .trace_err();
4311            }
4312        }
4313    }
4314    Ok(())
4315}
4316
4317async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4318    let mut contacts_to_update = HashSet::default();
4319
4320    let room_id;
4321    let canceled_calls_to_user_ids;
4322    let livekit_room;
4323    let delete_livekit_room;
4324    let room;
4325    let channel;
4326
4327    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4328        contacts_to_update.insert(session.user_id());
4329
4330        for project in left_room.left_projects.values() {
4331            project_left(project, session);
4332        }
4333
4334        room_id = RoomId::from_proto(left_room.room.id);
4335        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4336        livekit_room = mem::take(&mut left_room.room.livekit_room);
4337        delete_livekit_room = left_room.deleted;
4338        room = mem::take(&mut left_room.room);
4339        channel = mem::take(&mut left_room.channel);
4340
4341        room_updated(&room, &session.peer);
4342    } else {
4343        return Ok(());
4344    }
4345
4346    if let Some(channel) = channel {
4347        channel_updated(
4348            &channel,
4349            &room,
4350            &session.peer,
4351            &*session.connection_pool().await,
4352        );
4353    }
4354
4355    {
4356        let pool = session.connection_pool().await;
4357        for canceled_user_id in canceled_calls_to_user_ids {
4358            for connection_id in pool.user_connection_ids(canceled_user_id) {
4359                session
4360                    .peer
4361                    .send(
4362                        connection_id,
4363                        proto::CallCanceled {
4364                            room_id: room_id.to_proto(),
4365                        },
4366                    )
4367                    .trace_err();
4368            }
4369            contacts_to_update.insert(canceled_user_id);
4370        }
4371    }
4372
4373    for contact_user_id in contacts_to_update {
4374        update_user_contacts(contact_user_id, session).await?;
4375    }
4376
4377    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4378        live_kit
4379            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4380            .await
4381            .trace_err();
4382
4383        if delete_livekit_room {
4384            live_kit.delete_room(livekit_room).await.trace_err();
4385        }
4386    }
4387
4388    Ok(())
4389}
4390
4391async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4392    let left_channel_buffers = session
4393        .db()
4394        .await
4395        .leave_channel_buffers(session.connection_id)
4396        .await?;
4397
4398    for left_buffer in left_channel_buffers {
4399        channel_buffer_updated(
4400            session.connection_id,
4401            left_buffer.connections,
4402            &proto::UpdateChannelBufferCollaborators {
4403                channel_id: left_buffer.channel_id.to_proto(),
4404                collaborators: left_buffer.collaborators,
4405            },
4406            &session.peer,
4407        );
4408    }
4409
4410    Ok(())
4411}
4412
4413fn project_left(project: &db::LeftProject, session: &Session) {
4414    for connection_id in &project.connection_ids {
4415        if project.should_unshare {
4416            session
4417                .peer
4418                .send(
4419                    *connection_id,
4420                    proto::UnshareProject {
4421                        project_id: project.id.to_proto(),
4422                    },
4423                )
4424                .trace_err();
4425        } else {
4426            session
4427                .peer
4428                .send(
4429                    *connection_id,
4430                    proto::RemoveProjectCollaborator {
4431                        project_id: project.id.to_proto(),
4432                        peer_id: Some(session.connection_id.into()),
4433                    },
4434                )
4435                .trace_err();
4436        }
4437    }
4438}
4439
4440pub trait ResultExt {
4441    type Ok;
4442
4443    fn trace_err(self) -> Option<Self::Ok>;
4444}
4445
4446impl<T, E> ResultExt for Result<T, E>
4447where
4448    E: std::fmt::Debug,
4449{
4450    type Ok = T;
4451
4452    #[track_caller]
4453    fn trace_err(self) -> Option<T> {
4454        match self {
4455            Ok(value) => Some(value),
4456            Err(error) => {
4457                tracing::error!("{:?}", error);
4458                None
4459            }
4460        }
4461    }
4462}