rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
 313            .add_request_handler(
 314                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 315            )
 316            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 317            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 318            .add_request_handler(
 319                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 320            )
 321            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 322            .add_request_handler(
 323                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 324            )
 325            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 326            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 327            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 330            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 332            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 337            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 338            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 339            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 340            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 341            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 342            .add_message_handler(create_buffer_for_peer)
 343            .add_request_handler(update_buffer)
 344            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 345            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 346            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 347            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 349            .add_request_handler(get_users)
 350            .add_request_handler(fuzzy_search_users)
 351            .add_request_handler(request_contact)
 352            .add_request_handler(remove_contact)
 353            .add_request_handler(respond_to_contact_request)
 354            .add_message_handler(subscribe_to_channels)
 355            .add_request_handler(create_channel)
 356            .add_request_handler(delete_channel)
 357            .add_request_handler(invite_channel_member)
 358            .add_request_handler(remove_channel_member)
 359            .add_request_handler(set_channel_member_role)
 360            .add_request_handler(set_channel_visibility)
 361            .add_request_handler(rename_channel)
 362            .add_request_handler(join_channel_buffer)
 363            .add_request_handler(leave_channel_buffer)
 364            .add_message_handler(update_channel_buffer)
 365            .add_request_handler(rejoin_channel_buffers)
 366            .add_request_handler(get_channel_members)
 367            .add_request_handler(respond_to_channel_invite)
 368            .add_request_handler(join_channel)
 369            .add_request_handler(join_channel_chat)
 370            .add_message_handler(leave_channel_chat)
 371            .add_request_handler(send_channel_message)
 372            .add_request_handler(remove_channel_message)
 373            .add_request_handler(update_channel_message)
 374            .add_request_handler(get_channel_messages)
 375            .add_request_handler(get_channel_messages_by_id)
 376            .add_request_handler(get_notifications)
 377            .add_request_handler(mark_notification_as_read)
 378            .add_request_handler(move_channel)
 379            .add_request_handler(follow)
 380            .add_message_handler(unfollow)
 381            .add_message_handler(update_followers)
 382            .add_request_handler(get_private_user_info)
 383            .add_request_handler(get_llm_api_token)
 384            .add_request_handler(accept_terms_of_service)
 385            .add_message_handler(acknowledge_channel_message)
 386            .add_message_handler(acknowledge_buffer_version)
 387            .add_request_handler(get_supermaven_api_key)
 388            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 389            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 390            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 391            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 392            .add_message_handler(update_context)
 393            .add_request_handler({
 394                let app_state = app_state.clone();
 395                move |request, response, session| {
 396                    let app_state = app_state.clone();
 397                    async move {
 398                        count_language_model_tokens(request, response, session, &app_state.config)
 399                            .await
 400                    }
 401                }
 402            })
 403            .add_request_handler(get_cached_embeddings)
 404            .add_request_handler({
 405                let app_state = app_state.clone();
 406                move |request, response, session| {
 407                    compute_embeddings(
 408                        request,
 409                        response,
 410                        session,
 411                        app_state.config.openai_api_key.clone(),
 412                    )
 413                }
 414            });
 415
 416        Arc::new(server)
 417    }
 418
 419    pub async fn start(&self) -> Result<()> {
 420        let server_id = *self.id.lock();
 421        let app_state = self.app_state.clone();
 422        let peer = self.peer.clone();
 423        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 424        let pool = self.connection_pool.clone();
 425        let livekit_client = self.app_state.livekit_client.clone();
 426
 427        let span = info_span!("start server");
 428        self.app_state.executor.spawn_detached(
 429            async move {
 430                tracing::info!("waiting for cleanup timeout");
 431                timeout.await;
 432                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 433                if let Some((room_ids, channel_ids)) = app_state
 434                    .db
 435                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 436                    .await
 437                    .trace_err()
 438                {
 439                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 440                    tracing::info!(
 441                        stale_channel_buffer_count = channel_ids.len(),
 442                        "retrieved stale channel buffers"
 443                    );
 444
 445                    for channel_id in channel_ids {
 446                        if let Some(refreshed_channel_buffer) = app_state
 447                            .db
 448                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 449                            .await
 450                            .trace_err()
 451                        {
 452                            for connection_id in refreshed_channel_buffer.connection_ids {
 453                                peer.send(
 454                                    connection_id,
 455                                    proto::UpdateChannelBufferCollaborators {
 456                                        channel_id: channel_id.to_proto(),
 457                                        collaborators: refreshed_channel_buffer
 458                                            .collaborators
 459                                            .clone(),
 460                                    },
 461                                )
 462                                .trace_err();
 463                            }
 464                        }
 465                    }
 466
 467                    for room_id in room_ids {
 468                        let mut contacts_to_update = HashSet::default();
 469                        let mut canceled_calls_to_user_ids = Vec::new();
 470                        let mut livekit_room = String::new();
 471                        let mut delete_livekit_room = false;
 472
 473                        if let Some(mut refreshed_room) = app_state
 474                            .db
 475                            .clear_stale_room_participants(room_id, server_id)
 476                            .await
 477                            .trace_err()
 478                        {
 479                            tracing::info!(
 480                                room_id = room_id.0,
 481                                new_participant_count = refreshed_room.room.participants.len(),
 482                                "refreshed room"
 483                            );
 484                            room_updated(&refreshed_room.room, &peer);
 485                            if let Some(channel) = refreshed_room.channel.as_ref() {
 486                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 487                            }
 488                            contacts_to_update
 489                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 490                            contacts_to_update
 491                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 492                            canceled_calls_to_user_ids =
 493                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 494                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 495                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 496                        }
 497
 498                        {
 499                            let pool = pool.lock();
 500                            for canceled_user_id in canceled_calls_to_user_ids {
 501                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 502                                    peer.send(
 503                                        connection_id,
 504                                        proto::CallCanceled {
 505                                            room_id: room_id.to_proto(),
 506                                        },
 507                                    )
 508                                    .trace_err();
 509                                }
 510                            }
 511                        }
 512
 513                        for user_id in contacts_to_update {
 514                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 515                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 516                            if let Some((busy, contacts)) = busy.zip(contacts) {
 517                                let pool = pool.lock();
 518                                let updated_contact = contact_for_user(user_id, busy, &pool);
 519                                for contact in contacts {
 520                                    if let db::Contact::Accepted {
 521                                        user_id: contact_user_id,
 522                                        ..
 523                                    } = contact
 524                                    {
 525                                        for contact_conn_id in
 526                                            pool.user_connection_ids(contact_user_id)
 527                                        {
 528                                            peer.send(
 529                                                contact_conn_id,
 530                                                proto::UpdateContacts {
 531                                                    contacts: vec![updated_contact.clone()],
 532                                                    remove_contacts: Default::default(),
 533                                                    incoming_requests: Default::default(),
 534                                                    remove_incoming_requests: Default::default(),
 535                                                    outgoing_requests: Default::default(),
 536                                                    remove_outgoing_requests: Default::default(),
 537                                                },
 538                                            )
 539                                            .trace_err();
 540                                        }
 541                                    }
 542                                }
 543                            }
 544                        }
 545
 546                        if let Some(live_kit) = livekit_client.as_ref() {
 547                            if delete_livekit_room {
 548                                live_kit.delete_room(livekit_room).await.trace_err();
 549                            }
 550                        }
 551                    }
 552                }
 553
 554                app_state
 555                    .db
 556                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 557                    .await
 558                    .trace_err();
 559            }
 560            .instrument(span),
 561        );
 562        Ok(())
 563    }
 564
 565    pub fn teardown(&self) {
 566        self.peer.teardown();
 567        self.connection_pool.lock().reset();
 568        let _ = self.teardown.send(true);
 569    }
 570
 571    #[cfg(test)]
 572    pub fn reset(&self, id: ServerId) {
 573        self.teardown();
 574        *self.id.lock() = id;
 575        self.peer.reset(id.0 as u32);
 576        let _ = self.teardown.send(false);
 577    }
 578
 579    #[cfg(test)]
 580    pub fn id(&self) -> ServerId {
 581        *self.id.lock()
 582    }
 583
 584    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 585    where
 586        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 587        Fut: 'static + Send + Future<Output = Result<()>>,
 588        M: EnvelopedMessage,
 589    {
 590        let prev_handler = self.handlers.insert(
 591            TypeId::of::<M>(),
 592            Box::new(move |envelope, session| {
 593                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 594                let received_at = envelope.received_at;
 595                tracing::info!("message received");
 596                let start_time = Instant::now();
 597                let future = (handler)(*envelope, session);
 598                async move {
 599                    let result = future.await;
 600                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 601                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 602                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 603                    let payload_type = M::NAME;
 604
 605                    match result {
 606                        Err(error) => {
 607                            tracing::error!(
 608                                ?error,
 609                                total_duration_ms,
 610                                processing_duration_ms,
 611                                queue_duration_ms,
 612                                payload_type,
 613                                "error handling message"
 614                            )
 615                        }
 616                        Ok(()) => tracing::info!(
 617                            total_duration_ms,
 618                            processing_duration_ms,
 619                            queue_duration_ms,
 620                            "finished handling message"
 621                        ),
 622                    }
 623                }
 624                .boxed()
 625            }),
 626        );
 627        if prev_handler.is_some() {
 628            panic!("registered a handler for the same message twice");
 629        }
 630        self
 631    }
 632
 633    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 634    where
 635        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 636        Fut: 'static + Send + Future<Output = Result<()>>,
 637        M: EnvelopedMessage,
 638    {
 639        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 640        self
 641    }
 642
 643    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 644    where
 645        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 646        Fut: Send + Future<Output = Result<()>>,
 647        M: RequestMessage,
 648    {
 649        let handler = Arc::new(handler);
 650        self.add_handler(move |envelope, session| {
 651            let receipt = envelope.receipt();
 652            let handler = handler.clone();
 653            async move {
 654                let peer = session.peer.clone();
 655                let responded = Arc::new(AtomicBool::default());
 656                let response = Response {
 657                    peer: peer.clone(),
 658                    responded: responded.clone(),
 659                    receipt,
 660                };
 661                match (handler)(envelope.payload, response, session).await {
 662                    Ok(()) => {
 663                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 664                            Ok(())
 665                        } else {
 666                            Err(anyhow!("handler did not send a response"))?
 667                        }
 668                    }
 669                    Err(error) => {
 670                        let proto_err = match &error {
 671                            Error::Internal(err) => err.to_proto(),
 672                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 673                        };
 674                        peer.respond_with_error(receipt, proto_err)?;
 675                        Err(error)
 676                    }
 677                }
 678            }
 679        })
 680    }
 681
 682    #[allow(clippy::too_many_arguments)]
 683    pub fn handle_connection(
 684        self: &Arc<Self>,
 685        connection: Connection,
 686        address: String,
 687        principal: Principal,
 688        zed_version: ZedVersion,
 689        geoip_country_code: Option<String>,
 690        system_id: Option<String>,
 691        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 692        executor: Executor,
 693    ) -> impl Future<Output = ()> {
 694        let this = self.clone();
 695        let span = info_span!("handle connection", %address,
 696            connection_id=field::Empty,
 697            user_id=field::Empty,
 698            login=field::Empty,
 699            impersonator=field::Empty,
 700            geoip_country_code=field::Empty
 701        );
 702        principal.update_span(&span);
 703        if let Some(country_code) = geoip_country_code.as_ref() {
 704            span.record("geoip_country_code", country_code);
 705        }
 706
 707        let mut teardown = self.teardown.subscribe();
 708        async move {
 709            if *teardown.borrow() {
 710                tracing::error!("server is tearing down");
 711                return
 712            }
 713            let (connection_id, handle_io, mut incoming_rx) = this
 714                .peer
 715                .add_connection(connection, {
 716                    let executor = executor.clone();
 717                    move |duration| executor.sleep(duration)
 718                });
 719            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 720
 721            tracing::info!("connection opened");
 722
 723            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 724            let http_client = match ReqwestClient::user_agent(&user_agent) {
 725                Ok(http_client) => Arc::new(http_client),
 726                Err(error) => {
 727                    tracing::error!(?error, "failed to create HTTP client");
 728                    return;
 729                }
 730            };
 731
 732            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 733                    supermaven_admin_api_key.to_string(),
 734                    http_client.clone(),
 735                )));
 736
 737            let session = Session {
 738                principal: principal.clone(),
 739                connection_id,
 740                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 741                peer: this.peer.clone(),
 742                connection_pool: this.connection_pool.clone(),
 743                app_state: this.app_state.clone(),
 744                http_client,
 745                geoip_country_code,
 746                system_id,
 747                _executor: executor.clone(),
 748                supermaven_client,
 749            };
 750
 751            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 752                tracing::error!(?error, "failed to send initial client update");
 753                return;
 754            }
 755
 756            let handle_io = handle_io.fuse();
 757            futures::pin_mut!(handle_io);
 758
 759            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 760            // This prevents deadlocks when e.g., client A performs a request to client B and
 761            // client B performs a request to client A. If both clients stop processing further
 762            // messages until their respective request completes, they won't have a chance to
 763            // respond to the other client's request and cause a deadlock.
 764            //
 765            // This arrangement ensures we will attempt to process earlier messages first, but fall
 766            // back to processing messages arrived later in the spirit of making progress.
 767            let mut foreground_message_handlers = FuturesUnordered::new();
 768            let concurrent_handlers = Arc::new(Semaphore::new(256));
 769            loop {
 770                let next_message = async {
 771                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 772                    let message = incoming_rx.next().await;
 773                    (permit, message)
 774                }.fuse();
 775                futures::pin_mut!(next_message);
 776                futures::select_biased! {
 777                    _ = teardown.changed().fuse() => return,
 778                    result = handle_io => {
 779                        if let Err(error) = result {
 780                            tracing::error!(?error, "error handling I/O");
 781                        }
 782                        break;
 783                    }
 784                    _ = foreground_message_handlers.next() => {}
 785                    next_message = next_message => {
 786                        let (permit, message) = next_message;
 787                        if let Some(message) = message {
 788                            let type_name = message.payload_type_name();
 789                            // note: we copy all the fields from the parent span so we can query them in the logs.
 790                            // (https://github.com/tokio-rs/tracing/issues/2670).
 791                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 792                                user_id=field::Empty,
 793                                login=field::Empty,
 794                                impersonator=field::Empty,
 795                            );
 796                            principal.update_span(&span);
 797                            let span_enter = span.enter();
 798                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 799                                let is_background = message.is_background();
 800                                let handle_message = (handler)(message, session.clone());
 801                                drop(span_enter);
 802
 803                                let handle_message = async move {
 804                                    handle_message.await;
 805                                    drop(permit);
 806                                }.instrument(span);
 807                                if is_background {
 808                                    executor.spawn_detached(handle_message);
 809                                } else {
 810                                    foreground_message_handlers.push(handle_message);
 811                                }
 812                            } else {
 813                                tracing::error!("no message handler");
 814                            }
 815                        } else {
 816                            tracing::info!("connection closed");
 817                            break;
 818                        }
 819                    }
 820                }
 821            }
 822
 823            drop(foreground_message_handlers);
 824            tracing::info!("signing out");
 825            if let Err(error) = connection_lost(session, teardown, executor).await {
 826                tracing::error!(?error, "error signing out");
 827            }
 828
 829        }.instrument(span)
 830    }
 831
 832    async fn send_initial_client_update(
 833        &self,
 834        connection_id: ConnectionId,
 835        principal: &Principal,
 836        zed_version: ZedVersion,
 837        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 838        session: &Session,
 839    ) -> Result<()> {
 840        self.peer.send(
 841            connection_id,
 842            proto::Hello {
 843                peer_id: Some(connection_id.into()),
 844            },
 845        )?;
 846        tracing::info!("sent hello message");
 847        if let Some(send_connection_id) = send_connection_id.take() {
 848            let _ = send_connection_id.send(connection_id);
 849        }
 850
 851        match principal {
 852            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 853                if !user.connected_once {
 854                    self.peer.send(connection_id, proto::ShowContacts {})?;
 855                    self.app_state
 856                        .db
 857                        .set_user_connected_once(user.id, true)
 858                        .await?;
 859                }
 860
 861                update_user_plan(user.id, session).await?;
 862
 863                let contacts = self.app_state.db.get_contacts(user.id).await?;
 864
 865                {
 866                    let mut pool = self.connection_pool.lock();
 867                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 868                    self.peer.send(
 869                        connection_id,
 870                        build_initial_contacts_update(contacts, &pool),
 871                    )?;
 872                }
 873
 874                if should_auto_subscribe_to_channels(zed_version) {
 875                    subscribe_user_to_channels(user.id, session).await?;
 876                }
 877
 878                if let Some(incoming_call) =
 879                    self.app_state.db.incoming_call_for_user(user.id).await?
 880                {
 881                    self.peer.send(connection_id, incoming_call)?;
 882                }
 883
 884                update_user_contacts(user.id, session).await?;
 885            }
 886        }
 887
 888        Ok(())
 889    }
 890
 891    pub async fn invite_code_redeemed(
 892        self: &Arc<Self>,
 893        inviter_id: UserId,
 894        invitee_id: UserId,
 895    ) -> Result<()> {
 896        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 897            if let Some(code) = &user.invite_code {
 898                let pool = self.connection_pool.lock();
 899                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 900                for connection_id in pool.user_connection_ids(inviter_id) {
 901                    self.peer.send(
 902                        connection_id,
 903                        proto::UpdateContacts {
 904                            contacts: vec![invitee_contact.clone()],
 905                            ..Default::default()
 906                        },
 907                    )?;
 908                    self.peer.send(
 909                        connection_id,
 910                        proto::UpdateInviteInfo {
 911                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 912                            count: user.invite_count as u32,
 913                        },
 914                    )?;
 915                }
 916            }
 917        }
 918        Ok(())
 919    }
 920
 921    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 922        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 923            if let Some(invite_code) = &user.invite_code {
 924                let pool = self.connection_pool.lock();
 925                for connection_id in pool.user_connection_ids(user_id) {
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateInviteInfo {
 929                            url: format!(
 930                                "{}{}",
 931                                self.app_state.config.invite_link_prefix, invite_code
 932                            ),
 933                            count: user.invite_count as u32,
 934                        },
 935                    )?;
 936                }
 937            }
 938        }
 939        Ok(())
 940    }
 941
 942    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 943        let pool = self.connection_pool.lock();
 944        for connection_id in pool.user_connection_ids(user_id) {
 945            self.peer
 946                .send(connection_id, proto::RefreshLlmToken {})
 947                .trace_err();
 948        }
 949    }
 950
 951    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 952        ServerSnapshot {
 953            connection_pool: ConnectionPoolGuard {
 954                guard: self.connection_pool.lock(),
 955                _not_send: PhantomData,
 956            },
 957            peer: &self.peer,
 958        }
 959    }
 960}
 961
 962impl<'a> Deref for ConnectionPoolGuard<'a> {
 963    type Target = ConnectionPool;
 964
 965    fn deref(&self) -> &Self::Target {
 966        &self.guard
 967    }
 968}
 969
 970impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 971    fn deref_mut(&mut self) -> &mut Self::Target {
 972        &mut self.guard
 973    }
 974}
 975
 976impl<'a> Drop for ConnectionPoolGuard<'a> {
 977    fn drop(&mut self) {
 978        #[cfg(test)]
 979        self.check_invariants();
 980    }
 981}
 982
 983fn broadcast<F>(
 984    sender_id: Option<ConnectionId>,
 985    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 986    mut f: F,
 987) where
 988    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 989{
 990    for receiver_id in receiver_ids {
 991        if Some(receiver_id) != sender_id {
 992            if let Err(error) = f(receiver_id) {
 993                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 994            }
 995        }
 996    }
 997}
 998
 999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002    fn name() -> &'static HeaderName {
1003        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005    }
1006
1007    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008    where
1009        Self: Sized,
1010        I: Iterator<Item = &'i axum::http::HeaderValue>,
1011    {
1012        let version = values
1013            .next()
1014            .ok_or_else(axum::headers::Error::invalid)?
1015            .to_str()
1016            .map_err(|_| axum::headers::Error::invalid())?
1017            .parse()
1018            .map_err(|_| axum::headers::Error::invalid())?;
1019        Ok(Self(version))
1020    }
1021
1022    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023        values.extend([self.0.to_string().parse().unwrap()]);
1024    }
1025}
1026
1027pub struct AppVersionHeader(SemanticVersion);
1028impl Header for AppVersionHeader {
1029    fn name() -> &'static HeaderName {
1030        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032    }
1033
1034    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035    where
1036        Self: Sized,
1037        I: Iterator<Item = &'i axum::http::HeaderValue>,
1038    {
1039        let version = values
1040            .next()
1041            .ok_or_else(axum::headers::Error::invalid)?
1042            .to_str()
1043            .map_err(|_| axum::headers::Error::invalid())?
1044            .parse()
1045            .map_err(|_| axum::headers::Error::invalid())?;
1046        Ok(Self(version))
1047    }
1048
1049    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050        values.extend([self.0.to_string().parse().unwrap()]);
1051    }
1052}
1053
1054pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1055    Router::new()
1056        .route("/rpc", get(handle_websocket_request))
1057        .layer(
1058            ServiceBuilder::new()
1059                .layer(Extension(server.app_state.clone()))
1060                .layer(middleware::from_fn(auth::validate_header)),
1061        )
1062        .route("/metrics", get(handle_metrics))
1063        .layer(Extension(server))
1064}
1065
1066#[allow(clippy::too_many_arguments)]
1067pub async fn handle_websocket_request(
1068    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1069    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1070    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1071    Extension(server): Extension<Arc<Server>>,
1072    Extension(principal): Extension<Principal>,
1073    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1074    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1075    ws: WebSocketUpgrade,
1076) -> axum::response::Response {
1077    if protocol_version != rpc::PROTOCOL_VERSION {
1078        return (
1079            StatusCode::UPGRADE_REQUIRED,
1080            "client must be upgraded".to_string(),
1081        )
1082            .into_response();
1083    }
1084
1085    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1086        return (
1087            StatusCode::UPGRADE_REQUIRED,
1088            "no version header found".to_string(),
1089        )
1090            .into_response();
1091    };
1092
1093    if !version.can_collaborate() {
1094        return (
1095            StatusCode::UPGRADE_REQUIRED,
1096            "client must be upgraded".to_string(),
1097        )
1098            .into_response();
1099    }
1100
1101    let socket_address = socket_address.to_string();
1102    ws.on_upgrade(move |socket| {
1103        let socket = socket
1104            .map_ok(to_tungstenite_message)
1105            .err_into()
1106            .with(|message| async move { to_axum_message(message) });
1107        let connection = Connection::new(Box::pin(socket));
1108        async move {
1109            server
1110                .handle_connection(
1111                    connection,
1112                    socket_address,
1113                    principal,
1114                    version,
1115                    country_code_header.map(|header| header.to_string()),
1116                    system_id_header.map(|header| header.to_string()),
1117                    None,
1118                    Executor::Production,
1119                )
1120                .await;
1121        }
1122    })
1123}
1124
1125pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1126    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1127    let connections_metric = CONNECTIONS_METRIC
1128        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1129
1130    let connections = server
1131        .connection_pool
1132        .lock()
1133        .connections()
1134        .filter(|connection| !connection.admin)
1135        .count();
1136    connections_metric.set(connections as _);
1137
1138    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1139    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1140        register_int_gauge!(
1141            "shared_projects",
1142            "number of open projects with one or more guests"
1143        )
1144        .unwrap()
1145    });
1146
1147    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1148    shared_projects_metric.set(shared_projects as _);
1149
1150    let encoder = prometheus::TextEncoder::new();
1151    let metric_families = prometheus::gather();
1152    let encoded_metrics = encoder
1153        .encode_to_string(&metric_families)
1154        .map_err(|err| anyhow!("{}", err))?;
1155    Ok(encoded_metrics)
1156}
1157
1158#[instrument(err, skip(executor))]
1159async fn connection_lost(
1160    session: Session,
1161    mut teardown: watch::Receiver<bool>,
1162    executor: Executor,
1163) -> Result<()> {
1164    session.peer.disconnect(session.connection_id);
1165    session
1166        .connection_pool()
1167        .await
1168        .remove_connection(session.connection_id)?;
1169
1170    session
1171        .db()
1172        .await
1173        .connection_lost(session.connection_id)
1174        .await
1175        .trace_err();
1176
1177    futures::select_biased! {
1178        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1179
1180            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1181            leave_room_for_session(&session, session.connection_id).await.trace_err();
1182            leave_channel_buffers_for_session(&session)
1183                .await
1184                .trace_err();
1185
1186            if !session
1187                .connection_pool()
1188                .await
1189                .is_user_online(session.user_id())
1190            {
1191                let db = session.db().await;
1192                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1193                    room_updated(&room, &session.peer);
1194                }
1195            }
1196
1197            update_user_contacts(session.user_id(), &session).await?;
1198        },
1199        _ = teardown.changed().fuse() => {}
1200    }
1201
1202    Ok(())
1203}
1204
1205/// Acknowledges a ping from a client, used to keep the connection alive.
1206async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1207    response.send(proto::Ack {})?;
1208    Ok(())
1209}
1210
1211/// Creates a new room for calling (outside of channels)
1212async fn create_room(
1213    _request: proto::CreateRoom,
1214    response: Response<proto::CreateRoom>,
1215    session: Session,
1216) -> Result<()> {
1217    let livekit_room = nanoid::nanoid!(30);
1218
1219    let live_kit_connection_info = util::maybe!(async {
1220        let live_kit = session.app_state.livekit_client.as_ref();
1221        let live_kit = live_kit?;
1222        let user_id = session.user_id().to_string();
1223
1224        let token = live_kit
1225            .room_token(&livekit_room, &user_id.to_string())
1226            .trace_err()?;
1227
1228        Some(proto::LiveKitConnectionInfo {
1229            server_url: live_kit.url().into(),
1230            token,
1231            can_publish: true,
1232        })
1233    })
1234    .await;
1235
1236    let room = session
1237        .db()
1238        .await
1239        .create_room(session.user_id(), session.connection_id, &livekit_room)
1240        .await?;
1241
1242    response.send(proto::CreateRoomResponse {
1243        room: Some(room.clone()),
1244        live_kit_connection_info,
1245    })?;
1246
1247    update_user_contacts(session.user_id(), &session).await?;
1248    Ok(())
1249}
1250
1251/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1252async fn join_room(
1253    request: proto::JoinRoom,
1254    response: Response<proto::JoinRoom>,
1255    session: Session,
1256) -> Result<()> {
1257    let room_id = RoomId::from_proto(request.id);
1258
1259    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1260
1261    if let Some(channel_id) = channel_id {
1262        return join_channel_internal(channel_id, Box::new(response), session).await;
1263    }
1264
1265    let joined_room = {
1266        let room = session
1267            .db()
1268            .await
1269            .join_room(room_id, session.user_id(), session.connection_id)
1270            .await?;
1271        room_updated(&room.room, &session.peer);
1272        room.into_inner()
1273    };
1274
1275    for connection_id in session
1276        .connection_pool()
1277        .await
1278        .user_connection_ids(session.user_id())
1279    {
1280        session
1281            .peer
1282            .send(
1283                connection_id,
1284                proto::CallCanceled {
1285                    room_id: room_id.to_proto(),
1286                },
1287            )
1288            .trace_err();
1289    }
1290
1291    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1292    {
1293        live_kit
1294            .room_token(
1295                &joined_room.room.livekit_room,
1296                &session.user_id().to_string(),
1297            )
1298            .trace_err()
1299            .map(|token| proto::LiveKitConnectionInfo {
1300                server_url: live_kit.url().into(),
1301                token,
1302                can_publish: true,
1303            })
1304    } else {
1305        None
1306    };
1307
1308    response.send(proto::JoinRoomResponse {
1309        room: Some(joined_room.room),
1310        channel_id: None,
1311        live_kit_connection_info,
1312    })?;
1313
1314    update_user_contacts(session.user_id(), &session).await?;
1315    Ok(())
1316}
1317
1318/// Rejoin room is used to reconnect to a room after connection errors.
1319async fn rejoin_room(
1320    request: proto::RejoinRoom,
1321    response: Response<proto::RejoinRoom>,
1322    session: Session,
1323) -> Result<()> {
1324    let room;
1325    let channel;
1326    {
1327        let mut rejoined_room = session
1328            .db()
1329            .await
1330            .rejoin_room(request, session.user_id(), session.connection_id)
1331            .await?;
1332
1333        response.send(proto::RejoinRoomResponse {
1334            room: Some(rejoined_room.room.clone()),
1335            reshared_projects: rejoined_room
1336                .reshared_projects
1337                .iter()
1338                .map(|project| proto::ResharedProject {
1339                    id: project.id.to_proto(),
1340                    collaborators: project
1341                        .collaborators
1342                        .iter()
1343                        .map(|collaborator| collaborator.to_proto())
1344                        .collect(),
1345                })
1346                .collect(),
1347            rejoined_projects: rejoined_room
1348                .rejoined_projects
1349                .iter()
1350                .map(|rejoined_project| rejoined_project.to_proto())
1351                .collect(),
1352        })?;
1353        room_updated(&rejoined_room.room, &session.peer);
1354
1355        for project in &rejoined_room.reshared_projects {
1356            for collaborator in &project.collaborators {
1357                session
1358                    .peer
1359                    .send(
1360                        collaborator.connection_id,
1361                        proto::UpdateProjectCollaborator {
1362                            project_id: project.id.to_proto(),
1363                            old_peer_id: Some(project.old_connection_id.into()),
1364                            new_peer_id: Some(session.connection_id.into()),
1365                        },
1366                    )
1367                    .trace_err();
1368            }
1369
1370            broadcast(
1371                Some(session.connection_id),
1372                project
1373                    .collaborators
1374                    .iter()
1375                    .map(|collaborator| collaborator.connection_id),
1376                |connection_id| {
1377                    session.peer.forward_send(
1378                        session.connection_id,
1379                        connection_id,
1380                        proto::UpdateProject {
1381                            project_id: project.id.to_proto(),
1382                            worktrees: project.worktrees.clone(),
1383                        },
1384                    )
1385                },
1386            );
1387        }
1388
1389        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1390
1391        let rejoined_room = rejoined_room.into_inner();
1392
1393        room = rejoined_room.room;
1394        channel = rejoined_room.channel;
1395    }
1396
1397    if let Some(channel) = channel {
1398        channel_updated(
1399            &channel,
1400            &room,
1401            &session.peer,
1402            &*session.connection_pool().await,
1403        );
1404    }
1405
1406    update_user_contacts(session.user_id(), &session).await?;
1407    Ok(())
1408}
1409
1410fn notify_rejoined_projects(
1411    rejoined_projects: &mut Vec<RejoinedProject>,
1412    session: &Session,
1413) -> Result<()> {
1414    for project in rejoined_projects.iter() {
1415        for collaborator in &project.collaborators {
1416            session
1417                .peer
1418                .send(
1419                    collaborator.connection_id,
1420                    proto::UpdateProjectCollaborator {
1421                        project_id: project.id.to_proto(),
1422                        old_peer_id: Some(project.old_connection_id.into()),
1423                        new_peer_id: Some(session.connection_id.into()),
1424                    },
1425                )
1426                .trace_err();
1427        }
1428    }
1429
1430    for project in rejoined_projects {
1431        for worktree in mem::take(&mut project.worktrees) {
1432            // Stream this worktree's entries.
1433            let message = proto::UpdateWorktree {
1434                project_id: project.id.to_proto(),
1435                worktree_id: worktree.id,
1436                abs_path: worktree.abs_path.clone(),
1437                root_name: worktree.root_name,
1438                updated_entries: worktree.updated_entries,
1439                removed_entries: worktree.removed_entries,
1440                scan_id: worktree.scan_id,
1441                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1442                updated_repositories: worktree.updated_repositories,
1443                removed_repositories: worktree.removed_repositories,
1444            };
1445            for update in proto::split_worktree_update(message) {
1446                session.peer.send(session.connection_id, update.clone())?;
1447            }
1448
1449            // Stream this worktree's diagnostics.
1450            for summary in worktree.diagnostic_summaries {
1451                session.peer.send(
1452                    session.connection_id,
1453                    proto::UpdateDiagnosticSummary {
1454                        project_id: project.id.to_proto(),
1455                        worktree_id: worktree.id,
1456                        summary: Some(summary),
1457                    },
1458                )?;
1459            }
1460
1461            for settings_file in worktree.settings_files {
1462                session.peer.send(
1463                    session.connection_id,
1464                    proto::UpdateWorktreeSettings {
1465                        project_id: project.id.to_proto(),
1466                        worktree_id: worktree.id,
1467                        path: settings_file.path,
1468                        content: Some(settings_file.content),
1469                        kind: Some(settings_file.kind.to_proto().into()),
1470                    },
1471                )?;
1472            }
1473        }
1474
1475        for language_server in &project.language_servers {
1476            session.peer.send(
1477                session.connection_id,
1478                proto::UpdateLanguageServer {
1479                    project_id: project.id.to_proto(),
1480                    language_server_id: language_server.id,
1481                    variant: Some(
1482                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1483                            proto::LspDiskBasedDiagnosticsUpdated {},
1484                        ),
1485                    ),
1486                },
1487            )?;
1488        }
1489    }
1490    Ok(())
1491}
1492
1493/// leave room disconnects from the room.
1494async fn leave_room(
1495    _: proto::LeaveRoom,
1496    response: Response<proto::LeaveRoom>,
1497    session: Session,
1498) -> Result<()> {
1499    leave_room_for_session(&session, session.connection_id).await?;
1500    response.send(proto::Ack {})?;
1501    Ok(())
1502}
1503
1504/// Updates the permissions of someone else in the room.
1505async fn set_room_participant_role(
1506    request: proto::SetRoomParticipantRole,
1507    response: Response<proto::SetRoomParticipantRole>,
1508    session: Session,
1509) -> Result<()> {
1510    let user_id = UserId::from_proto(request.user_id);
1511    let role = ChannelRole::from(request.role());
1512
1513    let (livekit_room, can_publish) = {
1514        let room = session
1515            .db()
1516            .await
1517            .set_room_participant_role(
1518                session.user_id(),
1519                RoomId::from_proto(request.room_id),
1520                user_id,
1521                role,
1522            )
1523            .await?;
1524
1525        let livekit_room = room.livekit_room.clone();
1526        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1527        room_updated(&room, &session.peer);
1528        (livekit_room, can_publish)
1529    };
1530
1531    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1532        live_kit
1533            .update_participant(
1534                livekit_room.clone(),
1535                request.user_id.to_string(),
1536                livekit_server::proto::ParticipantPermission {
1537                    can_subscribe: true,
1538                    can_publish,
1539                    can_publish_data: can_publish,
1540                    hidden: false,
1541                    recorder: false,
1542                },
1543            )
1544            .await
1545            .trace_err();
1546    }
1547
1548    response.send(proto::Ack {})?;
1549    Ok(())
1550}
1551
1552/// Call someone else into the current room
1553async fn call(
1554    request: proto::Call,
1555    response: Response<proto::Call>,
1556    session: Session,
1557) -> Result<()> {
1558    let room_id = RoomId::from_proto(request.room_id);
1559    let calling_user_id = session.user_id();
1560    let calling_connection_id = session.connection_id;
1561    let called_user_id = UserId::from_proto(request.called_user_id);
1562    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1563    if !session
1564        .db()
1565        .await
1566        .has_contact(calling_user_id, called_user_id)
1567        .await?
1568    {
1569        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1570    }
1571
1572    let incoming_call = {
1573        let (room, incoming_call) = &mut *session
1574            .db()
1575            .await
1576            .call(
1577                room_id,
1578                calling_user_id,
1579                calling_connection_id,
1580                called_user_id,
1581                initial_project_id,
1582            )
1583            .await?;
1584        room_updated(room, &session.peer);
1585        mem::take(incoming_call)
1586    };
1587    update_user_contacts(called_user_id, &session).await?;
1588
1589    let mut calls = session
1590        .connection_pool()
1591        .await
1592        .user_connection_ids(called_user_id)
1593        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1594        .collect::<FuturesUnordered<_>>();
1595
1596    while let Some(call_response) = calls.next().await {
1597        match call_response.as_ref() {
1598            Ok(_) => {
1599                response.send(proto::Ack {})?;
1600                return Ok(());
1601            }
1602            Err(_) => {
1603                call_response.trace_err();
1604            }
1605        }
1606    }
1607
1608    {
1609        let room = session
1610            .db()
1611            .await
1612            .call_failed(room_id, called_user_id)
1613            .await?;
1614        room_updated(&room, &session.peer);
1615    }
1616    update_user_contacts(called_user_id, &session).await?;
1617
1618    Err(anyhow!("failed to ring user"))?
1619}
1620
1621/// Cancel an outgoing call.
1622async fn cancel_call(
1623    request: proto::CancelCall,
1624    response: Response<proto::CancelCall>,
1625    session: Session,
1626) -> Result<()> {
1627    let called_user_id = UserId::from_proto(request.called_user_id);
1628    let room_id = RoomId::from_proto(request.room_id);
1629    {
1630        let room = session
1631            .db()
1632            .await
1633            .cancel_call(room_id, session.connection_id, called_user_id)
1634            .await?;
1635        room_updated(&room, &session.peer);
1636    }
1637
1638    for connection_id in session
1639        .connection_pool()
1640        .await
1641        .user_connection_ids(called_user_id)
1642    {
1643        session
1644            .peer
1645            .send(
1646                connection_id,
1647                proto::CallCanceled {
1648                    room_id: room_id.to_proto(),
1649                },
1650            )
1651            .trace_err();
1652    }
1653    response.send(proto::Ack {})?;
1654
1655    update_user_contacts(called_user_id, &session).await?;
1656    Ok(())
1657}
1658
1659/// Decline an incoming call.
1660async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1661    let room_id = RoomId::from_proto(message.room_id);
1662    {
1663        let room = session
1664            .db()
1665            .await
1666            .decline_call(Some(room_id), session.user_id())
1667            .await?
1668            .ok_or_else(|| anyhow!("failed to decline call"))?;
1669        room_updated(&room, &session.peer);
1670    }
1671
1672    for connection_id in session
1673        .connection_pool()
1674        .await
1675        .user_connection_ids(session.user_id())
1676    {
1677        session
1678            .peer
1679            .send(
1680                connection_id,
1681                proto::CallCanceled {
1682                    room_id: room_id.to_proto(),
1683                },
1684            )
1685            .trace_err();
1686    }
1687    update_user_contacts(session.user_id(), &session).await?;
1688    Ok(())
1689}
1690
1691/// Updates other participants in the room with your current location.
1692async fn update_participant_location(
1693    request: proto::UpdateParticipantLocation,
1694    response: Response<proto::UpdateParticipantLocation>,
1695    session: Session,
1696) -> Result<()> {
1697    let room_id = RoomId::from_proto(request.room_id);
1698    let location = request
1699        .location
1700        .ok_or_else(|| anyhow!("invalid location"))?;
1701
1702    let db = session.db().await;
1703    let room = db
1704        .update_room_participant_location(room_id, session.connection_id, location)
1705        .await?;
1706
1707    room_updated(&room, &session.peer);
1708    response.send(proto::Ack {})?;
1709    Ok(())
1710}
1711
1712/// Share a project into the room.
1713async fn share_project(
1714    request: proto::ShareProject,
1715    response: Response<proto::ShareProject>,
1716    session: Session,
1717) -> Result<()> {
1718    let (project_id, room) = &*session
1719        .db()
1720        .await
1721        .share_project(
1722            RoomId::from_proto(request.room_id),
1723            session.connection_id,
1724            &request.worktrees,
1725            request.is_ssh_project,
1726        )
1727        .await?;
1728    response.send(proto::ShareProjectResponse {
1729        project_id: project_id.to_proto(),
1730    })?;
1731    room_updated(room, &session.peer);
1732
1733    Ok(())
1734}
1735
1736/// Unshare a project from the room.
1737async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1738    let project_id = ProjectId::from_proto(message.project_id);
1739    unshare_project_internal(project_id, session.connection_id, &session).await
1740}
1741
1742async fn unshare_project_internal(
1743    project_id: ProjectId,
1744    connection_id: ConnectionId,
1745    session: &Session,
1746) -> Result<()> {
1747    let delete = {
1748        let room_guard = session
1749            .db()
1750            .await
1751            .unshare_project(project_id, connection_id)
1752            .await?;
1753
1754        let (delete, room, guest_connection_ids) = &*room_guard;
1755
1756        let message = proto::UnshareProject {
1757            project_id: project_id.to_proto(),
1758        };
1759
1760        broadcast(
1761            Some(connection_id),
1762            guest_connection_ids.iter().copied(),
1763            |conn_id| session.peer.send(conn_id, message.clone()),
1764        );
1765        if let Some(room) = room {
1766            room_updated(room, &session.peer);
1767        }
1768
1769        *delete
1770    };
1771
1772    if delete {
1773        let db = session.db().await;
1774        db.delete_project(project_id).await?;
1775    }
1776
1777    Ok(())
1778}
1779
1780/// Join someone elses shared project.
1781async fn join_project(
1782    request: proto::JoinProject,
1783    response: Response<proto::JoinProject>,
1784    session: Session,
1785) -> Result<()> {
1786    let project_id = ProjectId::from_proto(request.project_id);
1787
1788    tracing::info!(%project_id, "join project");
1789
1790    let db = session.db().await;
1791    let (project, replica_id) = &mut *db
1792        .join_project(project_id, session.connection_id, session.user_id())
1793        .await?;
1794    drop(db);
1795    tracing::info!(%project_id, "join remote project");
1796    join_project_internal(response, session, project, replica_id)
1797}
1798
1799trait JoinProjectInternalResponse {
1800    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1801}
1802impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1803    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1804        Response::<proto::JoinProject>::send(self, result)
1805    }
1806}
1807
1808fn join_project_internal(
1809    response: impl JoinProjectInternalResponse,
1810    session: Session,
1811    project: &mut Project,
1812    replica_id: &ReplicaId,
1813) -> Result<()> {
1814    let collaborators = project
1815        .collaborators
1816        .iter()
1817        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1818        .map(|collaborator| collaborator.to_proto())
1819        .collect::<Vec<_>>();
1820    let project_id = project.id;
1821    let guest_user_id = session.user_id();
1822
1823    let worktrees = project
1824        .worktrees
1825        .iter()
1826        .map(|(id, worktree)| proto::WorktreeMetadata {
1827            id: *id,
1828            root_name: worktree.root_name.clone(),
1829            visible: worktree.visible,
1830            abs_path: worktree.abs_path.clone(),
1831        })
1832        .collect::<Vec<_>>();
1833
1834    let add_project_collaborator = proto::AddProjectCollaborator {
1835        project_id: project_id.to_proto(),
1836        collaborator: Some(proto::Collaborator {
1837            peer_id: Some(session.connection_id.into()),
1838            replica_id: replica_id.0 as u32,
1839            user_id: guest_user_id.to_proto(),
1840            is_host: false,
1841        }),
1842    };
1843
1844    for collaborator in &collaborators {
1845        session
1846            .peer
1847            .send(
1848                collaborator.peer_id.unwrap().into(),
1849                add_project_collaborator.clone(),
1850            )
1851            .trace_err();
1852    }
1853
1854    // First, we send the metadata associated with each worktree.
1855    response.send(proto::JoinProjectResponse {
1856        project_id: project.id.0 as u64,
1857        worktrees: worktrees.clone(),
1858        replica_id: replica_id.0 as u32,
1859        collaborators: collaborators.clone(),
1860        language_servers: project.language_servers.clone(),
1861        role: project.role.into(),
1862    })?;
1863
1864    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1865        // Stream this worktree's entries.
1866        let message = proto::UpdateWorktree {
1867            project_id: project_id.to_proto(),
1868            worktree_id,
1869            abs_path: worktree.abs_path.clone(),
1870            root_name: worktree.root_name,
1871            updated_entries: worktree.entries,
1872            removed_entries: Default::default(),
1873            scan_id: worktree.scan_id,
1874            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1875            updated_repositories: worktree.repository_entries.into_values().collect(),
1876            removed_repositories: Default::default(),
1877        };
1878        for update in proto::split_worktree_update(message) {
1879            session.peer.send(session.connection_id, update.clone())?;
1880        }
1881
1882        // Stream this worktree's diagnostics.
1883        for summary in worktree.diagnostic_summaries {
1884            session.peer.send(
1885                session.connection_id,
1886                proto::UpdateDiagnosticSummary {
1887                    project_id: project_id.to_proto(),
1888                    worktree_id: worktree.id,
1889                    summary: Some(summary),
1890                },
1891            )?;
1892        }
1893
1894        for settings_file in worktree.settings_files {
1895            session.peer.send(
1896                session.connection_id,
1897                proto::UpdateWorktreeSettings {
1898                    project_id: project_id.to_proto(),
1899                    worktree_id: worktree.id,
1900                    path: settings_file.path,
1901                    content: Some(settings_file.content),
1902                    kind: Some(settings_file.kind.to_proto() as i32),
1903                },
1904            )?;
1905        }
1906    }
1907
1908    for language_server in &project.language_servers {
1909        session.peer.send(
1910            session.connection_id,
1911            proto::UpdateLanguageServer {
1912                project_id: project_id.to_proto(),
1913                language_server_id: language_server.id,
1914                variant: Some(
1915                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1916                        proto::LspDiskBasedDiagnosticsUpdated {},
1917                    ),
1918                ),
1919            },
1920        )?;
1921    }
1922
1923    Ok(())
1924}
1925
1926/// Leave someone elses shared project.
1927async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1928    let sender_id = session.connection_id;
1929    let project_id = ProjectId::from_proto(request.project_id);
1930    let db = session.db().await;
1931
1932    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1933    tracing::info!(
1934        %project_id,
1935        "leave project"
1936    );
1937
1938    project_left(project, &session);
1939    if let Some(room) = room {
1940        room_updated(room, &session.peer);
1941    }
1942
1943    Ok(())
1944}
1945
1946/// Updates other participants with changes to the project
1947async fn update_project(
1948    request: proto::UpdateProject,
1949    response: Response<proto::UpdateProject>,
1950    session: Session,
1951) -> Result<()> {
1952    let project_id = ProjectId::from_proto(request.project_id);
1953    let (room, guest_connection_ids) = &*session
1954        .db()
1955        .await
1956        .update_project(project_id, session.connection_id, &request.worktrees)
1957        .await?;
1958    broadcast(
1959        Some(session.connection_id),
1960        guest_connection_ids.iter().copied(),
1961        |connection_id| {
1962            session
1963                .peer
1964                .forward_send(session.connection_id, connection_id, request.clone())
1965        },
1966    );
1967    if let Some(room) = room {
1968        room_updated(room, &session.peer);
1969    }
1970    response.send(proto::Ack {})?;
1971
1972    Ok(())
1973}
1974
1975/// Updates other participants with changes to the worktree
1976async fn update_worktree(
1977    request: proto::UpdateWorktree,
1978    response: Response<proto::UpdateWorktree>,
1979    session: Session,
1980) -> Result<()> {
1981    let guest_connection_ids = session
1982        .db()
1983        .await
1984        .update_worktree(&request, session.connection_id)
1985        .await?;
1986
1987    broadcast(
1988        Some(session.connection_id),
1989        guest_connection_ids.iter().copied(),
1990        |connection_id| {
1991            session
1992                .peer
1993                .forward_send(session.connection_id, connection_id, request.clone())
1994        },
1995    );
1996    response.send(proto::Ack {})?;
1997    Ok(())
1998}
1999
2000/// Updates other participants with changes to the diagnostics
2001async fn update_diagnostic_summary(
2002    message: proto::UpdateDiagnosticSummary,
2003    session: Session,
2004) -> Result<()> {
2005    let guest_connection_ids = session
2006        .db()
2007        .await
2008        .update_diagnostic_summary(&message, session.connection_id)
2009        .await?;
2010
2011    broadcast(
2012        Some(session.connection_id),
2013        guest_connection_ids.iter().copied(),
2014        |connection_id| {
2015            session
2016                .peer
2017                .forward_send(session.connection_id, connection_id, message.clone())
2018        },
2019    );
2020
2021    Ok(())
2022}
2023
2024/// Updates other participants with changes to the worktree settings
2025async fn update_worktree_settings(
2026    message: proto::UpdateWorktreeSettings,
2027    session: Session,
2028) -> Result<()> {
2029    let guest_connection_ids = session
2030        .db()
2031        .await
2032        .update_worktree_settings(&message, session.connection_id)
2033        .await?;
2034
2035    broadcast(
2036        Some(session.connection_id),
2037        guest_connection_ids.iter().copied(),
2038        |connection_id| {
2039            session
2040                .peer
2041                .forward_send(session.connection_id, connection_id, message.clone())
2042        },
2043    );
2044
2045    Ok(())
2046}
2047
2048/// Notify other participants that a  language server has started.
2049async fn start_language_server(
2050    request: proto::StartLanguageServer,
2051    session: Session,
2052) -> Result<()> {
2053    let guest_connection_ids = session
2054        .db()
2055        .await
2056        .start_language_server(&request, session.connection_id)
2057        .await?;
2058
2059    broadcast(
2060        Some(session.connection_id),
2061        guest_connection_ids.iter().copied(),
2062        |connection_id| {
2063            session
2064                .peer
2065                .forward_send(session.connection_id, connection_id, request.clone())
2066        },
2067    );
2068    Ok(())
2069}
2070
2071/// Notify other participants that a language server has changed.
2072async fn update_language_server(
2073    request: proto::UpdateLanguageServer,
2074    session: Session,
2075) -> Result<()> {
2076    let project_id = ProjectId::from_proto(request.project_id);
2077    let project_connection_ids = session
2078        .db()
2079        .await
2080        .project_connection_ids(project_id, session.connection_id, true)
2081        .await?;
2082    broadcast(
2083        Some(session.connection_id),
2084        project_connection_ids.iter().copied(),
2085        |connection_id| {
2086            session
2087                .peer
2088                .forward_send(session.connection_id, connection_id, request.clone())
2089        },
2090    );
2091    Ok(())
2092}
2093
2094/// forward a project request to the host. These requests should be read only
2095/// as guests are allowed to send them.
2096async fn forward_read_only_project_request<T>(
2097    request: T,
2098    response: Response<T>,
2099    session: Session,
2100) -> Result<()>
2101where
2102    T: EntityMessage + RequestMessage,
2103{
2104    let project_id = ProjectId::from_proto(request.remote_entity_id());
2105    let host_connection_id = session
2106        .db()
2107        .await
2108        .host_for_read_only_project_request(project_id, session.connection_id)
2109        .await?;
2110    let payload = session
2111        .peer
2112        .forward_request(session.connection_id, host_connection_id, request)
2113        .await?;
2114    response.send(payload)?;
2115    Ok(())
2116}
2117
2118async fn forward_find_search_candidates_request(
2119    request: proto::FindSearchCandidates,
2120    response: Response<proto::FindSearchCandidates>,
2121    session: Session,
2122) -> Result<()> {
2123    let project_id = ProjectId::from_proto(request.remote_entity_id());
2124    let host_connection_id = session
2125        .db()
2126        .await
2127        .host_for_read_only_project_request(project_id, session.connection_id)
2128        .await?;
2129    let payload = session
2130        .peer
2131        .forward_request(session.connection_id, host_connection_id, request)
2132        .await?;
2133    response.send(payload)?;
2134    Ok(())
2135}
2136
2137/// forward a project request to the host. These requests are disallowed
2138/// for guests.
2139async fn forward_mutating_project_request<T>(
2140    request: T,
2141    response: Response<T>,
2142    session: Session,
2143) -> Result<()>
2144where
2145    T: EntityMessage + RequestMessage,
2146{
2147    let project_id = ProjectId::from_proto(request.remote_entity_id());
2148
2149    let host_connection_id = session
2150        .db()
2151        .await
2152        .host_for_mutating_project_request(project_id, session.connection_id)
2153        .await?;
2154    let payload = session
2155        .peer
2156        .forward_request(session.connection_id, host_connection_id, request)
2157        .await?;
2158    response.send(payload)?;
2159    Ok(())
2160}
2161
2162/// Notify other participants that a new buffer has been created
2163async fn create_buffer_for_peer(
2164    request: proto::CreateBufferForPeer,
2165    session: Session,
2166) -> Result<()> {
2167    session
2168        .db()
2169        .await
2170        .check_user_is_project_host(
2171            ProjectId::from_proto(request.project_id),
2172            session.connection_id,
2173        )
2174        .await?;
2175    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2176    session
2177        .peer
2178        .forward_send(session.connection_id, peer_id.into(), request)?;
2179    Ok(())
2180}
2181
2182/// Notify other participants that a buffer has been updated. This is
2183/// allowed for guests as long as the update is limited to selections.
2184async fn update_buffer(
2185    request: proto::UpdateBuffer,
2186    response: Response<proto::UpdateBuffer>,
2187    session: Session,
2188) -> Result<()> {
2189    let project_id = ProjectId::from_proto(request.project_id);
2190    let mut capability = Capability::ReadOnly;
2191
2192    for op in request.operations.iter() {
2193        match op.variant {
2194            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2195            Some(_) => capability = Capability::ReadWrite,
2196        }
2197    }
2198
2199    let host = {
2200        let guard = session
2201            .db()
2202            .await
2203            .connections_for_buffer_update(project_id, session.connection_id, capability)
2204            .await?;
2205
2206        let (host, guests) = &*guard;
2207
2208        broadcast(
2209            Some(session.connection_id),
2210            guests.clone(),
2211            |connection_id| {
2212                session
2213                    .peer
2214                    .forward_send(session.connection_id, connection_id, request.clone())
2215            },
2216        );
2217
2218        *host
2219    };
2220
2221    if host != session.connection_id {
2222        session
2223            .peer
2224            .forward_request(session.connection_id, host, request.clone())
2225            .await?;
2226    }
2227
2228    response.send(proto::Ack {})?;
2229    Ok(())
2230}
2231
2232async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2233    let project_id = ProjectId::from_proto(message.project_id);
2234
2235    let operation = message.operation.as_ref().context("invalid operation")?;
2236    let capability = match operation.variant.as_ref() {
2237        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2238            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2239                match buffer_op.variant {
2240                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2241                        Capability::ReadOnly
2242                    }
2243                    _ => Capability::ReadWrite,
2244                }
2245            } else {
2246                Capability::ReadWrite
2247            }
2248        }
2249        Some(_) => Capability::ReadWrite,
2250        None => Capability::ReadOnly,
2251    };
2252
2253    let guard = session
2254        .db()
2255        .await
2256        .connections_for_buffer_update(project_id, session.connection_id, capability)
2257        .await?;
2258
2259    let (host, guests) = &*guard;
2260
2261    broadcast(
2262        Some(session.connection_id),
2263        guests.iter().chain([host]).copied(),
2264        |connection_id| {
2265            session
2266                .peer
2267                .forward_send(session.connection_id, connection_id, message.clone())
2268        },
2269    );
2270
2271    Ok(())
2272}
2273
2274/// Notify other participants that a project has been updated.
2275async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2276    request: T,
2277    session: Session,
2278) -> Result<()> {
2279    let project_id = ProjectId::from_proto(request.remote_entity_id());
2280    let project_connection_ids = session
2281        .db()
2282        .await
2283        .project_connection_ids(project_id, session.connection_id, false)
2284        .await?;
2285
2286    broadcast(
2287        Some(session.connection_id),
2288        project_connection_ids.iter().copied(),
2289        |connection_id| {
2290            session
2291                .peer
2292                .forward_send(session.connection_id, connection_id, request.clone())
2293        },
2294    );
2295    Ok(())
2296}
2297
2298/// Start following another user in a call.
2299async fn follow(
2300    request: proto::Follow,
2301    response: Response<proto::Follow>,
2302    session: Session,
2303) -> Result<()> {
2304    let room_id = RoomId::from_proto(request.room_id);
2305    let project_id = request.project_id.map(ProjectId::from_proto);
2306    let leader_id = request
2307        .leader_id
2308        .ok_or_else(|| anyhow!("invalid leader id"))?
2309        .into();
2310    let follower_id = session.connection_id;
2311
2312    session
2313        .db()
2314        .await
2315        .check_room_participants(room_id, leader_id, session.connection_id)
2316        .await?;
2317
2318    let response_payload = session
2319        .peer
2320        .forward_request(session.connection_id, leader_id, request)
2321        .await?;
2322    response.send(response_payload)?;
2323
2324    if let Some(project_id) = project_id {
2325        let room = session
2326            .db()
2327            .await
2328            .follow(room_id, project_id, leader_id, follower_id)
2329            .await?;
2330        room_updated(&room, &session.peer);
2331    }
2332
2333    Ok(())
2334}
2335
2336/// Stop following another user in a call.
2337async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2338    let room_id = RoomId::from_proto(request.room_id);
2339    let project_id = request.project_id.map(ProjectId::from_proto);
2340    let leader_id = request
2341        .leader_id
2342        .ok_or_else(|| anyhow!("invalid leader id"))?
2343        .into();
2344    let follower_id = session.connection_id;
2345
2346    session
2347        .db()
2348        .await
2349        .check_room_participants(room_id, leader_id, session.connection_id)
2350        .await?;
2351
2352    session
2353        .peer
2354        .forward_send(session.connection_id, leader_id, request)?;
2355
2356    if let Some(project_id) = project_id {
2357        let room = session
2358            .db()
2359            .await
2360            .unfollow(room_id, project_id, leader_id, follower_id)
2361            .await?;
2362        room_updated(&room, &session.peer);
2363    }
2364
2365    Ok(())
2366}
2367
2368/// Notify everyone following you of your current location.
2369async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2370    let room_id = RoomId::from_proto(request.room_id);
2371    let database = session.db.lock().await;
2372
2373    let connection_ids = if let Some(project_id) = request.project_id {
2374        let project_id = ProjectId::from_proto(project_id);
2375        database
2376            .project_connection_ids(project_id, session.connection_id, true)
2377            .await?
2378    } else {
2379        database
2380            .room_connection_ids(room_id, session.connection_id)
2381            .await?
2382    };
2383
2384    // For now, don't send view update messages back to that view's current leader.
2385    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2386        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2387        _ => None,
2388    });
2389
2390    for connection_id in connection_ids.iter().cloned() {
2391        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2392            session
2393                .peer
2394                .forward_send(session.connection_id, connection_id, request.clone())?;
2395        }
2396    }
2397    Ok(())
2398}
2399
2400/// Get public data about users.
2401async fn get_users(
2402    request: proto::GetUsers,
2403    response: Response<proto::GetUsers>,
2404    session: Session,
2405) -> Result<()> {
2406    let user_ids = request
2407        .user_ids
2408        .into_iter()
2409        .map(UserId::from_proto)
2410        .collect();
2411    let users = session
2412        .db()
2413        .await
2414        .get_users_by_ids(user_ids)
2415        .await?
2416        .into_iter()
2417        .map(|user| proto::User {
2418            id: user.id.to_proto(),
2419            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2420            github_login: user.github_login,
2421            email: user.email_address,
2422            name: user.name,
2423        })
2424        .collect();
2425    response.send(proto::UsersResponse { users })?;
2426    Ok(())
2427}
2428
2429/// Search for users (to invite) buy Github login
2430async fn fuzzy_search_users(
2431    request: proto::FuzzySearchUsers,
2432    response: Response<proto::FuzzySearchUsers>,
2433    session: Session,
2434) -> Result<()> {
2435    let query = request.query;
2436    let users = match query.len() {
2437        0 => vec![],
2438        1 | 2 => session
2439            .db()
2440            .await
2441            .get_user_by_github_login(&query)
2442            .await?
2443            .into_iter()
2444            .collect(),
2445        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2446    };
2447    let users = users
2448        .into_iter()
2449        .filter(|user| user.id != session.user_id())
2450        .map(|user| proto::User {
2451            id: user.id.to_proto(),
2452            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2453            github_login: user.github_login,
2454            name: user.name,
2455            email: user.email_address,
2456        })
2457        .collect();
2458    response.send(proto::UsersResponse { users })?;
2459    Ok(())
2460}
2461
2462/// Send a contact request to another user.
2463async fn request_contact(
2464    request: proto::RequestContact,
2465    response: Response<proto::RequestContact>,
2466    session: Session,
2467) -> Result<()> {
2468    let requester_id = session.user_id();
2469    let responder_id = UserId::from_proto(request.responder_id);
2470    if requester_id == responder_id {
2471        return Err(anyhow!("cannot add yourself as a contact"))?;
2472    }
2473
2474    let notifications = session
2475        .db()
2476        .await
2477        .send_contact_request(requester_id, responder_id)
2478        .await?;
2479
2480    // Update outgoing contact requests of requester
2481    let mut update = proto::UpdateContacts::default();
2482    update.outgoing_requests.push(responder_id.to_proto());
2483    for connection_id in session
2484        .connection_pool()
2485        .await
2486        .user_connection_ids(requester_id)
2487    {
2488        session.peer.send(connection_id, update.clone())?;
2489    }
2490
2491    // Update incoming contact requests of responder
2492    let mut update = proto::UpdateContacts::default();
2493    update
2494        .incoming_requests
2495        .push(proto::IncomingContactRequest {
2496            requester_id: requester_id.to_proto(),
2497        });
2498    let connection_pool = session.connection_pool().await;
2499    for connection_id in connection_pool.user_connection_ids(responder_id) {
2500        session.peer.send(connection_id, update.clone())?;
2501    }
2502
2503    send_notifications(&connection_pool, &session.peer, notifications);
2504
2505    response.send(proto::Ack {})?;
2506    Ok(())
2507}
2508
2509/// Accept or decline a contact request
2510async fn respond_to_contact_request(
2511    request: proto::RespondToContactRequest,
2512    response: Response<proto::RespondToContactRequest>,
2513    session: Session,
2514) -> Result<()> {
2515    let responder_id = session.user_id();
2516    let requester_id = UserId::from_proto(request.requester_id);
2517    let db = session.db().await;
2518    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2519        db.dismiss_contact_notification(responder_id, requester_id)
2520            .await?;
2521    } else {
2522        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2523
2524        let notifications = db
2525            .respond_to_contact_request(responder_id, requester_id, accept)
2526            .await?;
2527        let requester_busy = db.is_user_busy(requester_id).await?;
2528        let responder_busy = db.is_user_busy(responder_id).await?;
2529
2530        let pool = session.connection_pool().await;
2531        // Update responder with new contact
2532        let mut update = proto::UpdateContacts::default();
2533        if accept {
2534            update
2535                .contacts
2536                .push(contact_for_user(requester_id, requester_busy, &pool));
2537        }
2538        update
2539            .remove_incoming_requests
2540            .push(requester_id.to_proto());
2541        for connection_id in pool.user_connection_ids(responder_id) {
2542            session.peer.send(connection_id, update.clone())?;
2543        }
2544
2545        // Update requester with new contact
2546        let mut update = proto::UpdateContacts::default();
2547        if accept {
2548            update
2549                .contacts
2550                .push(contact_for_user(responder_id, responder_busy, &pool));
2551        }
2552        update
2553            .remove_outgoing_requests
2554            .push(responder_id.to_proto());
2555
2556        for connection_id in pool.user_connection_ids(requester_id) {
2557            session.peer.send(connection_id, update.clone())?;
2558        }
2559
2560        send_notifications(&pool, &session.peer, notifications);
2561    }
2562
2563    response.send(proto::Ack {})?;
2564    Ok(())
2565}
2566
2567/// Remove a contact.
2568async fn remove_contact(
2569    request: proto::RemoveContact,
2570    response: Response<proto::RemoveContact>,
2571    session: Session,
2572) -> Result<()> {
2573    let requester_id = session.user_id();
2574    let responder_id = UserId::from_proto(request.user_id);
2575    let db = session.db().await;
2576    let (contact_accepted, deleted_notification_id) =
2577        db.remove_contact(requester_id, responder_id).await?;
2578
2579    let pool = session.connection_pool().await;
2580    // Update outgoing contact requests of requester
2581    let mut update = proto::UpdateContacts::default();
2582    if contact_accepted {
2583        update.remove_contacts.push(responder_id.to_proto());
2584    } else {
2585        update
2586            .remove_outgoing_requests
2587            .push(responder_id.to_proto());
2588    }
2589    for connection_id in pool.user_connection_ids(requester_id) {
2590        session.peer.send(connection_id, update.clone())?;
2591    }
2592
2593    // Update incoming contact requests of responder
2594    let mut update = proto::UpdateContacts::default();
2595    if contact_accepted {
2596        update.remove_contacts.push(requester_id.to_proto());
2597    } else {
2598        update
2599            .remove_incoming_requests
2600            .push(requester_id.to_proto());
2601    }
2602    for connection_id in pool.user_connection_ids(responder_id) {
2603        session.peer.send(connection_id, update.clone())?;
2604        if let Some(notification_id) = deleted_notification_id {
2605            session.peer.send(
2606                connection_id,
2607                proto::DeleteNotification {
2608                    notification_id: notification_id.to_proto(),
2609                },
2610            )?;
2611        }
2612    }
2613
2614    response.send(proto::Ack {})?;
2615    Ok(())
2616}
2617
2618fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2619    version.0.minor() < 139
2620}
2621
2622async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2623    let plan = session.current_plan(&session.db().await).await?;
2624
2625    session
2626        .peer
2627        .send(
2628            session.connection_id,
2629            proto::UpdateUserPlan { plan: plan.into() },
2630        )
2631        .trace_err();
2632
2633    Ok(())
2634}
2635
2636async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2637    subscribe_user_to_channels(session.user_id(), &session).await?;
2638    Ok(())
2639}
2640
2641async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2642    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2643    let mut pool = session.connection_pool().await;
2644    for membership in &channels_for_user.channel_memberships {
2645        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2646    }
2647    session.peer.send(
2648        session.connection_id,
2649        build_update_user_channels(&channels_for_user),
2650    )?;
2651    session.peer.send(
2652        session.connection_id,
2653        build_channels_update(channels_for_user),
2654    )?;
2655    Ok(())
2656}
2657
2658/// Creates a new channel.
2659async fn create_channel(
2660    request: proto::CreateChannel,
2661    response: Response<proto::CreateChannel>,
2662    session: Session,
2663) -> Result<()> {
2664    let db = session.db().await;
2665
2666    let parent_id = request.parent_id.map(ChannelId::from_proto);
2667    let (channel, membership) = db
2668        .create_channel(&request.name, parent_id, session.user_id())
2669        .await?;
2670
2671    let root_id = channel.root_id();
2672    let channel = Channel::from_model(channel);
2673
2674    response.send(proto::CreateChannelResponse {
2675        channel: Some(channel.to_proto()),
2676        parent_id: request.parent_id,
2677    })?;
2678
2679    let mut connection_pool = session.connection_pool().await;
2680    if let Some(membership) = membership {
2681        connection_pool.subscribe_to_channel(
2682            membership.user_id,
2683            membership.channel_id,
2684            membership.role,
2685        );
2686        let update = proto::UpdateUserChannels {
2687            channel_memberships: vec![proto::ChannelMembership {
2688                channel_id: membership.channel_id.to_proto(),
2689                role: membership.role.into(),
2690            }],
2691            ..Default::default()
2692        };
2693        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2694            session.peer.send(connection_id, update.clone())?;
2695        }
2696    }
2697
2698    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2699        if !role.can_see_channel(channel.visibility) {
2700            continue;
2701        }
2702
2703        let update = proto::UpdateChannels {
2704            channels: vec![channel.to_proto()],
2705            ..Default::default()
2706        };
2707        session.peer.send(connection_id, update.clone())?;
2708    }
2709
2710    Ok(())
2711}
2712
2713/// Delete a channel
2714async fn delete_channel(
2715    request: proto::DeleteChannel,
2716    response: Response<proto::DeleteChannel>,
2717    session: Session,
2718) -> Result<()> {
2719    let db = session.db().await;
2720
2721    let channel_id = request.channel_id;
2722    let (root_channel, removed_channels) = db
2723        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2724        .await?;
2725    response.send(proto::Ack {})?;
2726
2727    // Notify members of removed channels
2728    let mut update = proto::UpdateChannels::default();
2729    update
2730        .delete_channels
2731        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2732
2733    let connection_pool = session.connection_pool().await;
2734    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2735        session.peer.send(connection_id, update.clone())?;
2736    }
2737
2738    Ok(())
2739}
2740
2741/// Invite someone to join a channel.
2742async fn invite_channel_member(
2743    request: proto::InviteChannelMember,
2744    response: Response<proto::InviteChannelMember>,
2745    session: Session,
2746) -> Result<()> {
2747    let db = session.db().await;
2748    let channel_id = ChannelId::from_proto(request.channel_id);
2749    let invitee_id = UserId::from_proto(request.user_id);
2750    let InviteMemberResult {
2751        channel,
2752        notifications,
2753    } = db
2754        .invite_channel_member(
2755            channel_id,
2756            invitee_id,
2757            session.user_id(),
2758            request.role().into(),
2759        )
2760        .await?;
2761
2762    let update = proto::UpdateChannels {
2763        channel_invitations: vec![channel.to_proto()],
2764        ..Default::default()
2765    };
2766
2767    let connection_pool = session.connection_pool().await;
2768    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2769        session.peer.send(connection_id, update.clone())?;
2770    }
2771
2772    send_notifications(&connection_pool, &session.peer, notifications);
2773
2774    response.send(proto::Ack {})?;
2775    Ok(())
2776}
2777
2778/// remove someone from a channel
2779async fn remove_channel_member(
2780    request: proto::RemoveChannelMember,
2781    response: Response<proto::RemoveChannelMember>,
2782    session: Session,
2783) -> Result<()> {
2784    let db = session.db().await;
2785    let channel_id = ChannelId::from_proto(request.channel_id);
2786    let member_id = UserId::from_proto(request.user_id);
2787
2788    let RemoveChannelMemberResult {
2789        membership_update,
2790        notification_id,
2791    } = db
2792        .remove_channel_member(channel_id, member_id, session.user_id())
2793        .await?;
2794
2795    let mut connection_pool = session.connection_pool().await;
2796    notify_membership_updated(
2797        &mut connection_pool,
2798        membership_update,
2799        member_id,
2800        &session.peer,
2801    );
2802    for connection_id in connection_pool.user_connection_ids(member_id) {
2803        if let Some(notification_id) = notification_id {
2804            session
2805                .peer
2806                .send(
2807                    connection_id,
2808                    proto::DeleteNotification {
2809                        notification_id: notification_id.to_proto(),
2810                    },
2811                )
2812                .trace_err();
2813        }
2814    }
2815
2816    response.send(proto::Ack {})?;
2817    Ok(())
2818}
2819
2820/// Toggle the channel between public and private.
2821/// Care is taken to maintain the invariant that public channels only descend from public channels,
2822/// (though members-only channels can appear at any point in the hierarchy).
2823async fn set_channel_visibility(
2824    request: proto::SetChannelVisibility,
2825    response: Response<proto::SetChannelVisibility>,
2826    session: Session,
2827) -> Result<()> {
2828    let db = session.db().await;
2829    let channel_id = ChannelId::from_proto(request.channel_id);
2830    let visibility = request.visibility().into();
2831
2832    let channel_model = db
2833        .set_channel_visibility(channel_id, visibility, session.user_id())
2834        .await?;
2835    let root_id = channel_model.root_id();
2836    let channel = Channel::from_model(channel_model);
2837
2838    let mut connection_pool = session.connection_pool().await;
2839    for (user_id, role) in connection_pool
2840        .channel_user_ids(root_id)
2841        .collect::<Vec<_>>()
2842        .into_iter()
2843    {
2844        let update = if role.can_see_channel(channel.visibility) {
2845            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2846            proto::UpdateChannels {
2847                channels: vec![channel.to_proto()],
2848                ..Default::default()
2849            }
2850        } else {
2851            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2852            proto::UpdateChannels {
2853                delete_channels: vec![channel.id.to_proto()],
2854                ..Default::default()
2855            }
2856        };
2857
2858        for connection_id in connection_pool.user_connection_ids(user_id) {
2859            session.peer.send(connection_id, update.clone())?;
2860        }
2861    }
2862
2863    response.send(proto::Ack {})?;
2864    Ok(())
2865}
2866
2867/// Alter the role for a user in the channel.
2868async fn set_channel_member_role(
2869    request: proto::SetChannelMemberRole,
2870    response: Response<proto::SetChannelMemberRole>,
2871    session: Session,
2872) -> Result<()> {
2873    let db = session.db().await;
2874    let channel_id = ChannelId::from_proto(request.channel_id);
2875    let member_id = UserId::from_proto(request.user_id);
2876    let result = db
2877        .set_channel_member_role(
2878            channel_id,
2879            session.user_id(),
2880            member_id,
2881            request.role().into(),
2882        )
2883        .await?;
2884
2885    match result {
2886        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2887            let mut connection_pool = session.connection_pool().await;
2888            notify_membership_updated(
2889                &mut connection_pool,
2890                membership_update,
2891                member_id,
2892                &session.peer,
2893            )
2894        }
2895        db::SetMemberRoleResult::InviteUpdated(channel) => {
2896            let update = proto::UpdateChannels {
2897                channel_invitations: vec![channel.to_proto()],
2898                ..Default::default()
2899            };
2900
2901            for connection_id in session
2902                .connection_pool()
2903                .await
2904                .user_connection_ids(member_id)
2905            {
2906                session.peer.send(connection_id, update.clone())?;
2907            }
2908        }
2909    }
2910
2911    response.send(proto::Ack {})?;
2912    Ok(())
2913}
2914
2915/// Change the name of a channel
2916async fn rename_channel(
2917    request: proto::RenameChannel,
2918    response: Response<proto::RenameChannel>,
2919    session: Session,
2920) -> Result<()> {
2921    let db = session.db().await;
2922    let channel_id = ChannelId::from_proto(request.channel_id);
2923    let channel_model = db
2924        .rename_channel(channel_id, session.user_id(), &request.name)
2925        .await?;
2926    let root_id = channel_model.root_id();
2927    let channel = Channel::from_model(channel_model);
2928
2929    response.send(proto::RenameChannelResponse {
2930        channel: Some(channel.to_proto()),
2931    })?;
2932
2933    let connection_pool = session.connection_pool().await;
2934    let update = proto::UpdateChannels {
2935        channels: vec![channel.to_proto()],
2936        ..Default::default()
2937    };
2938    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2939        if role.can_see_channel(channel.visibility) {
2940            session.peer.send(connection_id, update.clone())?;
2941        }
2942    }
2943
2944    Ok(())
2945}
2946
2947/// Move a channel to a new parent.
2948async fn move_channel(
2949    request: proto::MoveChannel,
2950    response: Response<proto::MoveChannel>,
2951    session: Session,
2952) -> Result<()> {
2953    let channel_id = ChannelId::from_proto(request.channel_id);
2954    let to = ChannelId::from_proto(request.to);
2955
2956    let (root_id, channels) = session
2957        .db()
2958        .await
2959        .move_channel(channel_id, to, session.user_id())
2960        .await?;
2961
2962    let connection_pool = session.connection_pool().await;
2963    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2964        let channels = channels
2965            .iter()
2966            .filter_map(|channel| {
2967                if role.can_see_channel(channel.visibility) {
2968                    Some(channel.to_proto())
2969                } else {
2970                    None
2971                }
2972            })
2973            .collect::<Vec<_>>();
2974        if channels.is_empty() {
2975            continue;
2976        }
2977
2978        let update = proto::UpdateChannels {
2979            channels,
2980            ..Default::default()
2981        };
2982
2983        session.peer.send(connection_id, update.clone())?;
2984    }
2985
2986    response.send(Ack {})?;
2987    Ok(())
2988}
2989
2990/// Get the list of channel members
2991async fn get_channel_members(
2992    request: proto::GetChannelMembers,
2993    response: Response<proto::GetChannelMembers>,
2994    session: Session,
2995) -> Result<()> {
2996    let db = session.db().await;
2997    let channel_id = ChannelId::from_proto(request.channel_id);
2998    let limit = if request.limit == 0 {
2999        u16::MAX as u64
3000    } else {
3001        request.limit
3002    };
3003    let (members, users) = db
3004        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3005        .await?;
3006    response.send(proto::GetChannelMembersResponse { members, users })?;
3007    Ok(())
3008}
3009
3010/// Accept or decline a channel invitation.
3011async fn respond_to_channel_invite(
3012    request: proto::RespondToChannelInvite,
3013    response: Response<proto::RespondToChannelInvite>,
3014    session: Session,
3015) -> Result<()> {
3016    let db = session.db().await;
3017    let channel_id = ChannelId::from_proto(request.channel_id);
3018    let RespondToChannelInvite {
3019        membership_update,
3020        notifications,
3021    } = db
3022        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3023        .await?;
3024
3025    let mut connection_pool = session.connection_pool().await;
3026    if let Some(membership_update) = membership_update {
3027        notify_membership_updated(
3028            &mut connection_pool,
3029            membership_update,
3030            session.user_id(),
3031            &session.peer,
3032        );
3033    } else {
3034        let update = proto::UpdateChannels {
3035            remove_channel_invitations: vec![channel_id.to_proto()],
3036            ..Default::default()
3037        };
3038
3039        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3040            session.peer.send(connection_id, update.clone())?;
3041        }
3042    };
3043
3044    send_notifications(&connection_pool, &session.peer, notifications);
3045
3046    response.send(proto::Ack {})?;
3047
3048    Ok(())
3049}
3050
3051/// Join the channels' room
3052async fn join_channel(
3053    request: proto::JoinChannel,
3054    response: Response<proto::JoinChannel>,
3055    session: Session,
3056) -> Result<()> {
3057    let channel_id = ChannelId::from_proto(request.channel_id);
3058    join_channel_internal(channel_id, Box::new(response), session).await
3059}
3060
3061trait JoinChannelInternalResponse {
3062    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3063}
3064impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3065    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3066        Response::<proto::JoinChannel>::send(self, result)
3067    }
3068}
3069impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3070    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3071        Response::<proto::JoinRoom>::send(self, result)
3072    }
3073}
3074
3075async fn join_channel_internal(
3076    channel_id: ChannelId,
3077    response: Box<impl JoinChannelInternalResponse>,
3078    session: Session,
3079) -> Result<()> {
3080    let joined_room = {
3081        let mut db = session.db().await;
3082        // If zed quits without leaving the room, and the user re-opens zed before the
3083        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3084        // room they were in.
3085        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3086            tracing::info!(
3087                stale_connection_id = %connection,
3088                "cleaning up stale connection",
3089            );
3090            drop(db);
3091            leave_room_for_session(&session, connection).await?;
3092            db = session.db().await;
3093        }
3094
3095        let (joined_room, membership_updated, role) = db
3096            .join_channel(channel_id, session.user_id(), session.connection_id)
3097            .await?;
3098
3099        let live_kit_connection_info =
3100            session
3101                .app_state
3102                .livekit_client
3103                .as_ref()
3104                .and_then(|live_kit| {
3105                    let (can_publish, token) = if role == ChannelRole::Guest {
3106                        (
3107                            false,
3108                            live_kit
3109                                .guest_token(
3110                                    &joined_room.room.livekit_room,
3111                                    &session.user_id().to_string(),
3112                                )
3113                                .trace_err()?,
3114                        )
3115                    } else {
3116                        (
3117                            true,
3118                            live_kit
3119                                .room_token(
3120                                    &joined_room.room.livekit_room,
3121                                    &session.user_id().to_string(),
3122                                )
3123                                .trace_err()?,
3124                        )
3125                    };
3126
3127                    Some(LiveKitConnectionInfo {
3128                        server_url: live_kit.url().into(),
3129                        token,
3130                        can_publish,
3131                    })
3132                });
3133
3134        response.send(proto::JoinRoomResponse {
3135            room: Some(joined_room.room.clone()),
3136            channel_id: joined_room
3137                .channel
3138                .as_ref()
3139                .map(|channel| channel.id.to_proto()),
3140            live_kit_connection_info,
3141        })?;
3142
3143        let mut connection_pool = session.connection_pool().await;
3144        if let Some(membership_updated) = membership_updated {
3145            notify_membership_updated(
3146                &mut connection_pool,
3147                membership_updated,
3148                session.user_id(),
3149                &session.peer,
3150            );
3151        }
3152
3153        room_updated(&joined_room.room, &session.peer);
3154
3155        joined_room
3156    };
3157
3158    channel_updated(
3159        &joined_room
3160            .channel
3161            .ok_or_else(|| anyhow!("channel not returned"))?,
3162        &joined_room.room,
3163        &session.peer,
3164        &*session.connection_pool().await,
3165    );
3166
3167    update_user_contacts(session.user_id(), &session).await?;
3168    Ok(())
3169}
3170
3171/// Start editing the channel notes
3172async fn join_channel_buffer(
3173    request: proto::JoinChannelBuffer,
3174    response: Response<proto::JoinChannelBuffer>,
3175    session: Session,
3176) -> Result<()> {
3177    let db = session.db().await;
3178    let channel_id = ChannelId::from_proto(request.channel_id);
3179
3180    let open_response = db
3181        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3182        .await?;
3183
3184    let collaborators = open_response.collaborators.clone();
3185    response.send(open_response)?;
3186
3187    let update = UpdateChannelBufferCollaborators {
3188        channel_id: channel_id.to_proto(),
3189        collaborators: collaborators.clone(),
3190    };
3191    channel_buffer_updated(
3192        session.connection_id,
3193        collaborators
3194            .iter()
3195            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3196        &update,
3197        &session.peer,
3198    );
3199
3200    Ok(())
3201}
3202
3203/// Edit the channel notes
3204async fn update_channel_buffer(
3205    request: proto::UpdateChannelBuffer,
3206    session: Session,
3207) -> Result<()> {
3208    let db = session.db().await;
3209    let channel_id = ChannelId::from_proto(request.channel_id);
3210
3211    let (collaborators, epoch, version) = db
3212        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3213        .await?;
3214
3215    channel_buffer_updated(
3216        session.connection_id,
3217        collaborators.clone(),
3218        &proto::UpdateChannelBuffer {
3219            channel_id: channel_id.to_proto(),
3220            operations: request.operations,
3221        },
3222        &session.peer,
3223    );
3224
3225    let pool = &*session.connection_pool().await;
3226
3227    let non_collaborators =
3228        pool.channel_connection_ids(channel_id)
3229            .filter_map(|(connection_id, _)| {
3230                if collaborators.contains(&connection_id) {
3231                    None
3232                } else {
3233                    Some(connection_id)
3234                }
3235            });
3236
3237    broadcast(None, non_collaborators, |peer_id| {
3238        session.peer.send(
3239            peer_id,
3240            proto::UpdateChannels {
3241                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3242                    channel_id: channel_id.to_proto(),
3243                    epoch: epoch as u64,
3244                    version: version.clone(),
3245                }],
3246                ..Default::default()
3247            },
3248        )
3249    });
3250
3251    Ok(())
3252}
3253
3254/// Rejoin the channel notes after a connection blip
3255async fn rejoin_channel_buffers(
3256    request: proto::RejoinChannelBuffers,
3257    response: Response<proto::RejoinChannelBuffers>,
3258    session: Session,
3259) -> Result<()> {
3260    let db = session.db().await;
3261    let buffers = db
3262        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3263        .await?;
3264
3265    for rejoined_buffer in &buffers {
3266        let collaborators_to_notify = rejoined_buffer
3267            .buffer
3268            .collaborators
3269            .iter()
3270            .filter_map(|c| Some(c.peer_id?.into()));
3271        channel_buffer_updated(
3272            session.connection_id,
3273            collaborators_to_notify,
3274            &proto::UpdateChannelBufferCollaborators {
3275                channel_id: rejoined_buffer.buffer.channel_id,
3276                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3277            },
3278            &session.peer,
3279        );
3280    }
3281
3282    response.send(proto::RejoinChannelBuffersResponse {
3283        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3284    })?;
3285
3286    Ok(())
3287}
3288
3289/// Stop editing the channel notes
3290async fn leave_channel_buffer(
3291    request: proto::LeaveChannelBuffer,
3292    response: Response<proto::LeaveChannelBuffer>,
3293    session: Session,
3294) -> Result<()> {
3295    let db = session.db().await;
3296    let channel_id = ChannelId::from_proto(request.channel_id);
3297
3298    let left_buffer = db
3299        .leave_channel_buffer(channel_id, session.connection_id)
3300        .await?;
3301
3302    response.send(Ack {})?;
3303
3304    channel_buffer_updated(
3305        session.connection_id,
3306        left_buffer.connections,
3307        &proto::UpdateChannelBufferCollaborators {
3308            channel_id: channel_id.to_proto(),
3309            collaborators: left_buffer.collaborators,
3310        },
3311        &session.peer,
3312    );
3313
3314    Ok(())
3315}
3316
3317fn channel_buffer_updated<T: EnvelopedMessage>(
3318    sender_id: ConnectionId,
3319    collaborators: impl IntoIterator<Item = ConnectionId>,
3320    message: &T,
3321    peer: &Peer,
3322) {
3323    broadcast(Some(sender_id), collaborators, |peer_id| {
3324        peer.send(peer_id, message.clone())
3325    });
3326}
3327
3328fn send_notifications(
3329    connection_pool: &ConnectionPool,
3330    peer: &Peer,
3331    notifications: db::NotificationBatch,
3332) {
3333    for (user_id, notification) in notifications {
3334        for connection_id in connection_pool.user_connection_ids(user_id) {
3335            if let Err(error) = peer.send(
3336                connection_id,
3337                proto::AddNotification {
3338                    notification: Some(notification.clone()),
3339                },
3340            ) {
3341                tracing::error!(
3342                    "failed to send notification to {:?} {}",
3343                    connection_id,
3344                    error
3345                );
3346            }
3347        }
3348    }
3349}
3350
3351/// Send a message to the channel
3352async fn send_channel_message(
3353    request: proto::SendChannelMessage,
3354    response: Response<proto::SendChannelMessage>,
3355    session: Session,
3356) -> Result<()> {
3357    // Validate the message body.
3358    let body = request.body.trim().to_string();
3359    if body.len() > MAX_MESSAGE_LEN {
3360        return Err(anyhow!("message is too long"))?;
3361    }
3362    if body.is_empty() {
3363        return Err(anyhow!("message can't be blank"))?;
3364    }
3365
3366    // TODO: adjust mentions if body is trimmed
3367
3368    let timestamp = OffsetDateTime::now_utc();
3369    let nonce = request
3370        .nonce
3371        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3372
3373    let channel_id = ChannelId::from_proto(request.channel_id);
3374    let CreatedChannelMessage {
3375        message_id,
3376        participant_connection_ids,
3377        notifications,
3378    } = session
3379        .db()
3380        .await
3381        .create_channel_message(
3382            channel_id,
3383            session.user_id(),
3384            &body,
3385            &request.mentions,
3386            timestamp,
3387            nonce.clone().into(),
3388            request.reply_to_message_id.map(MessageId::from_proto),
3389        )
3390        .await?;
3391
3392    let message = proto::ChannelMessage {
3393        sender_id: session.user_id().to_proto(),
3394        id: message_id.to_proto(),
3395        body,
3396        mentions: request.mentions,
3397        timestamp: timestamp.unix_timestamp() as u64,
3398        nonce: Some(nonce),
3399        reply_to_message_id: request.reply_to_message_id,
3400        edited_at: None,
3401    };
3402    broadcast(
3403        Some(session.connection_id),
3404        participant_connection_ids.clone(),
3405        |connection| {
3406            session.peer.send(
3407                connection,
3408                proto::ChannelMessageSent {
3409                    channel_id: channel_id.to_proto(),
3410                    message: Some(message.clone()),
3411                },
3412            )
3413        },
3414    );
3415    response.send(proto::SendChannelMessageResponse {
3416        message: Some(message),
3417    })?;
3418
3419    let pool = &*session.connection_pool().await;
3420    let non_participants =
3421        pool.channel_connection_ids(channel_id)
3422            .filter_map(|(connection_id, _)| {
3423                if participant_connection_ids.contains(&connection_id) {
3424                    None
3425                } else {
3426                    Some(connection_id)
3427                }
3428            });
3429    broadcast(None, non_participants, |peer_id| {
3430        session.peer.send(
3431            peer_id,
3432            proto::UpdateChannels {
3433                latest_channel_message_ids: vec![proto::ChannelMessageId {
3434                    channel_id: channel_id.to_proto(),
3435                    message_id: message_id.to_proto(),
3436                }],
3437                ..Default::default()
3438            },
3439        )
3440    });
3441    send_notifications(pool, &session.peer, notifications);
3442
3443    Ok(())
3444}
3445
3446/// Delete a channel message
3447async fn remove_channel_message(
3448    request: proto::RemoveChannelMessage,
3449    response: Response<proto::RemoveChannelMessage>,
3450    session: Session,
3451) -> Result<()> {
3452    let channel_id = ChannelId::from_proto(request.channel_id);
3453    let message_id = MessageId::from_proto(request.message_id);
3454    let (connection_ids, existing_notification_ids) = session
3455        .db()
3456        .await
3457        .remove_channel_message(channel_id, message_id, session.user_id())
3458        .await?;
3459
3460    broadcast(
3461        Some(session.connection_id),
3462        connection_ids,
3463        move |connection| {
3464            session.peer.send(connection, request.clone())?;
3465
3466            for notification_id in &existing_notification_ids {
3467                session.peer.send(
3468                    connection,
3469                    proto::DeleteNotification {
3470                        notification_id: (*notification_id).to_proto(),
3471                    },
3472                )?;
3473            }
3474
3475            Ok(())
3476        },
3477    );
3478    response.send(proto::Ack {})?;
3479    Ok(())
3480}
3481
3482async fn update_channel_message(
3483    request: proto::UpdateChannelMessage,
3484    response: Response<proto::UpdateChannelMessage>,
3485    session: Session,
3486) -> Result<()> {
3487    let channel_id = ChannelId::from_proto(request.channel_id);
3488    let message_id = MessageId::from_proto(request.message_id);
3489    let updated_at = OffsetDateTime::now_utc();
3490    let UpdatedChannelMessage {
3491        message_id,
3492        participant_connection_ids,
3493        notifications,
3494        reply_to_message_id,
3495        timestamp,
3496        deleted_mention_notification_ids,
3497        updated_mention_notifications,
3498    } = session
3499        .db()
3500        .await
3501        .update_channel_message(
3502            channel_id,
3503            message_id,
3504            session.user_id(),
3505            request.body.as_str(),
3506            &request.mentions,
3507            updated_at,
3508        )
3509        .await?;
3510
3511    let nonce = request
3512        .nonce
3513        .clone()
3514        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3515
3516    let message = proto::ChannelMessage {
3517        sender_id: session.user_id().to_proto(),
3518        id: message_id.to_proto(),
3519        body: request.body.clone(),
3520        mentions: request.mentions.clone(),
3521        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3522        nonce: Some(nonce),
3523        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3524        edited_at: Some(updated_at.unix_timestamp() as u64),
3525    };
3526
3527    response.send(proto::Ack {})?;
3528
3529    let pool = &*session.connection_pool().await;
3530    broadcast(
3531        Some(session.connection_id),
3532        participant_connection_ids,
3533        |connection| {
3534            session.peer.send(
3535                connection,
3536                proto::ChannelMessageUpdate {
3537                    channel_id: channel_id.to_proto(),
3538                    message: Some(message.clone()),
3539                },
3540            )?;
3541
3542            for notification_id in &deleted_mention_notification_ids {
3543                session.peer.send(
3544                    connection,
3545                    proto::DeleteNotification {
3546                        notification_id: (*notification_id).to_proto(),
3547                    },
3548                )?;
3549            }
3550
3551            for notification in &updated_mention_notifications {
3552                session.peer.send(
3553                    connection,
3554                    proto::UpdateNotification {
3555                        notification: Some(notification.clone()),
3556                    },
3557                )?;
3558            }
3559
3560            Ok(())
3561        },
3562    );
3563
3564    send_notifications(pool, &session.peer, notifications);
3565
3566    Ok(())
3567}
3568
3569/// Mark a channel message as read
3570async fn acknowledge_channel_message(
3571    request: proto::AckChannelMessage,
3572    session: Session,
3573) -> Result<()> {
3574    let channel_id = ChannelId::from_proto(request.channel_id);
3575    let message_id = MessageId::from_proto(request.message_id);
3576    let notifications = session
3577        .db()
3578        .await
3579        .observe_channel_message(channel_id, session.user_id(), message_id)
3580        .await?;
3581    send_notifications(
3582        &*session.connection_pool().await,
3583        &session.peer,
3584        notifications,
3585    );
3586    Ok(())
3587}
3588
3589/// Mark a buffer version as synced
3590async fn acknowledge_buffer_version(
3591    request: proto::AckBufferOperation,
3592    session: Session,
3593) -> Result<()> {
3594    let buffer_id = BufferId::from_proto(request.buffer_id);
3595    session
3596        .db()
3597        .await
3598        .observe_buffer_version(
3599            buffer_id,
3600            session.user_id(),
3601            request.epoch as i32,
3602            &request.version,
3603        )
3604        .await?;
3605    Ok(())
3606}
3607
3608async fn count_language_model_tokens(
3609    request: proto::CountLanguageModelTokens,
3610    response: Response<proto::CountLanguageModelTokens>,
3611    session: Session,
3612    config: &Config,
3613) -> Result<()> {
3614    authorize_access_to_legacy_llm_endpoints(&session).await?;
3615
3616    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3617        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3618        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3619    };
3620
3621    session
3622        .app_state
3623        .rate_limiter
3624        .check(&*rate_limit, session.user_id())
3625        .await?;
3626
3627    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3628        Some(proto::LanguageModelProvider::Google) => {
3629            let api_key = config
3630                .google_ai_api_key
3631                .as_ref()
3632                .context("no Google AI API key configured on the server")?;
3633            google_ai::count_tokens(
3634                session.http_client.as_ref(),
3635                google_ai::API_URL,
3636                api_key,
3637                serde_json::from_str(&request.request)?,
3638            )
3639            .await?
3640        }
3641        _ => return Err(anyhow!("unsupported provider"))?,
3642    };
3643
3644    response.send(proto::CountLanguageModelTokensResponse {
3645        token_count: result.total_tokens as u32,
3646    })?;
3647
3648    Ok(())
3649}
3650
3651struct ZedProCountLanguageModelTokensRateLimit;
3652
3653impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3654    fn capacity(&self) -> usize {
3655        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3656            .ok()
3657            .and_then(|v| v.parse().ok())
3658            .unwrap_or(600) // Picked arbitrarily
3659    }
3660
3661    fn refill_duration(&self) -> chrono::Duration {
3662        chrono::Duration::hours(1)
3663    }
3664
3665    fn db_name(&self) -> &'static str {
3666        "zed-pro:count-language-model-tokens"
3667    }
3668}
3669
3670struct FreeCountLanguageModelTokensRateLimit;
3671
3672impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3673    fn capacity(&self) -> usize {
3674        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3675            .ok()
3676            .and_then(|v| v.parse().ok())
3677            .unwrap_or(600 / 10) // Picked arbitrarily
3678    }
3679
3680    fn refill_duration(&self) -> chrono::Duration {
3681        chrono::Duration::hours(1)
3682    }
3683
3684    fn db_name(&self) -> &'static str {
3685        "free:count-language-model-tokens"
3686    }
3687}
3688
3689struct ZedProComputeEmbeddingsRateLimit;
3690
3691impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3692    fn capacity(&self) -> usize {
3693        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3694            .ok()
3695            .and_then(|v| v.parse().ok())
3696            .unwrap_or(5000) // Picked arbitrarily
3697    }
3698
3699    fn refill_duration(&self) -> chrono::Duration {
3700        chrono::Duration::hours(1)
3701    }
3702
3703    fn db_name(&self) -> &'static str {
3704        "zed-pro:compute-embeddings"
3705    }
3706}
3707
3708struct FreeComputeEmbeddingsRateLimit;
3709
3710impl RateLimit for FreeComputeEmbeddingsRateLimit {
3711    fn capacity(&self) -> usize {
3712        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3713            .ok()
3714            .and_then(|v| v.parse().ok())
3715            .unwrap_or(5000 / 10) // Picked arbitrarily
3716    }
3717
3718    fn refill_duration(&self) -> chrono::Duration {
3719        chrono::Duration::hours(1)
3720    }
3721
3722    fn db_name(&self) -> &'static str {
3723        "free:compute-embeddings"
3724    }
3725}
3726
3727async fn compute_embeddings(
3728    request: proto::ComputeEmbeddings,
3729    response: Response<proto::ComputeEmbeddings>,
3730    session: Session,
3731    api_key: Option<Arc<str>>,
3732) -> Result<()> {
3733    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3734    authorize_access_to_legacy_llm_endpoints(&session).await?;
3735
3736    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3737        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3738        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3739    };
3740
3741    session
3742        .app_state
3743        .rate_limiter
3744        .check(&*rate_limit, session.user_id())
3745        .await?;
3746
3747    let embeddings = match request.model.as_str() {
3748        "openai/text-embedding-3-small" => {
3749            open_ai::embed(
3750                session.http_client.as_ref(),
3751                OPEN_AI_API_URL,
3752                &api_key,
3753                OpenAiEmbeddingModel::TextEmbedding3Small,
3754                request.texts.iter().map(|text| text.as_str()),
3755            )
3756            .await?
3757        }
3758        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3759    };
3760
3761    let embeddings = request
3762        .texts
3763        .iter()
3764        .map(|text| {
3765            let mut hasher = sha2::Sha256::new();
3766            hasher.update(text.as_bytes());
3767            let result = hasher.finalize();
3768            result.to_vec()
3769        })
3770        .zip(
3771            embeddings
3772                .data
3773                .into_iter()
3774                .map(|embedding| embedding.embedding),
3775        )
3776        .collect::<HashMap<_, _>>();
3777
3778    let db = session.db().await;
3779    db.save_embeddings(&request.model, &embeddings)
3780        .await
3781        .context("failed to save embeddings")
3782        .trace_err();
3783
3784    response.send(proto::ComputeEmbeddingsResponse {
3785        embeddings: embeddings
3786            .into_iter()
3787            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3788            .collect(),
3789    })?;
3790    Ok(())
3791}
3792
3793async fn get_cached_embeddings(
3794    request: proto::GetCachedEmbeddings,
3795    response: Response<proto::GetCachedEmbeddings>,
3796    session: Session,
3797) -> Result<()> {
3798    authorize_access_to_legacy_llm_endpoints(&session).await?;
3799
3800    let db = session.db().await;
3801    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3802
3803    response.send(proto::GetCachedEmbeddingsResponse {
3804        embeddings: embeddings
3805            .into_iter()
3806            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3807            .collect(),
3808    })?;
3809    Ok(())
3810}
3811
3812/// This is leftover from before the LLM service.
3813///
3814/// The endpoints protected by this check will be moved there eventually.
3815async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3816    if session.is_staff() {
3817        Ok(())
3818    } else {
3819        Err(anyhow!("permission denied"))?
3820    }
3821}
3822
3823/// Get a Supermaven API key for the user
3824async fn get_supermaven_api_key(
3825    _request: proto::GetSupermavenApiKey,
3826    response: Response<proto::GetSupermavenApiKey>,
3827    session: Session,
3828) -> Result<()> {
3829    let user_id: String = session.user_id().to_string();
3830    if !session.is_staff() {
3831        return Err(anyhow!("supermaven not enabled for this account"))?;
3832    }
3833
3834    let email = session
3835        .email()
3836        .ok_or_else(|| anyhow!("user must have an email"))?;
3837
3838    let supermaven_admin_api = session
3839        .supermaven_client
3840        .as_ref()
3841        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3842
3843    let result = supermaven_admin_api
3844        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3845        .await?;
3846
3847    response.send(proto::GetSupermavenApiKeyResponse {
3848        api_key: result.api_key,
3849    })?;
3850
3851    Ok(())
3852}
3853
3854/// Start receiving chat updates for a channel
3855async fn join_channel_chat(
3856    request: proto::JoinChannelChat,
3857    response: Response<proto::JoinChannelChat>,
3858    session: Session,
3859) -> Result<()> {
3860    let channel_id = ChannelId::from_proto(request.channel_id);
3861
3862    let db = session.db().await;
3863    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3864        .await?;
3865    let messages = db
3866        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3867        .await?;
3868    response.send(proto::JoinChannelChatResponse {
3869        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3870        messages,
3871    })?;
3872    Ok(())
3873}
3874
3875/// Stop receiving chat updates for a channel
3876async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3877    let channel_id = ChannelId::from_proto(request.channel_id);
3878    session
3879        .db()
3880        .await
3881        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3882        .await?;
3883    Ok(())
3884}
3885
3886/// Retrieve the chat history for a channel
3887async fn get_channel_messages(
3888    request: proto::GetChannelMessages,
3889    response: Response<proto::GetChannelMessages>,
3890    session: Session,
3891) -> Result<()> {
3892    let channel_id = ChannelId::from_proto(request.channel_id);
3893    let messages = session
3894        .db()
3895        .await
3896        .get_channel_messages(
3897            channel_id,
3898            session.user_id(),
3899            MESSAGE_COUNT_PER_PAGE,
3900            Some(MessageId::from_proto(request.before_message_id)),
3901        )
3902        .await?;
3903    response.send(proto::GetChannelMessagesResponse {
3904        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3905        messages,
3906    })?;
3907    Ok(())
3908}
3909
3910/// Retrieve specific chat messages
3911async fn get_channel_messages_by_id(
3912    request: proto::GetChannelMessagesById,
3913    response: Response<proto::GetChannelMessagesById>,
3914    session: Session,
3915) -> Result<()> {
3916    let message_ids = request
3917        .message_ids
3918        .iter()
3919        .map(|id| MessageId::from_proto(*id))
3920        .collect::<Vec<_>>();
3921    let messages = session
3922        .db()
3923        .await
3924        .get_channel_messages_by_id(session.user_id(), &message_ids)
3925        .await?;
3926    response.send(proto::GetChannelMessagesResponse {
3927        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3928        messages,
3929    })?;
3930    Ok(())
3931}
3932
3933/// Retrieve the current users notifications
3934async fn get_notifications(
3935    request: proto::GetNotifications,
3936    response: Response<proto::GetNotifications>,
3937    session: Session,
3938) -> Result<()> {
3939    let notifications = session
3940        .db()
3941        .await
3942        .get_notifications(
3943            session.user_id(),
3944            NOTIFICATION_COUNT_PER_PAGE,
3945            request.before_id.map(db::NotificationId::from_proto),
3946        )
3947        .await?;
3948    response.send(proto::GetNotificationsResponse {
3949        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3950        notifications,
3951    })?;
3952    Ok(())
3953}
3954
3955/// Mark notifications as read
3956async fn mark_notification_as_read(
3957    request: proto::MarkNotificationRead,
3958    response: Response<proto::MarkNotificationRead>,
3959    session: Session,
3960) -> Result<()> {
3961    let database = &session.db().await;
3962    let notifications = database
3963        .mark_notification_as_read_by_id(
3964            session.user_id(),
3965            NotificationId::from_proto(request.notification_id),
3966        )
3967        .await?;
3968    send_notifications(
3969        &*session.connection_pool().await,
3970        &session.peer,
3971        notifications,
3972    );
3973    response.send(proto::Ack {})?;
3974    Ok(())
3975}
3976
3977/// Get the current users information
3978async fn get_private_user_info(
3979    _request: proto::GetPrivateUserInfo,
3980    response: Response<proto::GetPrivateUserInfo>,
3981    session: Session,
3982) -> Result<()> {
3983    let db = session.db().await;
3984
3985    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3986    let user = db
3987        .get_user_by_id(session.user_id())
3988        .await?
3989        .ok_or_else(|| anyhow!("user not found"))?;
3990    let flags = db.get_user_flags(session.user_id()).await?;
3991
3992    response.send(proto::GetPrivateUserInfoResponse {
3993        metrics_id,
3994        staff: user.admin,
3995        flags,
3996        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3997    })?;
3998    Ok(())
3999}
4000
4001/// Accept the terms of service (tos) on behalf of the current user
4002async fn accept_terms_of_service(
4003    _request: proto::AcceptTermsOfService,
4004    response: Response<proto::AcceptTermsOfService>,
4005    session: Session,
4006) -> Result<()> {
4007    let db = session.db().await;
4008
4009    let accepted_tos_at = Utc::now();
4010    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4011        .await?;
4012
4013    response.send(proto::AcceptTermsOfServiceResponse {
4014        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4015    })?;
4016    Ok(())
4017}
4018
4019/// The minimum account age an account must have in order to use the LLM service.
4020const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4021
4022async fn get_llm_api_token(
4023    _request: proto::GetLlmToken,
4024    response: Response<proto::GetLlmToken>,
4025    session: Session,
4026) -> Result<()> {
4027    let db = session.db().await;
4028
4029    let flags = db.get_user_flags(session.user_id()).await?;
4030    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4031    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4032    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4033
4034    if !session.is_staff() && !has_language_models_feature_flag {
4035        Err(anyhow!("permission denied"))?
4036    }
4037
4038    let user_id = session.user_id();
4039    let user = db
4040        .get_user_by_id(user_id)
4041        .await?
4042        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4043
4044    if user.accepted_tos_at.is_none() {
4045        Err(anyhow!("terms of service not accepted"))?
4046    }
4047
4048    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4049
4050    let bypass_account_age_check =
4051        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4052    if !bypass_account_age_check {
4053        let mut account_created_at = user.created_at;
4054        if let Some(github_created_at) = user.github_user_created_at {
4055            account_created_at = account_created_at.min(github_created_at);
4056        }
4057        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4058            Err(anyhow!("account too young"))?
4059        }
4060    }
4061
4062    let billing_preferences = db.get_billing_preferences(user.id).await?;
4063
4064    let token = LlmTokenClaims::create(
4065        &user,
4066        session.is_staff(),
4067        billing_preferences,
4068        has_llm_closed_beta_feature_flag,
4069        has_predict_edits_feature_flag,
4070        has_llm_subscription,
4071        session.current_plan(&db).await?,
4072        session.system_id.clone(),
4073        &session.app_state.config,
4074    )?;
4075    response.send(proto::GetLlmTokenResponse { token })?;
4076    Ok(())
4077}
4078
4079fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4080    let message = match message {
4081        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4082        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4083        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4084        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4085        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4086            code: frame.code.into(),
4087            reason: frame.reason,
4088        })),
4089        // We should never receive a frame while reading the message, according
4090        // to the `tungstenite` maintainers:
4091        //
4092        // > It cannot occur when you read messages from the WebSocket, but it
4093        // > can be used when you want to send the raw frames (e.g. you want to
4094        // > send the frames to the WebSocket without composing the full message first).
4095        // >
4096        // > — https://github.com/snapview/tungstenite-rs/issues/268
4097        TungsteniteMessage::Frame(_) => {
4098            bail!("received an unexpected frame while reading the message")
4099        }
4100    };
4101
4102    Ok(message)
4103}
4104
4105fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4106    match message {
4107        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4108        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4109        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4110        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4111        AxumMessage::Close(frame) => {
4112            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4113                code: frame.code.into(),
4114                reason: frame.reason,
4115            }))
4116        }
4117    }
4118}
4119
4120fn notify_membership_updated(
4121    connection_pool: &mut ConnectionPool,
4122    result: MembershipUpdated,
4123    user_id: UserId,
4124    peer: &Peer,
4125) {
4126    for membership in &result.new_channels.channel_memberships {
4127        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4128    }
4129    for channel_id in &result.removed_channels {
4130        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4131    }
4132
4133    let user_channels_update = proto::UpdateUserChannels {
4134        channel_memberships: result
4135            .new_channels
4136            .channel_memberships
4137            .iter()
4138            .map(|cm| proto::ChannelMembership {
4139                channel_id: cm.channel_id.to_proto(),
4140                role: cm.role.into(),
4141            })
4142            .collect(),
4143        ..Default::default()
4144    };
4145
4146    let mut update = build_channels_update(result.new_channels);
4147    update.delete_channels = result
4148        .removed_channels
4149        .into_iter()
4150        .map(|id| id.to_proto())
4151        .collect();
4152    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4153
4154    for connection_id in connection_pool.user_connection_ids(user_id) {
4155        peer.send(connection_id, user_channels_update.clone())
4156            .trace_err();
4157        peer.send(connection_id, update.clone()).trace_err();
4158    }
4159}
4160
4161fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4162    proto::UpdateUserChannels {
4163        channel_memberships: channels
4164            .channel_memberships
4165            .iter()
4166            .map(|m| proto::ChannelMembership {
4167                channel_id: m.channel_id.to_proto(),
4168                role: m.role.into(),
4169            })
4170            .collect(),
4171        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4172        observed_channel_message_id: channels.observed_channel_messages.clone(),
4173    }
4174}
4175
4176fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4177    let mut update = proto::UpdateChannels::default();
4178
4179    for channel in channels.channels {
4180        update.channels.push(channel.to_proto());
4181    }
4182
4183    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4184    update.latest_channel_message_ids = channels.latest_channel_messages;
4185
4186    for (channel_id, participants) in channels.channel_participants {
4187        update
4188            .channel_participants
4189            .push(proto::ChannelParticipants {
4190                channel_id: channel_id.to_proto(),
4191                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4192            });
4193    }
4194
4195    for channel in channels.invited_channels {
4196        update.channel_invitations.push(channel.to_proto());
4197    }
4198
4199    update
4200}
4201
4202fn build_initial_contacts_update(
4203    contacts: Vec<db::Contact>,
4204    pool: &ConnectionPool,
4205) -> proto::UpdateContacts {
4206    let mut update = proto::UpdateContacts::default();
4207
4208    for contact in contacts {
4209        match contact {
4210            db::Contact::Accepted { user_id, busy } => {
4211                update.contacts.push(contact_for_user(user_id, busy, pool));
4212            }
4213            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4214            db::Contact::Incoming { user_id } => {
4215                update
4216                    .incoming_requests
4217                    .push(proto::IncomingContactRequest {
4218                        requester_id: user_id.to_proto(),
4219                    })
4220            }
4221        }
4222    }
4223
4224    update
4225}
4226
4227fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4228    proto::Contact {
4229        user_id: user_id.to_proto(),
4230        online: pool.is_user_online(user_id),
4231        busy,
4232    }
4233}
4234
4235fn room_updated(room: &proto::Room, peer: &Peer) {
4236    broadcast(
4237        None,
4238        room.participants
4239            .iter()
4240            .filter_map(|participant| Some(participant.peer_id?.into())),
4241        |peer_id| {
4242            peer.send(
4243                peer_id,
4244                proto::RoomUpdated {
4245                    room: Some(room.clone()),
4246                },
4247            )
4248        },
4249    );
4250}
4251
4252fn channel_updated(
4253    channel: &db::channel::Model,
4254    room: &proto::Room,
4255    peer: &Peer,
4256    pool: &ConnectionPool,
4257) {
4258    let participants = room
4259        .participants
4260        .iter()
4261        .map(|p| p.user_id)
4262        .collect::<Vec<_>>();
4263
4264    broadcast(
4265        None,
4266        pool.channel_connection_ids(channel.root_id())
4267            .filter_map(|(channel_id, role)| {
4268                role.can_see_channel(channel.visibility)
4269                    .then_some(channel_id)
4270            }),
4271        |peer_id| {
4272            peer.send(
4273                peer_id,
4274                proto::UpdateChannels {
4275                    channel_participants: vec![proto::ChannelParticipants {
4276                        channel_id: channel.id.to_proto(),
4277                        participant_user_ids: participants.clone(),
4278                    }],
4279                    ..Default::default()
4280                },
4281            )
4282        },
4283    );
4284}
4285
4286async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4287    let db = session.db().await;
4288
4289    let contacts = db.get_contacts(user_id).await?;
4290    let busy = db.is_user_busy(user_id).await?;
4291
4292    let pool = session.connection_pool().await;
4293    let updated_contact = contact_for_user(user_id, busy, &pool);
4294    for contact in contacts {
4295        if let db::Contact::Accepted {
4296            user_id: contact_user_id,
4297            ..
4298        } = contact
4299        {
4300            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4301                session
4302                    .peer
4303                    .send(
4304                        contact_conn_id,
4305                        proto::UpdateContacts {
4306                            contacts: vec![updated_contact.clone()],
4307                            remove_contacts: Default::default(),
4308                            incoming_requests: Default::default(),
4309                            remove_incoming_requests: Default::default(),
4310                            outgoing_requests: Default::default(),
4311                            remove_outgoing_requests: Default::default(),
4312                        },
4313                    )
4314                    .trace_err();
4315            }
4316        }
4317    }
4318    Ok(())
4319}
4320
4321async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4322    let mut contacts_to_update = HashSet::default();
4323
4324    let room_id;
4325    let canceled_calls_to_user_ids;
4326    let livekit_room;
4327    let delete_livekit_room;
4328    let room;
4329    let channel;
4330
4331    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4332        contacts_to_update.insert(session.user_id());
4333
4334        for project in left_room.left_projects.values() {
4335            project_left(project, session);
4336        }
4337
4338        room_id = RoomId::from_proto(left_room.room.id);
4339        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4340        livekit_room = mem::take(&mut left_room.room.livekit_room);
4341        delete_livekit_room = left_room.deleted;
4342        room = mem::take(&mut left_room.room);
4343        channel = mem::take(&mut left_room.channel);
4344
4345        room_updated(&room, &session.peer);
4346    } else {
4347        return Ok(());
4348    }
4349
4350    if let Some(channel) = channel {
4351        channel_updated(
4352            &channel,
4353            &room,
4354            &session.peer,
4355            &*session.connection_pool().await,
4356        );
4357    }
4358
4359    {
4360        let pool = session.connection_pool().await;
4361        for canceled_user_id in canceled_calls_to_user_ids {
4362            for connection_id in pool.user_connection_ids(canceled_user_id) {
4363                session
4364                    .peer
4365                    .send(
4366                        connection_id,
4367                        proto::CallCanceled {
4368                            room_id: room_id.to_proto(),
4369                        },
4370                    )
4371                    .trace_err();
4372            }
4373            contacts_to_update.insert(canceled_user_id);
4374        }
4375    }
4376
4377    for contact_user_id in contacts_to_update {
4378        update_user_contacts(contact_user_id, session).await?;
4379    }
4380
4381    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4382        live_kit
4383            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4384            .await
4385            .trace_err();
4386
4387        if delete_livekit_room {
4388            live_kit.delete_room(livekit_room).await.trace_err();
4389        }
4390    }
4391
4392    Ok(())
4393}
4394
4395async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4396    let left_channel_buffers = session
4397        .db()
4398        .await
4399        .leave_channel_buffers(session.connection_id)
4400        .await?;
4401
4402    for left_buffer in left_channel_buffers {
4403        channel_buffer_updated(
4404            session.connection_id,
4405            left_buffer.connections,
4406            &proto::UpdateChannelBufferCollaborators {
4407                channel_id: left_buffer.channel_id.to_proto(),
4408                collaborators: left_buffer.collaborators,
4409            },
4410            &session.peer,
4411        );
4412    }
4413
4414    Ok(())
4415}
4416
4417fn project_left(project: &db::LeftProject, session: &Session) {
4418    for connection_id in &project.connection_ids {
4419        if project.should_unshare {
4420            session
4421                .peer
4422                .send(
4423                    *connection_id,
4424                    proto::UnshareProject {
4425                        project_id: project.id.to_proto(),
4426                    },
4427                )
4428                .trace_err();
4429        } else {
4430            session
4431                .peer
4432                .send(
4433                    *connection_id,
4434                    proto::RemoveProjectCollaborator {
4435                        project_id: project.id.to_proto(),
4436                        peer_id: Some(session.connection_id.into()),
4437                    },
4438                )
4439                .trace_err();
4440        }
4441    }
4442}
4443
4444pub trait ResultExt {
4445    type Ok;
4446
4447    fn trace_err(self) -> Option<Self::Ok>;
4448}
4449
4450impl<T, E> ResultExt for Result<T, E>
4451where
4452    E: std::fmt::Debug,
4453{
4454    type Ok = T;
4455
4456    #[track_caller]
4457    fn trace_err(self) -> Option<T> {
4458        match self {
4459            Ok(value) => Some(value),
4460            Err(error) => {
4461                tracing::error!("{:?}", error);
4462                None
4463            }
4464        }
4465    }
4466}