rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
 313            .add_request_handler(
 314                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 315            )
 316            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 317            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 318            .add_request_handler(
 319                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 320            )
 321            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 322            .add_request_handler(
 323                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 324            )
 325            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 326            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 327            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 330            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 332            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 336            .add_request_handler(
 337                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 338            )
 339            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 340            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 341            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 342            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 343            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 345            .add_message_handler(create_buffer_for_peer)
 346            .add_request_handler(update_buffer)
 347            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 352            .add_request_handler(get_users)
 353            .add_request_handler(fuzzy_search_users)
 354            .add_request_handler(request_contact)
 355            .add_request_handler(remove_contact)
 356            .add_request_handler(respond_to_contact_request)
 357            .add_message_handler(subscribe_to_channels)
 358            .add_request_handler(create_channel)
 359            .add_request_handler(delete_channel)
 360            .add_request_handler(invite_channel_member)
 361            .add_request_handler(remove_channel_member)
 362            .add_request_handler(set_channel_member_role)
 363            .add_request_handler(set_channel_visibility)
 364            .add_request_handler(rename_channel)
 365            .add_request_handler(join_channel_buffer)
 366            .add_request_handler(leave_channel_buffer)
 367            .add_message_handler(update_channel_buffer)
 368            .add_request_handler(rejoin_channel_buffers)
 369            .add_request_handler(get_channel_members)
 370            .add_request_handler(respond_to_channel_invite)
 371            .add_request_handler(join_channel)
 372            .add_request_handler(join_channel_chat)
 373            .add_message_handler(leave_channel_chat)
 374            .add_request_handler(send_channel_message)
 375            .add_request_handler(remove_channel_message)
 376            .add_request_handler(update_channel_message)
 377            .add_request_handler(get_channel_messages)
 378            .add_request_handler(get_channel_messages_by_id)
 379            .add_request_handler(get_notifications)
 380            .add_request_handler(mark_notification_as_read)
 381            .add_request_handler(move_channel)
 382            .add_request_handler(follow)
 383            .add_message_handler(unfollow)
 384            .add_message_handler(update_followers)
 385            .add_request_handler(get_private_user_info)
 386            .add_request_handler(get_llm_api_token)
 387            .add_request_handler(accept_terms_of_service)
 388            .add_message_handler(acknowledge_channel_message)
 389            .add_message_handler(acknowledge_buffer_version)
 390            .add_request_handler(get_supermaven_api_key)
 391            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 392            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 393            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 394            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 395            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 397            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 398            .add_message_handler(update_context)
 399            .add_request_handler({
 400                let app_state = app_state.clone();
 401                move |request, response, session| {
 402                    let app_state = app_state.clone();
 403                    async move {
 404                        count_language_model_tokens(request, response, session, &app_state.config)
 405                            .await
 406                    }
 407                }
 408            })
 409            .add_request_handler(get_cached_embeddings)
 410            .add_request_handler({
 411                let app_state = app_state.clone();
 412                move |request, response, session| {
 413                    compute_embeddings(
 414                        request,
 415                        response,
 416                        session,
 417                        app_state.config.openai_api_key.clone(),
 418                    )
 419                }
 420            });
 421
 422        Arc::new(server)
 423    }
 424
 425    pub async fn start(&self) -> Result<()> {
 426        let server_id = *self.id.lock();
 427        let app_state = self.app_state.clone();
 428        let peer = self.peer.clone();
 429        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 430        let pool = self.connection_pool.clone();
 431        let livekit_client = self.app_state.livekit_client.clone();
 432
 433        let span = info_span!("start server");
 434        self.app_state.executor.spawn_detached(
 435            async move {
 436                tracing::info!("waiting for cleanup timeout");
 437                timeout.await;
 438                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 439                if let Some((room_ids, channel_ids)) = app_state
 440                    .db
 441                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 442                    .await
 443                    .trace_err()
 444                {
 445                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 446                    tracing::info!(
 447                        stale_channel_buffer_count = channel_ids.len(),
 448                        "retrieved stale channel buffers"
 449                    );
 450
 451                    for channel_id in channel_ids {
 452                        if let Some(refreshed_channel_buffer) = app_state
 453                            .db
 454                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 455                            .await
 456                            .trace_err()
 457                        {
 458                            for connection_id in refreshed_channel_buffer.connection_ids {
 459                                peer.send(
 460                                    connection_id,
 461                                    proto::UpdateChannelBufferCollaborators {
 462                                        channel_id: channel_id.to_proto(),
 463                                        collaborators: refreshed_channel_buffer
 464                                            .collaborators
 465                                            .clone(),
 466                                    },
 467                                )
 468                                .trace_err();
 469                            }
 470                        }
 471                    }
 472
 473                    for room_id in room_ids {
 474                        let mut contacts_to_update = HashSet::default();
 475                        let mut canceled_calls_to_user_ids = Vec::new();
 476                        let mut livekit_room = String::new();
 477                        let mut delete_livekit_room = false;
 478
 479                        if let Some(mut refreshed_room) = app_state
 480                            .db
 481                            .clear_stale_room_participants(room_id, server_id)
 482                            .await
 483                            .trace_err()
 484                        {
 485                            tracing::info!(
 486                                room_id = room_id.0,
 487                                new_participant_count = refreshed_room.room.participants.len(),
 488                                "refreshed room"
 489                            );
 490                            room_updated(&refreshed_room.room, &peer);
 491                            if let Some(channel) = refreshed_room.channel.as_ref() {
 492                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 493                            }
 494                            contacts_to_update
 495                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 496                            contacts_to_update
 497                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 498                            canceled_calls_to_user_ids =
 499                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 500                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 501                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 502                        }
 503
 504                        {
 505                            let pool = pool.lock();
 506                            for canceled_user_id in canceled_calls_to_user_ids {
 507                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 508                                    peer.send(
 509                                        connection_id,
 510                                        proto::CallCanceled {
 511                                            room_id: room_id.to_proto(),
 512                                        },
 513                                    )
 514                                    .trace_err();
 515                                }
 516                            }
 517                        }
 518
 519                        for user_id in contacts_to_update {
 520                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 521                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 522                            if let Some((busy, contacts)) = busy.zip(contacts) {
 523                                let pool = pool.lock();
 524                                let updated_contact = contact_for_user(user_id, busy, &pool);
 525                                for contact in contacts {
 526                                    if let db::Contact::Accepted {
 527                                        user_id: contact_user_id,
 528                                        ..
 529                                    } = contact
 530                                    {
 531                                        for contact_conn_id in
 532                                            pool.user_connection_ids(contact_user_id)
 533                                        {
 534                                            peer.send(
 535                                                contact_conn_id,
 536                                                proto::UpdateContacts {
 537                                                    contacts: vec![updated_contact.clone()],
 538                                                    remove_contacts: Default::default(),
 539                                                    incoming_requests: Default::default(),
 540                                                    remove_incoming_requests: Default::default(),
 541                                                    outgoing_requests: Default::default(),
 542                                                    remove_outgoing_requests: Default::default(),
 543                                                },
 544                                            )
 545                                            .trace_err();
 546                                        }
 547                                    }
 548                                }
 549                            }
 550                        }
 551
 552                        if let Some(live_kit) = livekit_client.as_ref() {
 553                            if delete_livekit_room {
 554                                live_kit.delete_room(livekit_room).await.trace_err();
 555                            }
 556                        }
 557                    }
 558                }
 559
 560                app_state
 561                    .db
 562                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 563                    .await
 564                    .trace_err();
 565            }
 566            .instrument(span),
 567        );
 568        Ok(())
 569    }
 570
 571    pub fn teardown(&self) {
 572        self.peer.teardown();
 573        self.connection_pool.lock().reset();
 574        let _ = self.teardown.send(true);
 575    }
 576
 577    #[cfg(test)]
 578    pub fn reset(&self, id: ServerId) {
 579        self.teardown();
 580        *self.id.lock() = id;
 581        self.peer.reset(id.0 as u32);
 582        let _ = self.teardown.send(false);
 583    }
 584
 585    #[cfg(test)]
 586    pub fn id(&self) -> ServerId {
 587        *self.id.lock()
 588    }
 589
 590    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 591    where
 592        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 593        Fut: 'static + Send + Future<Output = Result<()>>,
 594        M: EnvelopedMessage,
 595    {
 596        let prev_handler = self.handlers.insert(
 597            TypeId::of::<M>(),
 598            Box::new(move |envelope, session| {
 599                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 600                let received_at = envelope.received_at;
 601                tracing::info!("message received");
 602                let start_time = Instant::now();
 603                let future = (handler)(*envelope, session);
 604                async move {
 605                    let result = future.await;
 606                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 607                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 608                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 609                    let payload_type = M::NAME;
 610
 611                    match result {
 612                        Err(error) => {
 613                            tracing::error!(
 614                                ?error,
 615                                total_duration_ms,
 616                                processing_duration_ms,
 617                                queue_duration_ms,
 618                                payload_type,
 619                                "error handling message"
 620                            )
 621                        }
 622                        Ok(()) => tracing::info!(
 623                            total_duration_ms,
 624                            processing_duration_ms,
 625                            queue_duration_ms,
 626                            "finished handling message"
 627                        ),
 628                    }
 629                }
 630                .boxed()
 631            }),
 632        );
 633        if prev_handler.is_some() {
 634            panic!("registered a handler for the same message twice");
 635        }
 636        self
 637    }
 638
 639    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 640    where
 641        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 642        Fut: 'static + Send + Future<Output = Result<()>>,
 643        M: EnvelopedMessage,
 644    {
 645        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 646        self
 647    }
 648
 649    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 650    where
 651        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 652        Fut: Send + Future<Output = Result<()>>,
 653        M: RequestMessage,
 654    {
 655        let handler = Arc::new(handler);
 656        self.add_handler(move |envelope, session| {
 657            let receipt = envelope.receipt();
 658            let handler = handler.clone();
 659            async move {
 660                let peer = session.peer.clone();
 661                let responded = Arc::new(AtomicBool::default());
 662                let response = Response {
 663                    peer: peer.clone(),
 664                    responded: responded.clone(),
 665                    receipt,
 666                };
 667                match (handler)(envelope.payload, response, session).await {
 668                    Ok(()) => {
 669                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 670                            Ok(())
 671                        } else {
 672                            Err(anyhow!("handler did not send a response"))?
 673                        }
 674                    }
 675                    Err(error) => {
 676                        let proto_err = match &error {
 677                            Error::Internal(err) => err.to_proto(),
 678                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 679                        };
 680                        peer.respond_with_error(receipt, proto_err)?;
 681                        Err(error)
 682                    }
 683                }
 684            }
 685        })
 686    }
 687
 688    #[allow(clippy::too_many_arguments)]
 689    pub fn handle_connection(
 690        self: &Arc<Self>,
 691        connection: Connection,
 692        address: String,
 693        principal: Principal,
 694        zed_version: ZedVersion,
 695        geoip_country_code: Option<String>,
 696        system_id: Option<String>,
 697        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 698        executor: Executor,
 699    ) -> impl Future<Output = ()> {
 700        let this = self.clone();
 701        let span = info_span!("handle connection", %address,
 702            connection_id=field::Empty,
 703            user_id=field::Empty,
 704            login=field::Empty,
 705            impersonator=field::Empty,
 706            geoip_country_code=field::Empty
 707        );
 708        principal.update_span(&span);
 709        if let Some(country_code) = geoip_country_code.as_ref() {
 710            span.record("geoip_country_code", country_code);
 711        }
 712
 713        let mut teardown = self.teardown.subscribe();
 714        async move {
 715            if *teardown.borrow() {
 716                tracing::error!("server is tearing down");
 717                return
 718            }
 719            let (connection_id, handle_io, mut incoming_rx) = this
 720                .peer
 721                .add_connection(connection, {
 722                    let executor = executor.clone();
 723                    move |duration| executor.sleep(duration)
 724                });
 725            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 726
 727            tracing::info!("connection opened");
 728
 729            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 730            let http_client = match ReqwestClient::user_agent(&user_agent) {
 731                Ok(http_client) => Arc::new(http_client),
 732                Err(error) => {
 733                    tracing::error!(?error, "failed to create HTTP client");
 734                    return;
 735                }
 736            };
 737
 738            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 739                    supermaven_admin_api_key.to_string(),
 740                    http_client.clone(),
 741                )));
 742
 743            let session = Session {
 744                principal: principal.clone(),
 745                connection_id,
 746                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 747                peer: this.peer.clone(),
 748                connection_pool: this.connection_pool.clone(),
 749                app_state: this.app_state.clone(),
 750                http_client,
 751                geoip_country_code,
 752                system_id,
 753                _executor: executor.clone(),
 754                supermaven_client,
 755            };
 756
 757            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 758                tracing::error!(?error, "failed to send initial client update");
 759                return;
 760            }
 761
 762            let handle_io = handle_io.fuse();
 763            futures::pin_mut!(handle_io);
 764
 765            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 766            // This prevents deadlocks when e.g., client A performs a request to client B and
 767            // client B performs a request to client A. If both clients stop processing further
 768            // messages until their respective request completes, they won't have a chance to
 769            // respond to the other client's request and cause a deadlock.
 770            //
 771            // This arrangement ensures we will attempt to process earlier messages first, but fall
 772            // back to processing messages arrived later in the spirit of making progress.
 773            let mut foreground_message_handlers = FuturesUnordered::new();
 774            let concurrent_handlers = Arc::new(Semaphore::new(256));
 775            loop {
 776                let next_message = async {
 777                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 778                    let message = incoming_rx.next().await;
 779                    (permit, message)
 780                }.fuse();
 781                futures::pin_mut!(next_message);
 782                futures::select_biased! {
 783                    _ = teardown.changed().fuse() => return,
 784                    result = handle_io => {
 785                        if let Err(error) = result {
 786                            tracing::error!(?error, "error handling I/O");
 787                        }
 788                        break;
 789                    }
 790                    _ = foreground_message_handlers.next() => {}
 791                    next_message = next_message => {
 792                        let (permit, message) = next_message;
 793                        if let Some(message) = message {
 794                            let type_name = message.payload_type_name();
 795                            // note: we copy all the fields from the parent span so we can query them in the logs.
 796                            // (https://github.com/tokio-rs/tracing/issues/2670).
 797                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 798                                user_id=field::Empty,
 799                                login=field::Empty,
 800                                impersonator=field::Empty,
 801                            );
 802                            principal.update_span(&span);
 803                            let span_enter = span.enter();
 804                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 805                                let is_background = message.is_background();
 806                                let handle_message = (handler)(message, session.clone());
 807                                drop(span_enter);
 808
 809                                let handle_message = async move {
 810                                    handle_message.await;
 811                                    drop(permit);
 812                                }.instrument(span);
 813                                if is_background {
 814                                    executor.spawn_detached(handle_message);
 815                                } else {
 816                                    foreground_message_handlers.push(handle_message);
 817                                }
 818                            } else {
 819                                tracing::error!("no message handler");
 820                            }
 821                        } else {
 822                            tracing::info!("connection closed");
 823                            break;
 824                        }
 825                    }
 826                }
 827            }
 828
 829            drop(foreground_message_handlers);
 830            tracing::info!("signing out");
 831            if let Err(error) = connection_lost(session, teardown, executor).await {
 832                tracing::error!(?error, "error signing out");
 833            }
 834
 835        }.instrument(span)
 836    }
 837
 838    async fn send_initial_client_update(
 839        &self,
 840        connection_id: ConnectionId,
 841        principal: &Principal,
 842        zed_version: ZedVersion,
 843        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 844        session: &Session,
 845    ) -> Result<()> {
 846        self.peer.send(
 847            connection_id,
 848            proto::Hello {
 849                peer_id: Some(connection_id.into()),
 850            },
 851        )?;
 852        tracing::info!("sent hello message");
 853        if let Some(send_connection_id) = send_connection_id.take() {
 854            let _ = send_connection_id.send(connection_id);
 855        }
 856
 857        match principal {
 858            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 859                if !user.connected_once {
 860                    self.peer.send(connection_id, proto::ShowContacts {})?;
 861                    self.app_state
 862                        .db
 863                        .set_user_connected_once(user.id, true)
 864                        .await?;
 865                }
 866
 867                update_user_plan(user.id, session).await?;
 868
 869                let contacts = self.app_state.db.get_contacts(user.id).await?;
 870
 871                {
 872                    let mut pool = self.connection_pool.lock();
 873                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 874                    self.peer.send(
 875                        connection_id,
 876                        build_initial_contacts_update(contacts, &pool),
 877                    )?;
 878                }
 879
 880                if should_auto_subscribe_to_channels(zed_version) {
 881                    subscribe_user_to_channels(user.id, session).await?;
 882                }
 883
 884                if let Some(incoming_call) =
 885                    self.app_state.db.incoming_call_for_user(user.id).await?
 886                {
 887                    self.peer.send(connection_id, incoming_call)?;
 888                }
 889
 890                update_user_contacts(user.id, session).await?;
 891            }
 892        }
 893
 894        Ok(())
 895    }
 896
 897    pub async fn invite_code_redeemed(
 898        self: &Arc<Self>,
 899        inviter_id: UserId,
 900        invitee_id: UserId,
 901    ) -> Result<()> {
 902        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 903            if let Some(code) = &user.invite_code {
 904                let pool = self.connection_pool.lock();
 905                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 906                for connection_id in pool.user_connection_ids(inviter_id) {
 907                    self.peer.send(
 908                        connection_id,
 909                        proto::UpdateContacts {
 910                            contacts: vec![invitee_contact.clone()],
 911                            ..Default::default()
 912                        },
 913                    )?;
 914                    self.peer.send(
 915                        connection_id,
 916                        proto::UpdateInviteInfo {
 917                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 918                            count: user.invite_count as u32,
 919                        },
 920                    )?;
 921                }
 922            }
 923        }
 924        Ok(())
 925    }
 926
 927    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 928        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 929            if let Some(invite_code) = &user.invite_code {
 930                let pool = self.connection_pool.lock();
 931                for connection_id in pool.user_connection_ids(user_id) {
 932                    self.peer.send(
 933                        connection_id,
 934                        proto::UpdateInviteInfo {
 935                            url: format!(
 936                                "{}{}",
 937                                self.app_state.config.invite_link_prefix, invite_code
 938                            ),
 939                            count: user.invite_count as u32,
 940                        },
 941                    )?;
 942                }
 943            }
 944        }
 945        Ok(())
 946    }
 947
 948    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 949        let pool = self.connection_pool.lock();
 950        for connection_id in pool.user_connection_ids(user_id) {
 951            self.peer
 952                .send(connection_id, proto::RefreshLlmToken {})
 953                .trace_err();
 954        }
 955    }
 956
 957    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 958        ServerSnapshot {
 959            connection_pool: ConnectionPoolGuard {
 960                guard: self.connection_pool.lock(),
 961                _not_send: PhantomData,
 962            },
 963            peer: &self.peer,
 964        }
 965    }
 966}
 967
 968impl<'a> Deref for ConnectionPoolGuard<'a> {
 969    type Target = ConnectionPool;
 970
 971    fn deref(&self) -> &Self::Target {
 972        &self.guard
 973    }
 974}
 975
 976impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 977    fn deref_mut(&mut self) -> &mut Self::Target {
 978        &mut self.guard
 979    }
 980}
 981
 982impl<'a> Drop for ConnectionPoolGuard<'a> {
 983    fn drop(&mut self) {
 984        #[cfg(test)]
 985        self.check_invariants();
 986    }
 987}
 988
 989fn broadcast<F>(
 990    sender_id: Option<ConnectionId>,
 991    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 992    mut f: F,
 993) where
 994    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 995{
 996    for receiver_id in receiver_ids {
 997        if Some(receiver_id) != sender_id {
 998            if let Err(error) = f(receiver_id) {
 999                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1000            }
1001        }
1002    }
1003}
1004
1005pub struct ProtocolVersion(u32);
1006
1007impl Header for ProtocolVersion {
1008    fn name() -> &'static HeaderName {
1009        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1010        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1011    }
1012
1013    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1014    where
1015        Self: Sized,
1016        I: Iterator<Item = &'i axum::http::HeaderValue>,
1017    {
1018        let version = values
1019            .next()
1020            .ok_or_else(axum::headers::Error::invalid)?
1021            .to_str()
1022            .map_err(|_| axum::headers::Error::invalid())?
1023            .parse()
1024            .map_err(|_| axum::headers::Error::invalid())?;
1025        Ok(Self(version))
1026    }
1027
1028    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1029        values.extend([self.0.to_string().parse().unwrap()]);
1030    }
1031}
1032
1033pub struct AppVersionHeader(SemanticVersion);
1034impl Header for AppVersionHeader {
1035    fn name() -> &'static HeaderName {
1036        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1037        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1038    }
1039
1040    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1041    where
1042        Self: Sized,
1043        I: Iterator<Item = &'i axum::http::HeaderValue>,
1044    {
1045        let version = values
1046            .next()
1047            .ok_or_else(axum::headers::Error::invalid)?
1048            .to_str()
1049            .map_err(|_| axum::headers::Error::invalid())?
1050            .parse()
1051            .map_err(|_| axum::headers::Error::invalid())?;
1052        Ok(Self(version))
1053    }
1054
1055    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1056        values.extend([self.0.to_string().parse().unwrap()]);
1057    }
1058}
1059
1060pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1061    Router::new()
1062        .route("/rpc", get(handle_websocket_request))
1063        .layer(
1064            ServiceBuilder::new()
1065                .layer(Extension(server.app_state.clone()))
1066                .layer(middleware::from_fn(auth::validate_header)),
1067        )
1068        .route("/metrics", get(handle_metrics))
1069        .layer(Extension(server))
1070}
1071
1072#[allow(clippy::too_many_arguments)]
1073pub async fn handle_websocket_request(
1074    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1075    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1076    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1077    Extension(server): Extension<Arc<Server>>,
1078    Extension(principal): Extension<Principal>,
1079    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1080    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1081    ws: WebSocketUpgrade,
1082) -> axum::response::Response {
1083    if protocol_version != rpc::PROTOCOL_VERSION {
1084        return (
1085            StatusCode::UPGRADE_REQUIRED,
1086            "client must be upgraded".to_string(),
1087        )
1088            .into_response();
1089    }
1090
1091    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1092        return (
1093            StatusCode::UPGRADE_REQUIRED,
1094            "no version header found".to_string(),
1095        )
1096            .into_response();
1097    };
1098
1099    if !version.can_collaborate() {
1100        return (
1101            StatusCode::UPGRADE_REQUIRED,
1102            "client must be upgraded".to_string(),
1103        )
1104            .into_response();
1105    }
1106
1107    let socket_address = socket_address.to_string();
1108    ws.on_upgrade(move |socket| {
1109        let socket = socket
1110            .map_ok(to_tungstenite_message)
1111            .err_into()
1112            .with(|message| async move { to_axum_message(message) });
1113        let connection = Connection::new(Box::pin(socket));
1114        async move {
1115            server
1116                .handle_connection(
1117                    connection,
1118                    socket_address,
1119                    principal,
1120                    version,
1121                    country_code_header.map(|header| header.to_string()),
1122                    system_id_header.map(|header| header.to_string()),
1123                    None,
1124                    Executor::Production,
1125                )
1126                .await;
1127        }
1128    })
1129}
1130
1131pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1132    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1133    let connections_metric = CONNECTIONS_METRIC
1134        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1135
1136    let connections = server
1137        .connection_pool
1138        .lock()
1139        .connections()
1140        .filter(|connection| !connection.admin)
1141        .count();
1142    connections_metric.set(connections as _);
1143
1144    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1145    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1146        register_int_gauge!(
1147            "shared_projects",
1148            "number of open projects with one or more guests"
1149        )
1150        .unwrap()
1151    });
1152
1153    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1154    shared_projects_metric.set(shared_projects as _);
1155
1156    let encoder = prometheus::TextEncoder::new();
1157    let metric_families = prometheus::gather();
1158    let encoded_metrics = encoder
1159        .encode_to_string(&metric_families)
1160        .map_err(|err| anyhow!("{}", err))?;
1161    Ok(encoded_metrics)
1162}
1163
1164#[instrument(err, skip(executor))]
1165async fn connection_lost(
1166    session: Session,
1167    mut teardown: watch::Receiver<bool>,
1168    executor: Executor,
1169) -> Result<()> {
1170    session.peer.disconnect(session.connection_id);
1171    session
1172        .connection_pool()
1173        .await
1174        .remove_connection(session.connection_id)?;
1175
1176    session
1177        .db()
1178        .await
1179        .connection_lost(session.connection_id)
1180        .await
1181        .trace_err();
1182
1183    futures::select_biased! {
1184        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1185
1186            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1187            leave_room_for_session(&session, session.connection_id).await.trace_err();
1188            leave_channel_buffers_for_session(&session)
1189                .await
1190                .trace_err();
1191
1192            if !session
1193                .connection_pool()
1194                .await
1195                .is_user_online(session.user_id())
1196            {
1197                let db = session.db().await;
1198                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1199                    room_updated(&room, &session.peer);
1200                }
1201            }
1202
1203            update_user_contacts(session.user_id(), &session).await?;
1204        },
1205        _ = teardown.changed().fuse() => {}
1206    }
1207
1208    Ok(())
1209}
1210
1211/// Acknowledges a ping from a client, used to keep the connection alive.
1212async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1213    response.send(proto::Ack {})?;
1214    Ok(())
1215}
1216
1217/// Creates a new room for calling (outside of channels)
1218async fn create_room(
1219    _request: proto::CreateRoom,
1220    response: Response<proto::CreateRoom>,
1221    session: Session,
1222) -> Result<()> {
1223    let livekit_room = nanoid::nanoid!(30);
1224
1225    let live_kit_connection_info = util::maybe!(async {
1226        let live_kit = session.app_state.livekit_client.as_ref();
1227        let live_kit = live_kit?;
1228        let user_id = session.user_id().to_string();
1229
1230        let token = live_kit
1231            .room_token(&livekit_room, &user_id.to_string())
1232            .trace_err()?;
1233
1234        Some(proto::LiveKitConnectionInfo {
1235            server_url: live_kit.url().into(),
1236            token,
1237            can_publish: true,
1238        })
1239    })
1240    .await;
1241
1242    let room = session
1243        .db()
1244        .await
1245        .create_room(session.user_id(), session.connection_id, &livekit_room)
1246        .await?;
1247
1248    response.send(proto::CreateRoomResponse {
1249        room: Some(room.clone()),
1250        live_kit_connection_info,
1251    })?;
1252
1253    update_user_contacts(session.user_id(), &session).await?;
1254    Ok(())
1255}
1256
1257/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1258async fn join_room(
1259    request: proto::JoinRoom,
1260    response: Response<proto::JoinRoom>,
1261    session: Session,
1262) -> Result<()> {
1263    let room_id = RoomId::from_proto(request.id);
1264
1265    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1266
1267    if let Some(channel_id) = channel_id {
1268        return join_channel_internal(channel_id, Box::new(response), session).await;
1269    }
1270
1271    let joined_room = {
1272        let room = session
1273            .db()
1274            .await
1275            .join_room(room_id, session.user_id(), session.connection_id)
1276            .await?;
1277        room_updated(&room.room, &session.peer);
1278        room.into_inner()
1279    };
1280
1281    for connection_id in session
1282        .connection_pool()
1283        .await
1284        .user_connection_ids(session.user_id())
1285    {
1286        session
1287            .peer
1288            .send(
1289                connection_id,
1290                proto::CallCanceled {
1291                    room_id: room_id.to_proto(),
1292                },
1293            )
1294            .trace_err();
1295    }
1296
1297    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1298    {
1299        live_kit
1300            .room_token(
1301                &joined_room.room.livekit_room,
1302                &session.user_id().to_string(),
1303            )
1304            .trace_err()
1305            .map(|token| proto::LiveKitConnectionInfo {
1306                server_url: live_kit.url().into(),
1307                token,
1308                can_publish: true,
1309            })
1310    } else {
1311        None
1312    };
1313
1314    response.send(proto::JoinRoomResponse {
1315        room: Some(joined_room.room),
1316        channel_id: None,
1317        live_kit_connection_info,
1318    })?;
1319
1320    update_user_contacts(session.user_id(), &session).await?;
1321    Ok(())
1322}
1323
1324/// Rejoin room is used to reconnect to a room after connection errors.
1325async fn rejoin_room(
1326    request: proto::RejoinRoom,
1327    response: Response<proto::RejoinRoom>,
1328    session: Session,
1329) -> Result<()> {
1330    let room;
1331    let channel;
1332    {
1333        let mut rejoined_room = session
1334            .db()
1335            .await
1336            .rejoin_room(request, session.user_id(), session.connection_id)
1337            .await?;
1338
1339        response.send(proto::RejoinRoomResponse {
1340            room: Some(rejoined_room.room.clone()),
1341            reshared_projects: rejoined_room
1342                .reshared_projects
1343                .iter()
1344                .map(|project| proto::ResharedProject {
1345                    id: project.id.to_proto(),
1346                    collaborators: project
1347                        .collaborators
1348                        .iter()
1349                        .map(|collaborator| collaborator.to_proto())
1350                        .collect(),
1351                })
1352                .collect(),
1353            rejoined_projects: rejoined_room
1354                .rejoined_projects
1355                .iter()
1356                .map(|rejoined_project| rejoined_project.to_proto())
1357                .collect(),
1358        })?;
1359        room_updated(&rejoined_room.room, &session.peer);
1360
1361        for project in &rejoined_room.reshared_projects {
1362            for collaborator in &project.collaborators {
1363                session
1364                    .peer
1365                    .send(
1366                        collaborator.connection_id,
1367                        proto::UpdateProjectCollaborator {
1368                            project_id: project.id.to_proto(),
1369                            old_peer_id: Some(project.old_connection_id.into()),
1370                            new_peer_id: Some(session.connection_id.into()),
1371                        },
1372                    )
1373                    .trace_err();
1374            }
1375
1376            broadcast(
1377                Some(session.connection_id),
1378                project
1379                    .collaborators
1380                    .iter()
1381                    .map(|collaborator| collaborator.connection_id),
1382                |connection_id| {
1383                    session.peer.forward_send(
1384                        session.connection_id,
1385                        connection_id,
1386                        proto::UpdateProject {
1387                            project_id: project.id.to_proto(),
1388                            worktrees: project.worktrees.clone(),
1389                        },
1390                    )
1391                },
1392            );
1393        }
1394
1395        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1396
1397        let rejoined_room = rejoined_room.into_inner();
1398
1399        room = rejoined_room.room;
1400        channel = rejoined_room.channel;
1401    }
1402
1403    if let Some(channel) = channel {
1404        channel_updated(
1405            &channel,
1406            &room,
1407            &session.peer,
1408            &*session.connection_pool().await,
1409        );
1410    }
1411
1412    update_user_contacts(session.user_id(), &session).await?;
1413    Ok(())
1414}
1415
1416fn notify_rejoined_projects(
1417    rejoined_projects: &mut Vec<RejoinedProject>,
1418    session: &Session,
1419) -> Result<()> {
1420    for project in rejoined_projects.iter() {
1421        for collaborator in &project.collaborators {
1422            session
1423                .peer
1424                .send(
1425                    collaborator.connection_id,
1426                    proto::UpdateProjectCollaborator {
1427                        project_id: project.id.to_proto(),
1428                        old_peer_id: Some(project.old_connection_id.into()),
1429                        new_peer_id: Some(session.connection_id.into()),
1430                    },
1431                )
1432                .trace_err();
1433        }
1434    }
1435
1436    for project in rejoined_projects {
1437        for worktree in mem::take(&mut project.worktrees) {
1438            // Stream this worktree's entries.
1439            let message = proto::UpdateWorktree {
1440                project_id: project.id.to_proto(),
1441                worktree_id: worktree.id,
1442                abs_path: worktree.abs_path.clone(),
1443                root_name: worktree.root_name,
1444                updated_entries: worktree.updated_entries,
1445                removed_entries: worktree.removed_entries,
1446                scan_id: worktree.scan_id,
1447                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1448                updated_repositories: worktree.updated_repositories,
1449                removed_repositories: worktree.removed_repositories,
1450            };
1451            for update in proto::split_worktree_update(message) {
1452                session.peer.send(session.connection_id, update.clone())?;
1453            }
1454
1455            // Stream this worktree's diagnostics.
1456            for summary in worktree.diagnostic_summaries {
1457                session.peer.send(
1458                    session.connection_id,
1459                    proto::UpdateDiagnosticSummary {
1460                        project_id: project.id.to_proto(),
1461                        worktree_id: worktree.id,
1462                        summary: Some(summary),
1463                    },
1464                )?;
1465            }
1466
1467            for settings_file in worktree.settings_files {
1468                session.peer.send(
1469                    session.connection_id,
1470                    proto::UpdateWorktreeSettings {
1471                        project_id: project.id.to_proto(),
1472                        worktree_id: worktree.id,
1473                        path: settings_file.path,
1474                        content: Some(settings_file.content),
1475                        kind: Some(settings_file.kind.to_proto().into()),
1476                    },
1477                )?;
1478            }
1479        }
1480
1481        for language_server in &project.language_servers {
1482            session.peer.send(
1483                session.connection_id,
1484                proto::UpdateLanguageServer {
1485                    project_id: project.id.to_proto(),
1486                    language_server_id: language_server.id,
1487                    variant: Some(
1488                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1489                            proto::LspDiskBasedDiagnosticsUpdated {},
1490                        ),
1491                    ),
1492                },
1493            )?;
1494        }
1495    }
1496    Ok(())
1497}
1498
1499/// leave room disconnects from the room.
1500async fn leave_room(
1501    _: proto::LeaveRoom,
1502    response: Response<proto::LeaveRoom>,
1503    session: Session,
1504) -> Result<()> {
1505    leave_room_for_session(&session, session.connection_id).await?;
1506    response.send(proto::Ack {})?;
1507    Ok(())
1508}
1509
1510/// Updates the permissions of someone else in the room.
1511async fn set_room_participant_role(
1512    request: proto::SetRoomParticipantRole,
1513    response: Response<proto::SetRoomParticipantRole>,
1514    session: Session,
1515) -> Result<()> {
1516    let user_id = UserId::from_proto(request.user_id);
1517    let role = ChannelRole::from(request.role());
1518
1519    let (livekit_room, can_publish) = {
1520        let room = session
1521            .db()
1522            .await
1523            .set_room_participant_role(
1524                session.user_id(),
1525                RoomId::from_proto(request.room_id),
1526                user_id,
1527                role,
1528            )
1529            .await?;
1530
1531        let livekit_room = room.livekit_room.clone();
1532        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1533        room_updated(&room, &session.peer);
1534        (livekit_room, can_publish)
1535    };
1536
1537    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1538        live_kit
1539            .update_participant(
1540                livekit_room.clone(),
1541                request.user_id.to_string(),
1542                livekit_server::proto::ParticipantPermission {
1543                    can_subscribe: true,
1544                    can_publish,
1545                    can_publish_data: can_publish,
1546                    hidden: false,
1547                    recorder: false,
1548                },
1549            )
1550            .await
1551            .trace_err();
1552    }
1553
1554    response.send(proto::Ack {})?;
1555    Ok(())
1556}
1557
1558/// Call someone else into the current room
1559async fn call(
1560    request: proto::Call,
1561    response: Response<proto::Call>,
1562    session: Session,
1563) -> Result<()> {
1564    let room_id = RoomId::from_proto(request.room_id);
1565    let calling_user_id = session.user_id();
1566    let calling_connection_id = session.connection_id;
1567    let called_user_id = UserId::from_proto(request.called_user_id);
1568    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1569    if !session
1570        .db()
1571        .await
1572        .has_contact(calling_user_id, called_user_id)
1573        .await?
1574    {
1575        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1576    }
1577
1578    let incoming_call = {
1579        let (room, incoming_call) = &mut *session
1580            .db()
1581            .await
1582            .call(
1583                room_id,
1584                calling_user_id,
1585                calling_connection_id,
1586                called_user_id,
1587                initial_project_id,
1588            )
1589            .await?;
1590        room_updated(room, &session.peer);
1591        mem::take(incoming_call)
1592    };
1593    update_user_contacts(called_user_id, &session).await?;
1594
1595    let mut calls = session
1596        .connection_pool()
1597        .await
1598        .user_connection_ids(called_user_id)
1599        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1600        .collect::<FuturesUnordered<_>>();
1601
1602    while let Some(call_response) = calls.next().await {
1603        match call_response.as_ref() {
1604            Ok(_) => {
1605                response.send(proto::Ack {})?;
1606                return Ok(());
1607            }
1608            Err(_) => {
1609                call_response.trace_err();
1610            }
1611        }
1612    }
1613
1614    {
1615        let room = session
1616            .db()
1617            .await
1618            .call_failed(room_id, called_user_id)
1619            .await?;
1620        room_updated(&room, &session.peer);
1621    }
1622    update_user_contacts(called_user_id, &session).await?;
1623
1624    Err(anyhow!("failed to ring user"))?
1625}
1626
1627/// Cancel an outgoing call.
1628async fn cancel_call(
1629    request: proto::CancelCall,
1630    response: Response<proto::CancelCall>,
1631    session: Session,
1632) -> Result<()> {
1633    let called_user_id = UserId::from_proto(request.called_user_id);
1634    let room_id = RoomId::from_proto(request.room_id);
1635    {
1636        let room = session
1637            .db()
1638            .await
1639            .cancel_call(room_id, session.connection_id, called_user_id)
1640            .await?;
1641        room_updated(&room, &session.peer);
1642    }
1643
1644    for connection_id in session
1645        .connection_pool()
1646        .await
1647        .user_connection_ids(called_user_id)
1648    {
1649        session
1650            .peer
1651            .send(
1652                connection_id,
1653                proto::CallCanceled {
1654                    room_id: room_id.to_proto(),
1655                },
1656            )
1657            .trace_err();
1658    }
1659    response.send(proto::Ack {})?;
1660
1661    update_user_contacts(called_user_id, &session).await?;
1662    Ok(())
1663}
1664
1665/// Decline an incoming call.
1666async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1667    let room_id = RoomId::from_proto(message.room_id);
1668    {
1669        let room = session
1670            .db()
1671            .await
1672            .decline_call(Some(room_id), session.user_id())
1673            .await?
1674            .ok_or_else(|| anyhow!("failed to decline call"))?;
1675        room_updated(&room, &session.peer);
1676    }
1677
1678    for connection_id in session
1679        .connection_pool()
1680        .await
1681        .user_connection_ids(session.user_id())
1682    {
1683        session
1684            .peer
1685            .send(
1686                connection_id,
1687                proto::CallCanceled {
1688                    room_id: room_id.to_proto(),
1689                },
1690            )
1691            .trace_err();
1692    }
1693    update_user_contacts(session.user_id(), &session).await?;
1694    Ok(())
1695}
1696
1697/// Updates other participants in the room with your current location.
1698async fn update_participant_location(
1699    request: proto::UpdateParticipantLocation,
1700    response: Response<proto::UpdateParticipantLocation>,
1701    session: Session,
1702) -> Result<()> {
1703    let room_id = RoomId::from_proto(request.room_id);
1704    let location = request
1705        .location
1706        .ok_or_else(|| anyhow!("invalid location"))?;
1707
1708    let db = session.db().await;
1709    let room = db
1710        .update_room_participant_location(room_id, session.connection_id, location)
1711        .await?;
1712
1713    room_updated(&room, &session.peer);
1714    response.send(proto::Ack {})?;
1715    Ok(())
1716}
1717
1718/// Share a project into the room.
1719async fn share_project(
1720    request: proto::ShareProject,
1721    response: Response<proto::ShareProject>,
1722    session: Session,
1723) -> Result<()> {
1724    let (project_id, room) = &*session
1725        .db()
1726        .await
1727        .share_project(
1728            RoomId::from_proto(request.room_id),
1729            session.connection_id,
1730            &request.worktrees,
1731            request.is_ssh_project,
1732        )
1733        .await?;
1734    response.send(proto::ShareProjectResponse {
1735        project_id: project_id.to_proto(),
1736    })?;
1737    room_updated(room, &session.peer);
1738
1739    Ok(())
1740}
1741
1742/// Unshare a project from the room.
1743async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1744    let project_id = ProjectId::from_proto(message.project_id);
1745    unshare_project_internal(project_id, session.connection_id, &session).await
1746}
1747
1748async fn unshare_project_internal(
1749    project_id: ProjectId,
1750    connection_id: ConnectionId,
1751    session: &Session,
1752) -> Result<()> {
1753    let delete = {
1754        let room_guard = session
1755            .db()
1756            .await
1757            .unshare_project(project_id, connection_id)
1758            .await?;
1759
1760        let (delete, room, guest_connection_ids) = &*room_guard;
1761
1762        let message = proto::UnshareProject {
1763            project_id: project_id.to_proto(),
1764        };
1765
1766        broadcast(
1767            Some(connection_id),
1768            guest_connection_ids.iter().copied(),
1769            |conn_id| session.peer.send(conn_id, message.clone()),
1770        );
1771        if let Some(room) = room {
1772            room_updated(room, &session.peer);
1773        }
1774
1775        *delete
1776    };
1777
1778    if delete {
1779        let db = session.db().await;
1780        db.delete_project(project_id).await?;
1781    }
1782
1783    Ok(())
1784}
1785
1786/// Join someone elses shared project.
1787async fn join_project(
1788    request: proto::JoinProject,
1789    response: Response<proto::JoinProject>,
1790    session: Session,
1791) -> Result<()> {
1792    let project_id = ProjectId::from_proto(request.project_id);
1793
1794    tracing::info!(%project_id, "join project");
1795
1796    let db = session.db().await;
1797    let (project, replica_id) = &mut *db
1798        .join_project(project_id, session.connection_id, session.user_id())
1799        .await?;
1800    drop(db);
1801    tracing::info!(%project_id, "join remote project");
1802    join_project_internal(response, session, project, replica_id)
1803}
1804
1805trait JoinProjectInternalResponse {
1806    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1807}
1808impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1809    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1810        Response::<proto::JoinProject>::send(self, result)
1811    }
1812}
1813
1814fn join_project_internal(
1815    response: impl JoinProjectInternalResponse,
1816    session: Session,
1817    project: &mut Project,
1818    replica_id: &ReplicaId,
1819) -> Result<()> {
1820    let collaborators = project
1821        .collaborators
1822        .iter()
1823        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1824        .map(|collaborator| collaborator.to_proto())
1825        .collect::<Vec<_>>();
1826    let project_id = project.id;
1827    let guest_user_id = session.user_id();
1828
1829    let worktrees = project
1830        .worktrees
1831        .iter()
1832        .map(|(id, worktree)| proto::WorktreeMetadata {
1833            id: *id,
1834            root_name: worktree.root_name.clone(),
1835            visible: worktree.visible,
1836            abs_path: worktree.abs_path.clone(),
1837        })
1838        .collect::<Vec<_>>();
1839
1840    let add_project_collaborator = proto::AddProjectCollaborator {
1841        project_id: project_id.to_proto(),
1842        collaborator: Some(proto::Collaborator {
1843            peer_id: Some(session.connection_id.into()),
1844            replica_id: replica_id.0 as u32,
1845            user_id: guest_user_id.to_proto(),
1846            is_host: false,
1847        }),
1848    };
1849
1850    for collaborator in &collaborators {
1851        session
1852            .peer
1853            .send(
1854                collaborator.peer_id.unwrap().into(),
1855                add_project_collaborator.clone(),
1856            )
1857            .trace_err();
1858    }
1859
1860    // First, we send the metadata associated with each worktree.
1861    response.send(proto::JoinProjectResponse {
1862        project_id: project.id.0 as u64,
1863        worktrees: worktrees.clone(),
1864        replica_id: replica_id.0 as u32,
1865        collaborators: collaborators.clone(),
1866        language_servers: project.language_servers.clone(),
1867        role: project.role.into(),
1868    })?;
1869
1870    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1871        // Stream this worktree's entries.
1872        let message = proto::UpdateWorktree {
1873            project_id: project_id.to_proto(),
1874            worktree_id,
1875            abs_path: worktree.abs_path.clone(),
1876            root_name: worktree.root_name,
1877            updated_entries: worktree.entries,
1878            removed_entries: Default::default(),
1879            scan_id: worktree.scan_id,
1880            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1881            updated_repositories: worktree.repository_entries.into_values().collect(),
1882            removed_repositories: Default::default(),
1883        };
1884        for update in proto::split_worktree_update(message) {
1885            session.peer.send(session.connection_id, update.clone())?;
1886        }
1887
1888        // Stream this worktree's diagnostics.
1889        for summary in worktree.diagnostic_summaries {
1890            session.peer.send(
1891                session.connection_id,
1892                proto::UpdateDiagnosticSummary {
1893                    project_id: project_id.to_proto(),
1894                    worktree_id: worktree.id,
1895                    summary: Some(summary),
1896                },
1897            )?;
1898        }
1899
1900        for settings_file in worktree.settings_files {
1901            session.peer.send(
1902                session.connection_id,
1903                proto::UpdateWorktreeSettings {
1904                    project_id: project_id.to_proto(),
1905                    worktree_id: worktree.id,
1906                    path: settings_file.path,
1907                    content: Some(settings_file.content),
1908                    kind: Some(settings_file.kind.to_proto() as i32),
1909                },
1910            )?;
1911        }
1912    }
1913
1914    for language_server in &project.language_servers {
1915        session.peer.send(
1916            session.connection_id,
1917            proto::UpdateLanguageServer {
1918                project_id: project_id.to_proto(),
1919                language_server_id: language_server.id,
1920                variant: Some(
1921                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1922                        proto::LspDiskBasedDiagnosticsUpdated {},
1923                    ),
1924                ),
1925            },
1926        )?;
1927    }
1928
1929    Ok(())
1930}
1931
1932/// Leave someone elses shared project.
1933async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1934    let sender_id = session.connection_id;
1935    let project_id = ProjectId::from_proto(request.project_id);
1936    let db = session.db().await;
1937
1938    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1939    tracing::info!(
1940        %project_id,
1941        "leave project"
1942    );
1943
1944    project_left(project, &session);
1945    if let Some(room) = room {
1946        room_updated(room, &session.peer);
1947    }
1948
1949    Ok(())
1950}
1951
1952/// Updates other participants with changes to the project
1953async fn update_project(
1954    request: proto::UpdateProject,
1955    response: Response<proto::UpdateProject>,
1956    session: Session,
1957) -> Result<()> {
1958    let project_id = ProjectId::from_proto(request.project_id);
1959    let (room, guest_connection_ids) = &*session
1960        .db()
1961        .await
1962        .update_project(project_id, session.connection_id, &request.worktrees)
1963        .await?;
1964    broadcast(
1965        Some(session.connection_id),
1966        guest_connection_ids.iter().copied(),
1967        |connection_id| {
1968            session
1969                .peer
1970                .forward_send(session.connection_id, connection_id, request.clone())
1971        },
1972    );
1973    if let Some(room) = room {
1974        room_updated(room, &session.peer);
1975    }
1976    response.send(proto::Ack {})?;
1977
1978    Ok(())
1979}
1980
1981/// Updates other participants with changes to the worktree
1982async fn update_worktree(
1983    request: proto::UpdateWorktree,
1984    response: Response<proto::UpdateWorktree>,
1985    session: Session,
1986) -> Result<()> {
1987    let guest_connection_ids = session
1988        .db()
1989        .await
1990        .update_worktree(&request, session.connection_id)
1991        .await?;
1992
1993    broadcast(
1994        Some(session.connection_id),
1995        guest_connection_ids.iter().copied(),
1996        |connection_id| {
1997            session
1998                .peer
1999                .forward_send(session.connection_id, connection_id, request.clone())
2000        },
2001    );
2002    response.send(proto::Ack {})?;
2003    Ok(())
2004}
2005
2006/// Updates other participants with changes to the diagnostics
2007async fn update_diagnostic_summary(
2008    message: proto::UpdateDiagnosticSummary,
2009    session: Session,
2010) -> Result<()> {
2011    let guest_connection_ids = session
2012        .db()
2013        .await
2014        .update_diagnostic_summary(&message, session.connection_id)
2015        .await?;
2016
2017    broadcast(
2018        Some(session.connection_id),
2019        guest_connection_ids.iter().copied(),
2020        |connection_id| {
2021            session
2022                .peer
2023                .forward_send(session.connection_id, connection_id, message.clone())
2024        },
2025    );
2026
2027    Ok(())
2028}
2029
2030/// Updates other participants with changes to the worktree settings
2031async fn update_worktree_settings(
2032    message: proto::UpdateWorktreeSettings,
2033    session: Session,
2034) -> Result<()> {
2035    let guest_connection_ids = session
2036        .db()
2037        .await
2038        .update_worktree_settings(&message, session.connection_id)
2039        .await?;
2040
2041    broadcast(
2042        Some(session.connection_id),
2043        guest_connection_ids.iter().copied(),
2044        |connection_id| {
2045            session
2046                .peer
2047                .forward_send(session.connection_id, connection_id, message.clone())
2048        },
2049    );
2050
2051    Ok(())
2052}
2053
2054/// Notify other participants that a  language server has started.
2055async fn start_language_server(
2056    request: proto::StartLanguageServer,
2057    session: Session,
2058) -> Result<()> {
2059    let guest_connection_ids = session
2060        .db()
2061        .await
2062        .start_language_server(&request, session.connection_id)
2063        .await?;
2064
2065    broadcast(
2066        Some(session.connection_id),
2067        guest_connection_ids.iter().copied(),
2068        |connection_id| {
2069            session
2070                .peer
2071                .forward_send(session.connection_id, connection_id, request.clone())
2072        },
2073    );
2074    Ok(())
2075}
2076
2077/// Notify other participants that a language server has changed.
2078async fn update_language_server(
2079    request: proto::UpdateLanguageServer,
2080    session: Session,
2081) -> Result<()> {
2082    let project_id = ProjectId::from_proto(request.project_id);
2083    let project_connection_ids = session
2084        .db()
2085        .await
2086        .project_connection_ids(project_id, session.connection_id, true)
2087        .await?;
2088    broadcast(
2089        Some(session.connection_id),
2090        project_connection_ids.iter().copied(),
2091        |connection_id| {
2092            session
2093                .peer
2094                .forward_send(session.connection_id, connection_id, request.clone())
2095        },
2096    );
2097    Ok(())
2098}
2099
2100/// forward a project request to the host. These requests should be read only
2101/// as guests are allowed to send them.
2102async fn forward_read_only_project_request<T>(
2103    request: T,
2104    response: Response<T>,
2105    session: Session,
2106) -> Result<()>
2107where
2108    T: EntityMessage + RequestMessage,
2109{
2110    let project_id = ProjectId::from_proto(request.remote_entity_id());
2111    let host_connection_id = session
2112        .db()
2113        .await
2114        .host_for_read_only_project_request(project_id, session.connection_id)
2115        .await?;
2116    let payload = session
2117        .peer
2118        .forward_request(session.connection_id, host_connection_id, request)
2119        .await?;
2120    response.send(payload)?;
2121    Ok(())
2122}
2123
2124async fn forward_find_search_candidates_request(
2125    request: proto::FindSearchCandidates,
2126    response: Response<proto::FindSearchCandidates>,
2127    session: Session,
2128) -> Result<()> {
2129    let project_id = ProjectId::from_proto(request.remote_entity_id());
2130    let host_connection_id = session
2131        .db()
2132        .await
2133        .host_for_read_only_project_request(project_id, session.connection_id)
2134        .await?;
2135    let payload = session
2136        .peer
2137        .forward_request(session.connection_id, host_connection_id, request)
2138        .await?;
2139    response.send(payload)?;
2140    Ok(())
2141}
2142
2143/// forward a project request to the host. These requests are disallowed
2144/// for guests.
2145async fn forward_mutating_project_request<T>(
2146    request: T,
2147    response: Response<T>,
2148    session: Session,
2149) -> Result<()>
2150where
2151    T: EntityMessage + RequestMessage,
2152{
2153    let project_id = ProjectId::from_proto(request.remote_entity_id());
2154
2155    let host_connection_id = session
2156        .db()
2157        .await
2158        .host_for_mutating_project_request(project_id, session.connection_id)
2159        .await?;
2160    let payload = session
2161        .peer
2162        .forward_request(session.connection_id, host_connection_id, request)
2163        .await?;
2164    response.send(payload)?;
2165    Ok(())
2166}
2167
2168/// Notify other participants that a new buffer has been created
2169async fn create_buffer_for_peer(
2170    request: proto::CreateBufferForPeer,
2171    session: Session,
2172) -> Result<()> {
2173    session
2174        .db()
2175        .await
2176        .check_user_is_project_host(
2177            ProjectId::from_proto(request.project_id),
2178            session.connection_id,
2179        )
2180        .await?;
2181    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2182    session
2183        .peer
2184        .forward_send(session.connection_id, peer_id.into(), request)?;
2185    Ok(())
2186}
2187
2188/// Notify other participants that a buffer has been updated. This is
2189/// allowed for guests as long as the update is limited to selections.
2190async fn update_buffer(
2191    request: proto::UpdateBuffer,
2192    response: Response<proto::UpdateBuffer>,
2193    session: Session,
2194) -> Result<()> {
2195    let project_id = ProjectId::from_proto(request.project_id);
2196    let mut capability = Capability::ReadOnly;
2197
2198    for op in request.operations.iter() {
2199        match op.variant {
2200            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2201            Some(_) => capability = Capability::ReadWrite,
2202        }
2203    }
2204
2205    let host = {
2206        let guard = session
2207            .db()
2208            .await
2209            .connections_for_buffer_update(project_id, session.connection_id, capability)
2210            .await?;
2211
2212        let (host, guests) = &*guard;
2213
2214        broadcast(
2215            Some(session.connection_id),
2216            guests.clone(),
2217            |connection_id| {
2218                session
2219                    .peer
2220                    .forward_send(session.connection_id, connection_id, request.clone())
2221            },
2222        );
2223
2224        *host
2225    };
2226
2227    if host != session.connection_id {
2228        session
2229            .peer
2230            .forward_request(session.connection_id, host, request.clone())
2231            .await?;
2232    }
2233
2234    response.send(proto::Ack {})?;
2235    Ok(())
2236}
2237
2238async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2239    let project_id = ProjectId::from_proto(message.project_id);
2240
2241    let operation = message.operation.as_ref().context("invalid operation")?;
2242    let capability = match operation.variant.as_ref() {
2243        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2244            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2245                match buffer_op.variant {
2246                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2247                        Capability::ReadOnly
2248                    }
2249                    _ => Capability::ReadWrite,
2250                }
2251            } else {
2252                Capability::ReadWrite
2253            }
2254        }
2255        Some(_) => Capability::ReadWrite,
2256        None => Capability::ReadOnly,
2257    };
2258
2259    let guard = session
2260        .db()
2261        .await
2262        .connections_for_buffer_update(project_id, session.connection_id, capability)
2263        .await?;
2264
2265    let (host, guests) = &*guard;
2266
2267    broadcast(
2268        Some(session.connection_id),
2269        guests.iter().chain([host]).copied(),
2270        |connection_id| {
2271            session
2272                .peer
2273                .forward_send(session.connection_id, connection_id, message.clone())
2274        },
2275    );
2276
2277    Ok(())
2278}
2279
2280/// Notify other participants that a project has been updated.
2281async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2282    request: T,
2283    session: Session,
2284) -> Result<()> {
2285    let project_id = ProjectId::from_proto(request.remote_entity_id());
2286    let project_connection_ids = session
2287        .db()
2288        .await
2289        .project_connection_ids(project_id, session.connection_id, false)
2290        .await?;
2291
2292    broadcast(
2293        Some(session.connection_id),
2294        project_connection_ids.iter().copied(),
2295        |connection_id| {
2296            session
2297                .peer
2298                .forward_send(session.connection_id, connection_id, request.clone())
2299        },
2300    );
2301    Ok(())
2302}
2303
2304/// Start following another user in a call.
2305async fn follow(
2306    request: proto::Follow,
2307    response: Response<proto::Follow>,
2308    session: Session,
2309) -> Result<()> {
2310    let room_id = RoomId::from_proto(request.room_id);
2311    let project_id = request.project_id.map(ProjectId::from_proto);
2312    let leader_id = request
2313        .leader_id
2314        .ok_or_else(|| anyhow!("invalid leader id"))?
2315        .into();
2316    let follower_id = session.connection_id;
2317
2318    session
2319        .db()
2320        .await
2321        .check_room_participants(room_id, leader_id, session.connection_id)
2322        .await?;
2323
2324    let response_payload = session
2325        .peer
2326        .forward_request(session.connection_id, leader_id, request)
2327        .await?;
2328    response.send(response_payload)?;
2329
2330    if let Some(project_id) = project_id {
2331        let room = session
2332            .db()
2333            .await
2334            .follow(room_id, project_id, leader_id, follower_id)
2335            .await?;
2336        room_updated(&room, &session.peer);
2337    }
2338
2339    Ok(())
2340}
2341
2342/// Stop following another user in a call.
2343async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2344    let room_id = RoomId::from_proto(request.room_id);
2345    let project_id = request.project_id.map(ProjectId::from_proto);
2346    let leader_id = request
2347        .leader_id
2348        .ok_or_else(|| anyhow!("invalid leader id"))?
2349        .into();
2350    let follower_id = session.connection_id;
2351
2352    session
2353        .db()
2354        .await
2355        .check_room_participants(room_id, leader_id, session.connection_id)
2356        .await?;
2357
2358    session
2359        .peer
2360        .forward_send(session.connection_id, leader_id, request)?;
2361
2362    if let Some(project_id) = project_id {
2363        let room = session
2364            .db()
2365            .await
2366            .unfollow(room_id, project_id, leader_id, follower_id)
2367            .await?;
2368        room_updated(&room, &session.peer);
2369    }
2370
2371    Ok(())
2372}
2373
2374/// Notify everyone following you of your current location.
2375async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2376    let room_id = RoomId::from_proto(request.room_id);
2377    let database = session.db.lock().await;
2378
2379    let connection_ids = if let Some(project_id) = request.project_id {
2380        let project_id = ProjectId::from_proto(project_id);
2381        database
2382            .project_connection_ids(project_id, session.connection_id, true)
2383            .await?
2384    } else {
2385        database
2386            .room_connection_ids(room_id, session.connection_id)
2387            .await?
2388    };
2389
2390    // For now, don't send view update messages back to that view's current leader.
2391    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2392        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2393        _ => None,
2394    });
2395
2396    for connection_id in connection_ids.iter().cloned() {
2397        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2398            session
2399                .peer
2400                .forward_send(session.connection_id, connection_id, request.clone())?;
2401        }
2402    }
2403    Ok(())
2404}
2405
2406/// Get public data about users.
2407async fn get_users(
2408    request: proto::GetUsers,
2409    response: Response<proto::GetUsers>,
2410    session: Session,
2411) -> Result<()> {
2412    let user_ids = request
2413        .user_ids
2414        .into_iter()
2415        .map(UserId::from_proto)
2416        .collect();
2417    let users = session
2418        .db()
2419        .await
2420        .get_users_by_ids(user_ids)
2421        .await?
2422        .into_iter()
2423        .map(|user| proto::User {
2424            id: user.id.to_proto(),
2425            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2426            github_login: user.github_login,
2427            email: user.email_address,
2428            name: user.name,
2429        })
2430        .collect();
2431    response.send(proto::UsersResponse { users })?;
2432    Ok(())
2433}
2434
2435/// Search for users (to invite) buy Github login
2436async fn fuzzy_search_users(
2437    request: proto::FuzzySearchUsers,
2438    response: Response<proto::FuzzySearchUsers>,
2439    session: Session,
2440) -> Result<()> {
2441    let query = request.query;
2442    let users = match query.len() {
2443        0 => vec![],
2444        1 | 2 => session
2445            .db()
2446            .await
2447            .get_user_by_github_login(&query)
2448            .await?
2449            .into_iter()
2450            .collect(),
2451        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2452    };
2453    let users = users
2454        .into_iter()
2455        .filter(|user| user.id != session.user_id())
2456        .map(|user| proto::User {
2457            id: user.id.to_proto(),
2458            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2459            github_login: user.github_login,
2460            name: user.name,
2461            email: user.email_address,
2462        })
2463        .collect();
2464    response.send(proto::UsersResponse { users })?;
2465    Ok(())
2466}
2467
2468/// Send a contact request to another user.
2469async fn request_contact(
2470    request: proto::RequestContact,
2471    response: Response<proto::RequestContact>,
2472    session: Session,
2473) -> Result<()> {
2474    let requester_id = session.user_id();
2475    let responder_id = UserId::from_proto(request.responder_id);
2476    if requester_id == responder_id {
2477        return Err(anyhow!("cannot add yourself as a contact"))?;
2478    }
2479
2480    let notifications = session
2481        .db()
2482        .await
2483        .send_contact_request(requester_id, responder_id)
2484        .await?;
2485
2486    // Update outgoing contact requests of requester
2487    let mut update = proto::UpdateContacts::default();
2488    update.outgoing_requests.push(responder_id.to_proto());
2489    for connection_id in session
2490        .connection_pool()
2491        .await
2492        .user_connection_ids(requester_id)
2493    {
2494        session.peer.send(connection_id, update.clone())?;
2495    }
2496
2497    // Update incoming contact requests of responder
2498    let mut update = proto::UpdateContacts::default();
2499    update
2500        .incoming_requests
2501        .push(proto::IncomingContactRequest {
2502            requester_id: requester_id.to_proto(),
2503        });
2504    let connection_pool = session.connection_pool().await;
2505    for connection_id in connection_pool.user_connection_ids(responder_id) {
2506        session.peer.send(connection_id, update.clone())?;
2507    }
2508
2509    send_notifications(&connection_pool, &session.peer, notifications);
2510
2511    response.send(proto::Ack {})?;
2512    Ok(())
2513}
2514
2515/// Accept or decline a contact request
2516async fn respond_to_contact_request(
2517    request: proto::RespondToContactRequest,
2518    response: Response<proto::RespondToContactRequest>,
2519    session: Session,
2520) -> Result<()> {
2521    let responder_id = session.user_id();
2522    let requester_id = UserId::from_proto(request.requester_id);
2523    let db = session.db().await;
2524    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2525        db.dismiss_contact_notification(responder_id, requester_id)
2526            .await?;
2527    } else {
2528        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2529
2530        let notifications = db
2531            .respond_to_contact_request(responder_id, requester_id, accept)
2532            .await?;
2533        let requester_busy = db.is_user_busy(requester_id).await?;
2534        let responder_busy = db.is_user_busy(responder_id).await?;
2535
2536        let pool = session.connection_pool().await;
2537        // Update responder with new contact
2538        let mut update = proto::UpdateContacts::default();
2539        if accept {
2540            update
2541                .contacts
2542                .push(contact_for_user(requester_id, requester_busy, &pool));
2543        }
2544        update
2545            .remove_incoming_requests
2546            .push(requester_id.to_proto());
2547        for connection_id in pool.user_connection_ids(responder_id) {
2548            session.peer.send(connection_id, update.clone())?;
2549        }
2550
2551        // Update requester with new contact
2552        let mut update = proto::UpdateContacts::default();
2553        if accept {
2554            update
2555                .contacts
2556                .push(contact_for_user(responder_id, responder_busy, &pool));
2557        }
2558        update
2559            .remove_outgoing_requests
2560            .push(responder_id.to_proto());
2561
2562        for connection_id in pool.user_connection_ids(requester_id) {
2563            session.peer.send(connection_id, update.clone())?;
2564        }
2565
2566        send_notifications(&pool, &session.peer, notifications);
2567    }
2568
2569    response.send(proto::Ack {})?;
2570    Ok(())
2571}
2572
2573/// Remove a contact.
2574async fn remove_contact(
2575    request: proto::RemoveContact,
2576    response: Response<proto::RemoveContact>,
2577    session: Session,
2578) -> Result<()> {
2579    let requester_id = session.user_id();
2580    let responder_id = UserId::from_proto(request.user_id);
2581    let db = session.db().await;
2582    let (contact_accepted, deleted_notification_id) =
2583        db.remove_contact(requester_id, responder_id).await?;
2584
2585    let pool = session.connection_pool().await;
2586    // Update outgoing contact requests of requester
2587    let mut update = proto::UpdateContacts::default();
2588    if contact_accepted {
2589        update.remove_contacts.push(responder_id.to_proto());
2590    } else {
2591        update
2592            .remove_outgoing_requests
2593            .push(responder_id.to_proto());
2594    }
2595    for connection_id in pool.user_connection_ids(requester_id) {
2596        session.peer.send(connection_id, update.clone())?;
2597    }
2598
2599    // Update incoming contact requests of responder
2600    let mut update = proto::UpdateContacts::default();
2601    if contact_accepted {
2602        update.remove_contacts.push(requester_id.to_proto());
2603    } else {
2604        update
2605            .remove_incoming_requests
2606            .push(requester_id.to_proto());
2607    }
2608    for connection_id in pool.user_connection_ids(responder_id) {
2609        session.peer.send(connection_id, update.clone())?;
2610        if let Some(notification_id) = deleted_notification_id {
2611            session.peer.send(
2612                connection_id,
2613                proto::DeleteNotification {
2614                    notification_id: notification_id.to_proto(),
2615                },
2616            )?;
2617        }
2618    }
2619
2620    response.send(proto::Ack {})?;
2621    Ok(())
2622}
2623
2624fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2625    version.0.minor() < 139
2626}
2627
2628async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2629    let plan = session.current_plan(&session.db().await).await?;
2630
2631    session
2632        .peer
2633        .send(
2634            session.connection_id,
2635            proto::UpdateUserPlan { plan: plan.into() },
2636        )
2637        .trace_err();
2638
2639    Ok(())
2640}
2641
2642async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2643    subscribe_user_to_channels(session.user_id(), &session).await?;
2644    Ok(())
2645}
2646
2647async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2648    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2649    let mut pool = session.connection_pool().await;
2650    for membership in &channels_for_user.channel_memberships {
2651        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2652    }
2653    session.peer.send(
2654        session.connection_id,
2655        build_update_user_channels(&channels_for_user),
2656    )?;
2657    session.peer.send(
2658        session.connection_id,
2659        build_channels_update(channels_for_user),
2660    )?;
2661    Ok(())
2662}
2663
2664/// Creates a new channel.
2665async fn create_channel(
2666    request: proto::CreateChannel,
2667    response: Response<proto::CreateChannel>,
2668    session: Session,
2669) -> Result<()> {
2670    let db = session.db().await;
2671
2672    let parent_id = request.parent_id.map(ChannelId::from_proto);
2673    let (channel, membership) = db
2674        .create_channel(&request.name, parent_id, session.user_id())
2675        .await?;
2676
2677    let root_id = channel.root_id();
2678    let channel = Channel::from_model(channel);
2679
2680    response.send(proto::CreateChannelResponse {
2681        channel: Some(channel.to_proto()),
2682        parent_id: request.parent_id,
2683    })?;
2684
2685    let mut connection_pool = session.connection_pool().await;
2686    if let Some(membership) = membership {
2687        connection_pool.subscribe_to_channel(
2688            membership.user_id,
2689            membership.channel_id,
2690            membership.role,
2691        );
2692        let update = proto::UpdateUserChannels {
2693            channel_memberships: vec![proto::ChannelMembership {
2694                channel_id: membership.channel_id.to_proto(),
2695                role: membership.role.into(),
2696            }],
2697            ..Default::default()
2698        };
2699        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2700            session.peer.send(connection_id, update.clone())?;
2701        }
2702    }
2703
2704    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2705        if !role.can_see_channel(channel.visibility) {
2706            continue;
2707        }
2708
2709        let update = proto::UpdateChannels {
2710            channels: vec![channel.to_proto()],
2711            ..Default::default()
2712        };
2713        session.peer.send(connection_id, update.clone())?;
2714    }
2715
2716    Ok(())
2717}
2718
2719/// Delete a channel
2720async fn delete_channel(
2721    request: proto::DeleteChannel,
2722    response: Response<proto::DeleteChannel>,
2723    session: Session,
2724) -> Result<()> {
2725    let db = session.db().await;
2726
2727    let channel_id = request.channel_id;
2728    let (root_channel, removed_channels) = db
2729        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2730        .await?;
2731    response.send(proto::Ack {})?;
2732
2733    // Notify members of removed channels
2734    let mut update = proto::UpdateChannels::default();
2735    update
2736        .delete_channels
2737        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2738
2739    let connection_pool = session.connection_pool().await;
2740    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2741        session.peer.send(connection_id, update.clone())?;
2742    }
2743
2744    Ok(())
2745}
2746
2747/// Invite someone to join a channel.
2748async fn invite_channel_member(
2749    request: proto::InviteChannelMember,
2750    response: Response<proto::InviteChannelMember>,
2751    session: Session,
2752) -> Result<()> {
2753    let db = session.db().await;
2754    let channel_id = ChannelId::from_proto(request.channel_id);
2755    let invitee_id = UserId::from_proto(request.user_id);
2756    let InviteMemberResult {
2757        channel,
2758        notifications,
2759    } = db
2760        .invite_channel_member(
2761            channel_id,
2762            invitee_id,
2763            session.user_id(),
2764            request.role().into(),
2765        )
2766        .await?;
2767
2768    let update = proto::UpdateChannels {
2769        channel_invitations: vec![channel.to_proto()],
2770        ..Default::default()
2771    };
2772
2773    let connection_pool = session.connection_pool().await;
2774    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2775        session.peer.send(connection_id, update.clone())?;
2776    }
2777
2778    send_notifications(&connection_pool, &session.peer, notifications);
2779
2780    response.send(proto::Ack {})?;
2781    Ok(())
2782}
2783
2784/// remove someone from a channel
2785async fn remove_channel_member(
2786    request: proto::RemoveChannelMember,
2787    response: Response<proto::RemoveChannelMember>,
2788    session: Session,
2789) -> Result<()> {
2790    let db = session.db().await;
2791    let channel_id = ChannelId::from_proto(request.channel_id);
2792    let member_id = UserId::from_proto(request.user_id);
2793
2794    let RemoveChannelMemberResult {
2795        membership_update,
2796        notification_id,
2797    } = db
2798        .remove_channel_member(channel_id, member_id, session.user_id())
2799        .await?;
2800
2801    let mut connection_pool = session.connection_pool().await;
2802    notify_membership_updated(
2803        &mut connection_pool,
2804        membership_update,
2805        member_id,
2806        &session.peer,
2807    );
2808    for connection_id in connection_pool.user_connection_ids(member_id) {
2809        if let Some(notification_id) = notification_id {
2810            session
2811                .peer
2812                .send(
2813                    connection_id,
2814                    proto::DeleteNotification {
2815                        notification_id: notification_id.to_proto(),
2816                    },
2817                )
2818                .trace_err();
2819        }
2820    }
2821
2822    response.send(proto::Ack {})?;
2823    Ok(())
2824}
2825
2826/// Toggle the channel between public and private.
2827/// Care is taken to maintain the invariant that public channels only descend from public channels,
2828/// (though members-only channels can appear at any point in the hierarchy).
2829async fn set_channel_visibility(
2830    request: proto::SetChannelVisibility,
2831    response: Response<proto::SetChannelVisibility>,
2832    session: Session,
2833) -> Result<()> {
2834    let db = session.db().await;
2835    let channel_id = ChannelId::from_proto(request.channel_id);
2836    let visibility = request.visibility().into();
2837
2838    let channel_model = db
2839        .set_channel_visibility(channel_id, visibility, session.user_id())
2840        .await?;
2841    let root_id = channel_model.root_id();
2842    let channel = Channel::from_model(channel_model);
2843
2844    let mut connection_pool = session.connection_pool().await;
2845    for (user_id, role) in connection_pool
2846        .channel_user_ids(root_id)
2847        .collect::<Vec<_>>()
2848        .into_iter()
2849    {
2850        let update = if role.can_see_channel(channel.visibility) {
2851            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2852            proto::UpdateChannels {
2853                channels: vec![channel.to_proto()],
2854                ..Default::default()
2855            }
2856        } else {
2857            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2858            proto::UpdateChannels {
2859                delete_channels: vec![channel.id.to_proto()],
2860                ..Default::default()
2861            }
2862        };
2863
2864        for connection_id in connection_pool.user_connection_ids(user_id) {
2865            session.peer.send(connection_id, update.clone())?;
2866        }
2867    }
2868
2869    response.send(proto::Ack {})?;
2870    Ok(())
2871}
2872
2873/// Alter the role for a user in the channel.
2874async fn set_channel_member_role(
2875    request: proto::SetChannelMemberRole,
2876    response: Response<proto::SetChannelMemberRole>,
2877    session: Session,
2878) -> Result<()> {
2879    let db = session.db().await;
2880    let channel_id = ChannelId::from_proto(request.channel_id);
2881    let member_id = UserId::from_proto(request.user_id);
2882    let result = db
2883        .set_channel_member_role(
2884            channel_id,
2885            session.user_id(),
2886            member_id,
2887            request.role().into(),
2888        )
2889        .await?;
2890
2891    match result {
2892        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2893            let mut connection_pool = session.connection_pool().await;
2894            notify_membership_updated(
2895                &mut connection_pool,
2896                membership_update,
2897                member_id,
2898                &session.peer,
2899            )
2900        }
2901        db::SetMemberRoleResult::InviteUpdated(channel) => {
2902            let update = proto::UpdateChannels {
2903                channel_invitations: vec![channel.to_proto()],
2904                ..Default::default()
2905            };
2906
2907            for connection_id in session
2908                .connection_pool()
2909                .await
2910                .user_connection_ids(member_id)
2911            {
2912                session.peer.send(connection_id, update.clone())?;
2913            }
2914        }
2915    }
2916
2917    response.send(proto::Ack {})?;
2918    Ok(())
2919}
2920
2921/// Change the name of a channel
2922async fn rename_channel(
2923    request: proto::RenameChannel,
2924    response: Response<proto::RenameChannel>,
2925    session: Session,
2926) -> Result<()> {
2927    let db = session.db().await;
2928    let channel_id = ChannelId::from_proto(request.channel_id);
2929    let channel_model = db
2930        .rename_channel(channel_id, session.user_id(), &request.name)
2931        .await?;
2932    let root_id = channel_model.root_id();
2933    let channel = Channel::from_model(channel_model);
2934
2935    response.send(proto::RenameChannelResponse {
2936        channel: Some(channel.to_proto()),
2937    })?;
2938
2939    let connection_pool = session.connection_pool().await;
2940    let update = proto::UpdateChannels {
2941        channels: vec![channel.to_proto()],
2942        ..Default::default()
2943    };
2944    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2945        if role.can_see_channel(channel.visibility) {
2946            session.peer.send(connection_id, update.clone())?;
2947        }
2948    }
2949
2950    Ok(())
2951}
2952
2953/// Move a channel to a new parent.
2954async fn move_channel(
2955    request: proto::MoveChannel,
2956    response: Response<proto::MoveChannel>,
2957    session: Session,
2958) -> Result<()> {
2959    let channel_id = ChannelId::from_proto(request.channel_id);
2960    let to = ChannelId::from_proto(request.to);
2961
2962    let (root_id, channels) = session
2963        .db()
2964        .await
2965        .move_channel(channel_id, to, session.user_id())
2966        .await?;
2967
2968    let connection_pool = session.connection_pool().await;
2969    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2970        let channels = channels
2971            .iter()
2972            .filter_map(|channel| {
2973                if role.can_see_channel(channel.visibility) {
2974                    Some(channel.to_proto())
2975                } else {
2976                    None
2977                }
2978            })
2979            .collect::<Vec<_>>();
2980        if channels.is_empty() {
2981            continue;
2982        }
2983
2984        let update = proto::UpdateChannels {
2985            channels,
2986            ..Default::default()
2987        };
2988
2989        session.peer.send(connection_id, update.clone())?;
2990    }
2991
2992    response.send(Ack {})?;
2993    Ok(())
2994}
2995
2996/// Get the list of channel members
2997async fn get_channel_members(
2998    request: proto::GetChannelMembers,
2999    response: Response<proto::GetChannelMembers>,
3000    session: Session,
3001) -> Result<()> {
3002    let db = session.db().await;
3003    let channel_id = ChannelId::from_proto(request.channel_id);
3004    let limit = if request.limit == 0 {
3005        u16::MAX as u64
3006    } else {
3007        request.limit
3008    };
3009    let (members, users) = db
3010        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3011        .await?;
3012    response.send(proto::GetChannelMembersResponse { members, users })?;
3013    Ok(())
3014}
3015
3016/// Accept or decline a channel invitation.
3017async fn respond_to_channel_invite(
3018    request: proto::RespondToChannelInvite,
3019    response: Response<proto::RespondToChannelInvite>,
3020    session: Session,
3021) -> Result<()> {
3022    let db = session.db().await;
3023    let channel_id = ChannelId::from_proto(request.channel_id);
3024    let RespondToChannelInvite {
3025        membership_update,
3026        notifications,
3027    } = db
3028        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3029        .await?;
3030
3031    let mut connection_pool = session.connection_pool().await;
3032    if let Some(membership_update) = membership_update {
3033        notify_membership_updated(
3034            &mut connection_pool,
3035            membership_update,
3036            session.user_id(),
3037            &session.peer,
3038        );
3039    } else {
3040        let update = proto::UpdateChannels {
3041            remove_channel_invitations: vec![channel_id.to_proto()],
3042            ..Default::default()
3043        };
3044
3045        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3046            session.peer.send(connection_id, update.clone())?;
3047        }
3048    };
3049
3050    send_notifications(&connection_pool, &session.peer, notifications);
3051
3052    response.send(proto::Ack {})?;
3053
3054    Ok(())
3055}
3056
3057/// Join the channels' room
3058async fn join_channel(
3059    request: proto::JoinChannel,
3060    response: Response<proto::JoinChannel>,
3061    session: Session,
3062) -> Result<()> {
3063    let channel_id = ChannelId::from_proto(request.channel_id);
3064    join_channel_internal(channel_id, Box::new(response), session).await
3065}
3066
3067trait JoinChannelInternalResponse {
3068    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3069}
3070impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3071    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3072        Response::<proto::JoinChannel>::send(self, result)
3073    }
3074}
3075impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3076    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3077        Response::<proto::JoinRoom>::send(self, result)
3078    }
3079}
3080
3081async fn join_channel_internal(
3082    channel_id: ChannelId,
3083    response: Box<impl JoinChannelInternalResponse>,
3084    session: Session,
3085) -> Result<()> {
3086    let joined_room = {
3087        let mut db = session.db().await;
3088        // If zed quits without leaving the room, and the user re-opens zed before the
3089        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3090        // room they were in.
3091        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3092            tracing::info!(
3093                stale_connection_id = %connection,
3094                "cleaning up stale connection",
3095            );
3096            drop(db);
3097            leave_room_for_session(&session, connection).await?;
3098            db = session.db().await;
3099        }
3100
3101        let (joined_room, membership_updated, role) = db
3102            .join_channel(channel_id, session.user_id(), session.connection_id)
3103            .await?;
3104
3105        let live_kit_connection_info =
3106            session
3107                .app_state
3108                .livekit_client
3109                .as_ref()
3110                .and_then(|live_kit| {
3111                    let (can_publish, token) = if role == ChannelRole::Guest {
3112                        (
3113                            false,
3114                            live_kit
3115                                .guest_token(
3116                                    &joined_room.room.livekit_room,
3117                                    &session.user_id().to_string(),
3118                                )
3119                                .trace_err()?,
3120                        )
3121                    } else {
3122                        (
3123                            true,
3124                            live_kit
3125                                .room_token(
3126                                    &joined_room.room.livekit_room,
3127                                    &session.user_id().to_string(),
3128                                )
3129                                .trace_err()?,
3130                        )
3131                    };
3132
3133                    Some(LiveKitConnectionInfo {
3134                        server_url: live_kit.url().into(),
3135                        token,
3136                        can_publish,
3137                    })
3138                });
3139
3140        response.send(proto::JoinRoomResponse {
3141            room: Some(joined_room.room.clone()),
3142            channel_id: joined_room
3143                .channel
3144                .as_ref()
3145                .map(|channel| channel.id.to_proto()),
3146            live_kit_connection_info,
3147        })?;
3148
3149        let mut connection_pool = session.connection_pool().await;
3150        if let Some(membership_updated) = membership_updated {
3151            notify_membership_updated(
3152                &mut connection_pool,
3153                membership_updated,
3154                session.user_id(),
3155                &session.peer,
3156            );
3157        }
3158
3159        room_updated(&joined_room.room, &session.peer);
3160
3161        joined_room
3162    };
3163
3164    channel_updated(
3165        &joined_room
3166            .channel
3167            .ok_or_else(|| anyhow!("channel not returned"))?,
3168        &joined_room.room,
3169        &session.peer,
3170        &*session.connection_pool().await,
3171    );
3172
3173    update_user_contacts(session.user_id(), &session).await?;
3174    Ok(())
3175}
3176
3177/// Start editing the channel notes
3178async fn join_channel_buffer(
3179    request: proto::JoinChannelBuffer,
3180    response: Response<proto::JoinChannelBuffer>,
3181    session: Session,
3182) -> Result<()> {
3183    let db = session.db().await;
3184    let channel_id = ChannelId::from_proto(request.channel_id);
3185
3186    let open_response = db
3187        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3188        .await?;
3189
3190    let collaborators = open_response.collaborators.clone();
3191    response.send(open_response)?;
3192
3193    let update = UpdateChannelBufferCollaborators {
3194        channel_id: channel_id.to_proto(),
3195        collaborators: collaborators.clone(),
3196    };
3197    channel_buffer_updated(
3198        session.connection_id,
3199        collaborators
3200            .iter()
3201            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3202        &update,
3203        &session.peer,
3204    );
3205
3206    Ok(())
3207}
3208
3209/// Edit the channel notes
3210async fn update_channel_buffer(
3211    request: proto::UpdateChannelBuffer,
3212    session: Session,
3213) -> Result<()> {
3214    let db = session.db().await;
3215    let channel_id = ChannelId::from_proto(request.channel_id);
3216
3217    let (collaborators, epoch, version) = db
3218        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3219        .await?;
3220
3221    channel_buffer_updated(
3222        session.connection_id,
3223        collaborators.clone(),
3224        &proto::UpdateChannelBuffer {
3225            channel_id: channel_id.to_proto(),
3226            operations: request.operations,
3227        },
3228        &session.peer,
3229    );
3230
3231    let pool = &*session.connection_pool().await;
3232
3233    let non_collaborators =
3234        pool.channel_connection_ids(channel_id)
3235            .filter_map(|(connection_id, _)| {
3236                if collaborators.contains(&connection_id) {
3237                    None
3238                } else {
3239                    Some(connection_id)
3240                }
3241            });
3242
3243    broadcast(None, non_collaborators, |peer_id| {
3244        session.peer.send(
3245            peer_id,
3246            proto::UpdateChannels {
3247                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3248                    channel_id: channel_id.to_proto(),
3249                    epoch: epoch as u64,
3250                    version: version.clone(),
3251                }],
3252                ..Default::default()
3253            },
3254        )
3255    });
3256
3257    Ok(())
3258}
3259
3260/// Rejoin the channel notes after a connection blip
3261async fn rejoin_channel_buffers(
3262    request: proto::RejoinChannelBuffers,
3263    response: Response<proto::RejoinChannelBuffers>,
3264    session: Session,
3265) -> Result<()> {
3266    let db = session.db().await;
3267    let buffers = db
3268        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3269        .await?;
3270
3271    for rejoined_buffer in &buffers {
3272        let collaborators_to_notify = rejoined_buffer
3273            .buffer
3274            .collaborators
3275            .iter()
3276            .filter_map(|c| Some(c.peer_id?.into()));
3277        channel_buffer_updated(
3278            session.connection_id,
3279            collaborators_to_notify,
3280            &proto::UpdateChannelBufferCollaborators {
3281                channel_id: rejoined_buffer.buffer.channel_id,
3282                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3283            },
3284            &session.peer,
3285        );
3286    }
3287
3288    response.send(proto::RejoinChannelBuffersResponse {
3289        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3290    })?;
3291
3292    Ok(())
3293}
3294
3295/// Stop editing the channel notes
3296async fn leave_channel_buffer(
3297    request: proto::LeaveChannelBuffer,
3298    response: Response<proto::LeaveChannelBuffer>,
3299    session: Session,
3300) -> Result<()> {
3301    let db = session.db().await;
3302    let channel_id = ChannelId::from_proto(request.channel_id);
3303
3304    let left_buffer = db
3305        .leave_channel_buffer(channel_id, session.connection_id)
3306        .await?;
3307
3308    response.send(Ack {})?;
3309
3310    channel_buffer_updated(
3311        session.connection_id,
3312        left_buffer.connections,
3313        &proto::UpdateChannelBufferCollaborators {
3314            channel_id: channel_id.to_proto(),
3315            collaborators: left_buffer.collaborators,
3316        },
3317        &session.peer,
3318    );
3319
3320    Ok(())
3321}
3322
3323fn channel_buffer_updated<T: EnvelopedMessage>(
3324    sender_id: ConnectionId,
3325    collaborators: impl IntoIterator<Item = ConnectionId>,
3326    message: &T,
3327    peer: &Peer,
3328) {
3329    broadcast(Some(sender_id), collaborators, |peer_id| {
3330        peer.send(peer_id, message.clone())
3331    });
3332}
3333
3334fn send_notifications(
3335    connection_pool: &ConnectionPool,
3336    peer: &Peer,
3337    notifications: db::NotificationBatch,
3338) {
3339    for (user_id, notification) in notifications {
3340        for connection_id in connection_pool.user_connection_ids(user_id) {
3341            if let Err(error) = peer.send(
3342                connection_id,
3343                proto::AddNotification {
3344                    notification: Some(notification.clone()),
3345                },
3346            ) {
3347                tracing::error!(
3348                    "failed to send notification to {:?} {}",
3349                    connection_id,
3350                    error
3351                );
3352            }
3353        }
3354    }
3355}
3356
3357/// Send a message to the channel
3358async fn send_channel_message(
3359    request: proto::SendChannelMessage,
3360    response: Response<proto::SendChannelMessage>,
3361    session: Session,
3362) -> Result<()> {
3363    // Validate the message body.
3364    let body = request.body.trim().to_string();
3365    if body.len() > MAX_MESSAGE_LEN {
3366        return Err(anyhow!("message is too long"))?;
3367    }
3368    if body.is_empty() {
3369        return Err(anyhow!("message can't be blank"))?;
3370    }
3371
3372    // TODO: adjust mentions if body is trimmed
3373
3374    let timestamp = OffsetDateTime::now_utc();
3375    let nonce = request
3376        .nonce
3377        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3378
3379    let channel_id = ChannelId::from_proto(request.channel_id);
3380    let CreatedChannelMessage {
3381        message_id,
3382        participant_connection_ids,
3383        notifications,
3384    } = session
3385        .db()
3386        .await
3387        .create_channel_message(
3388            channel_id,
3389            session.user_id(),
3390            &body,
3391            &request.mentions,
3392            timestamp,
3393            nonce.clone().into(),
3394            request.reply_to_message_id.map(MessageId::from_proto),
3395        )
3396        .await?;
3397
3398    let message = proto::ChannelMessage {
3399        sender_id: session.user_id().to_proto(),
3400        id: message_id.to_proto(),
3401        body,
3402        mentions: request.mentions,
3403        timestamp: timestamp.unix_timestamp() as u64,
3404        nonce: Some(nonce),
3405        reply_to_message_id: request.reply_to_message_id,
3406        edited_at: None,
3407    };
3408    broadcast(
3409        Some(session.connection_id),
3410        participant_connection_ids.clone(),
3411        |connection| {
3412            session.peer.send(
3413                connection,
3414                proto::ChannelMessageSent {
3415                    channel_id: channel_id.to_proto(),
3416                    message: Some(message.clone()),
3417                },
3418            )
3419        },
3420    );
3421    response.send(proto::SendChannelMessageResponse {
3422        message: Some(message),
3423    })?;
3424
3425    let pool = &*session.connection_pool().await;
3426    let non_participants =
3427        pool.channel_connection_ids(channel_id)
3428            .filter_map(|(connection_id, _)| {
3429                if participant_connection_ids.contains(&connection_id) {
3430                    None
3431                } else {
3432                    Some(connection_id)
3433                }
3434            });
3435    broadcast(None, non_participants, |peer_id| {
3436        session.peer.send(
3437            peer_id,
3438            proto::UpdateChannels {
3439                latest_channel_message_ids: vec![proto::ChannelMessageId {
3440                    channel_id: channel_id.to_proto(),
3441                    message_id: message_id.to_proto(),
3442                }],
3443                ..Default::default()
3444            },
3445        )
3446    });
3447    send_notifications(pool, &session.peer, notifications);
3448
3449    Ok(())
3450}
3451
3452/// Delete a channel message
3453async fn remove_channel_message(
3454    request: proto::RemoveChannelMessage,
3455    response: Response<proto::RemoveChannelMessage>,
3456    session: Session,
3457) -> Result<()> {
3458    let channel_id = ChannelId::from_proto(request.channel_id);
3459    let message_id = MessageId::from_proto(request.message_id);
3460    let (connection_ids, existing_notification_ids) = session
3461        .db()
3462        .await
3463        .remove_channel_message(channel_id, message_id, session.user_id())
3464        .await?;
3465
3466    broadcast(
3467        Some(session.connection_id),
3468        connection_ids,
3469        move |connection| {
3470            session.peer.send(connection, request.clone())?;
3471
3472            for notification_id in &existing_notification_ids {
3473                session.peer.send(
3474                    connection,
3475                    proto::DeleteNotification {
3476                        notification_id: (*notification_id).to_proto(),
3477                    },
3478                )?;
3479            }
3480
3481            Ok(())
3482        },
3483    );
3484    response.send(proto::Ack {})?;
3485    Ok(())
3486}
3487
3488async fn update_channel_message(
3489    request: proto::UpdateChannelMessage,
3490    response: Response<proto::UpdateChannelMessage>,
3491    session: Session,
3492) -> Result<()> {
3493    let channel_id = ChannelId::from_proto(request.channel_id);
3494    let message_id = MessageId::from_proto(request.message_id);
3495    let updated_at = OffsetDateTime::now_utc();
3496    let UpdatedChannelMessage {
3497        message_id,
3498        participant_connection_ids,
3499        notifications,
3500        reply_to_message_id,
3501        timestamp,
3502        deleted_mention_notification_ids,
3503        updated_mention_notifications,
3504    } = session
3505        .db()
3506        .await
3507        .update_channel_message(
3508            channel_id,
3509            message_id,
3510            session.user_id(),
3511            request.body.as_str(),
3512            &request.mentions,
3513            updated_at,
3514        )
3515        .await?;
3516
3517    let nonce = request
3518        .nonce
3519        .clone()
3520        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3521
3522    let message = proto::ChannelMessage {
3523        sender_id: session.user_id().to_proto(),
3524        id: message_id.to_proto(),
3525        body: request.body.clone(),
3526        mentions: request.mentions.clone(),
3527        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3528        nonce: Some(nonce),
3529        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3530        edited_at: Some(updated_at.unix_timestamp() as u64),
3531    };
3532
3533    response.send(proto::Ack {})?;
3534
3535    let pool = &*session.connection_pool().await;
3536    broadcast(
3537        Some(session.connection_id),
3538        participant_connection_ids,
3539        |connection| {
3540            session.peer.send(
3541                connection,
3542                proto::ChannelMessageUpdate {
3543                    channel_id: channel_id.to_proto(),
3544                    message: Some(message.clone()),
3545                },
3546            )?;
3547
3548            for notification_id in &deleted_mention_notification_ids {
3549                session.peer.send(
3550                    connection,
3551                    proto::DeleteNotification {
3552                        notification_id: (*notification_id).to_proto(),
3553                    },
3554                )?;
3555            }
3556
3557            for notification in &updated_mention_notifications {
3558                session.peer.send(
3559                    connection,
3560                    proto::UpdateNotification {
3561                        notification: Some(notification.clone()),
3562                    },
3563                )?;
3564            }
3565
3566            Ok(())
3567        },
3568    );
3569
3570    send_notifications(pool, &session.peer, notifications);
3571
3572    Ok(())
3573}
3574
3575/// Mark a channel message as read
3576async fn acknowledge_channel_message(
3577    request: proto::AckChannelMessage,
3578    session: Session,
3579) -> Result<()> {
3580    let channel_id = ChannelId::from_proto(request.channel_id);
3581    let message_id = MessageId::from_proto(request.message_id);
3582    let notifications = session
3583        .db()
3584        .await
3585        .observe_channel_message(channel_id, session.user_id(), message_id)
3586        .await?;
3587    send_notifications(
3588        &*session.connection_pool().await,
3589        &session.peer,
3590        notifications,
3591    );
3592    Ok(())
3593}
3594
3595/// Mark a buffer version as synced
3596async fn acknowledge_buffer_version(
3597    request: proto::AckBufferOperation,
3598    session: Session,
3599) -> Result<()> {
3600    let buffer_id = BufferId::from_proto(request.buffer_id);
3601    session
3602        .db()
3603        .await
3604        .observe_buffer_version(
3605            buffer_id,
3606            session.user_id(),
3607            request.epoch as i32,
3608            &request.version,
3609        )
3610        .await?;
3611    Ok(())
3612}
3613
3614async fn count_language_model_tokens(
3615    request: proto::CountLanguageModelTokens,
3616    response: Response<proto::CountLanguageModelTokens>,
3617    session: Session,
3618    config: &Config,
3619) -> Result<()> {
3620    authorize_access_to_legacy_llm_endpoints(&session).await?;
3621
3622    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3623        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3624        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3625    };
3626
3627    session
3628        .app_state
3629        .rate_limiter
3630        .check(&*rate_limit, session.user_id())
3631        .await?;
3632
3633    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3634        Some(proto::LanguageModelProvider::Google) => {
3635            let api_key = config
3636                .google_ai_api_key
3637                .as_ref()
3638                .context("no Google AI API key configured on the server")?;
3639            google_ai::count_tokens(
3640                session.http_client.as_ref(),
3641                google_ai::API_URL,
3642                api_key,
3643                serde_json::from_str(&request.request)?,
3644            )
3645            .await?
3646        }
3647        _ => return Err(anyhow!("unsupported provider"))?,
3648    };
3649
3650    response.send(proto::CountLanguageModelTokensResponse {
3651        token_count: result.total_tokens as u32,
3652    })?;
3653
3654    Ok(())
3655}
3656
3657struct ZedProCountLanguageModelTokensRateLimit;
3658
3659impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3660    fn capacity(&self) -> usize {
3661        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3662            .ok()
3663            .and_then(|v| v.parse().ok())
3664            .unwrap_or(600) // Picked arbitrarily
3665    }
3666
3667    fn refill_duration(&self) -> chrono::Duration {
3668        chrono::Duration::hours(1)
3669    }
3670
3671    fn db_name(&self) -> &'static str {
3672        "zed-pro:count-language-model-tokens"
3673    }
3674}
3675
3676struct FreeCountLanguageModelTokensRateLimit;
3677
3678impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3679    fn capacity(&self) -> usize {
3680        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3681            .ok()
3682            .and_then(|v| v.parse().ok())
3683            .unwrap_or(600 / 10) // Picked arbitrarily
3684    }
3685
3686    fn refill_duration(&self) -> chrono::Duration {
3687        chrono::Duration::hours(1)
3688    }
3689
3690    fn db_name(&self) -> &'static str {
3691        "free:count-language-model-tokens"
3692    }
3693}
3694
3695struct ZedProComputeEmbeddingsRateLimit;
3696
3697impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3698    fn capacity(&self) -> usize {
3699        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3700            .ok()
3701            .and_then(|v| v.parse().ok())
3702            .unwrap_or(5000) // Picked arbitrarily
3703    }
3704
3705    fn refill_duration(&self) -> chrono::Duration {
3706        chrono::Duration::hours(1)
3707    }
3708
3709    fn db_name(&self) -> &'static str {
3710        "zed-pro:compute-embeddings"
3711    }
3712}
3713
3714struct FreeComputeEmbeddingsRateLimit;
3715
3716impl RateLimit for FreeComputeEmbeddingsRateLimit {
3717    fn capacity(&self) -> usize {
3718        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3719            .ok()
3720            .and_then(|v| v.parse().ok())
3721            .unwrap_or(5000 / 10) // Picked arbitrarily
3722    }
3723
3724    fn refill_duration(&self) -> chrono::Duration {
3725        chrono::Duration::hours(1)
3726    }
3727
3728    fn db_name(&self) -> &'static str {
3729        "free:compute-embeddings"
3730    }
3731}
3732
3733async fn compute_embeddings(
3734    request: proto::ComputeEmbeddings,
3735    response: Response<proto::ComputeEmbeddings>,
3736    session: Session,
3737    api_key: Option<Arc<str>>,
3738) -> Result<()> {
3739    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3740    authorize_access_to_legacy_llm_endpoints(&session).await?;
3741
3742    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3743        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3744        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3745    };
3746
3747    session
3748        .app_state
3749        .rate_limiter
3750        .check(&*rate_limit, session.user_id())
3751        .await?;
3752
3753    let embeddings = match request.model.as_str() {
3754        "openai/text-embedding-3-small" => {
3755            open_ai::embed(
3756                session.http_client.as_ref(),
3757                OPEN_AI_API_URL,
3758                &api_key,
3759                OpenAiEmbeddingModel::TextEmbedding3Small,
3760                request.texts.iter().map(|text| text.as_str()),
3761            )
3762            .await?
3763        }
3764        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3765    };
3766
3767    let embeddings = request
3768        .texts
3769        .iter()
3770        .map(|text| {
3771            let mut hasher = sha2::Sha256::new();
3772            hasher.update(text.as_bytes());
3773            let result = hasher.finalize();
3774            result.to_vec()
3775        })
3776        .zip(
3777            embeddings
3778                .data
3779                .into_iter()
3780                .map(|embedding| embedding.embedding),
3781        )
3782        .collect::<HashMap<_, _>>();
3783
3784    let db = session.db().await;
3785    db.save_embeddings(&request.model, &embeddings)
3786        .await
3787        .context("failed to save embeddings")
3788        .trace_err();
3789
3790    response.send(proto::ComputeEmbeddingsResponse {
3791        embeddings: embeddings
3792            .into_iter()
3793            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3794            .collect(),
3795    })?;
3796    Ok(())
3797}
3798
3799async fn get_cached_embeddings(
3800    request: proto::GetCachedEmbeddings,
3801    response: Response<proto::GetCachedEmbeddings>,
3802    session: Session,
3803) -> Result<()> {
3804    authorize_access_to_legacy_llm_endpoints(&session).await?;
3805
3806    let db = session.db().await;
3807    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3808
3809    response.send(proto::GetCachedEmbeddingsResponse {
3810        embeddings: embeddings
3811            .into_iter()
3812            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3813            .collect(),
3814    })?;
3815    Ok(())
3816}
3817
3818/// This is leftover from before the LLM service.
3819///
3820/// The endpoints protected by this check will be moved there eventually.
3821async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3822    if session.is_staff() {
3823        Ok(())
3824    } else {
3825        Err(anyhow!("permission denied"))?
3826    }
3827}
3828
3829/// Get a Supermaven API key for the user
3830async fn get_supermaven_api_key(
3831    _request: proto::GetSupermavenApiKey,
3832    response: Response<proto::GetSupermavenApiKey>,
3833    session: Session,
3834) -> Result<()> {
3835    let user_id: String = session.user_id().to_string();
3836    if !session.is_staff() {
3837        return Err(anyhow!("supermaven not enabled for this account"))?;
3838    }
3839
3840    let email = session
3841        .email()
3842        .ok_or_else(|| anyhow!("user must have an email"))?;
3843
3844    let supermaven_admin_api = session
3845        .supermaven_client
3846        .as_ref()
3847        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3848
3849    let result = supermaven_admin_api
3850        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3851        .await?;
3852
3853    response.send(proto::GetSupermavenApiKeyResponse {
3854        api_key: result.api_key,
3855    })?;
3856
3857    Ok(())
3858}
3859
3860/// Start receiving chat updates for a channel
3861async fn join_channel_chat(
3862    request: proto::JoinChannelChat,
3863    response: Response<proto::JoinChannelChat>,
3864    session: Session,
3865) -> Result<()> {
3866    let channel_id = ChannelId::from_proto(request.channel_id);
3867
3868    let db = session.db().await;
3869    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3870        .await?;
3871    let messages = db
3872        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3873        .await?;
3874    response.send(proto::JoinChannelChatResponse {
3875        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3876        messages,
3877    })?;
3878    Ok(())
3879}
3880
3881/// Stop receiving chat updates for a channel
3882async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3883    let channel_id = ChannelId::from_proto(request.channel_id);
3884    session
3885        .db()
3886        .await
3887        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3888        .await?;
3889    Ok(())
3890}
3891
3892/// Retrieve the chat history for a channel
3893async fn get_channel_messages(
3894    request: proto::GetChannelMessages,
3895    response: Response<proto::GetChannelMessages>,
3896    session: Session,
3897) -> Result<()> {
3898    let channel_id = ChannelId::from_proto(request.channel_id);
3899    let messages = session
3900        .db()
3901        .await
3902        .get_channel_messages(
3903            channel_id,
3904            session.user_id(),
3905            MESSAGE_COUNT_PER_PAGE,
3906            Some(MessageId::from_proto(request.before_message_id)),
3907        )
3908        .await?;
3909    response.send(proto::GetChannelMessagesResponse {
3910        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3911        messages,
3912    })?;
3913    Ok(())
3914}
3915
3916/// Retrieve specific chat messages
3917async fn get_channel_messages_by_id(
3918    request: proto::GetChannelMessagesById,
3919    response: Response<proto::GetChannelMessagesById>,
3920    session: Session,
3921) -> Result<()> {
3922    let message_ids = request
3923        .message_ids
3924        .iter()
3925        .map(|id| MessageId::from_proto(*id))
3926        .collect::<Vec<_>>();
3927    let messages = session
3928        .db()
3929        .await
3930        .get_channel_messages_by_id(session.user_id(), &message_ids)
3931        .await?;
3932    response.send(proto::GetChannelMessagesResponse {
3933        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3934        messages,
3935    })?;
3936    Ok(())
3937}
3938
3939/// Retrieve the current users notifications
3940async fn get_notifications(
3941    request: proto::GetNotifications,
3942    response: Response<proto::GetNotifications>,
3943    session: Session,
3944) -> Result<()> {
3945    let notifications = session
3946        .db()
3947        .await
3948        .get_notifications(
3949            session.user_id(),
3950            NOTIFICATION_COUNT_PER_PAGE,
3951            request.before_id.map(db::NotificationId::from_proto),
3952        )
3953        .await?;
3954    response.send(proto::GetNotificationsResponse {
3955        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3956        notifications,
3957    })?;
3958    Ok(())
3959}
3960
3961/// Mark notifications as read
3962async fn mark_notification_as_read(
3963    request: proto::MarkNotificationRead,
3964    response: Response<proto::MarkNotificationRead>,
3965    session: Session,
3966) -> Result<()> {
3967    let database = &session.db().await;
3968    let notifications = database
3969        .mark_notification_as_read_by_id(
3970            session.user_id(),
3971            NotificationId::from_proto(request.notification_id),
3972        )
3973        .await?;
3974    send_notifications(
3975        &*session.connection_pool().await,
3976        &session.peer,
3977        notifications,
3978    );
3979    response.send(proto::Ack {})?;
3980    Ok(())
3981}
3982
3983/// Get the current users information
3984async fn get_private_user_info(
3985    _request: proto::GetPrivateUserInfo,
3986    response: Response<proto::GetPrivateUserInfo>,
3987    session: Session,
3988) -> Result<()> {
3989    let db = session.db().await;
3990
3991    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3992    let user = db
3993        .get_user_by_id(session.user_id())
3994        .await?
3995        .ok_or_else(|| anyhow!("user not found"))?;
3996    let flags = db.get_user_flags(session.user_id()).await?;
3997
3998    response.send(proto::GetPrivateUserInfoResponse {
3999        metrics_id,
4000        staff: user.admin,
4001        flags,
4002        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4003    })?;
4004    Ok(())
4005}
4006
4007/// Accept the terms of service (tos) on behalf of the current user
4008async fn accept_terms_of_service(
4009    _request: proto::AcceptTermsOfService,
4010    response: Response<proto::AcceptTermsOfService>,
4011    session: Session,
4012) -> Result<()> {
4013    let db = session.db().await;
4014
4015    let accepted_tos_at = Utc::now();
4016    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4017        .await?;
4018
4019    response.send(proto::AcceptTermsOfServiceResponse {
4020        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4021    })?;
4022    Ok(())
4023}
4024
4025/// The minimum account age an account must have in order to use the LLM service.
4026const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4027
4028async fn get_llm_api_token(
4029    _request: proto::GetLlmToken,
4030    response: Response<proto::GetLlmToken>,
4031    session: Session,
4032) -> Result<()> {
4033    let db = session.db().await;
4034
4035    let flags = db.get_user_flags(session.user_id()).await?;
4036    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4037    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4038    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4039
4040    if !session.is_staff() && !has_language_models_feature_flag {
4041        Err(anyhow!("permission denied"))?
4042    }
4043
4044    let user_id = session.user_id();
4045    let user = db
4046        .get_user_by_id(user_id)
4047        .await?
4048        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4049
4050    if user.accepted_tos_at.is_none() {
4051        Err(anyhow!("terms of service not accepted"))?
4052    }
4053
4054    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4055
4056    let bypass_account_age_check =
4057        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4058    if !bypass_account_age_check {
4059        let mut account_created_at = user.created_at;
4060        if let Some(github_created_at) = user.github_user_created_at {
4061            account_created_at = account_created_at.min(github_created_at);
4062        }
4063        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4064            Err(anyhow!("account too young"))?
4065        }
4066    }
4067
4068    let billing_preferences = db.get_billing_preferences(user.id).await?;
4069
4070    let token = LlmTokenClaims::create(
4071        &user,
4072        session.is_staff(),
4073        billing_preferences,
4074        has_llm_closed_beta_feature_flag,
4075        has_predict_edits_feature_flag,
4076        has_llm_subscription,
4077        session.current_plan(&db).await?,
4078        session.system_id.clone(),
4079        &session.app_state.config,
4080    )?;
4081    response.send(proto::GetLlmTokenResponse { token })?;
4082    Ok(())
4083}
4084
4085fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4086    let message = match message {
4087        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4088        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4089        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4090        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4091        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4092            code: frame.code.into(),
4093            reason: frame.reason,
4094        })),
4095        // We should never receive a frame while reading the message, according
4096        // to the `tungstenite` maintainers:
4097        //
4098        // > It cannot occur when you read messages from the WebSocket, but it
4099        // > can be used when you want to send the raw frames (e.g. you want to
4100        // > send the frames to the WebSocket without composing the full message first).
4101        // >
4102        // > — https://github.com/snapview/tungstenite-rs/issues/268
4103        TungsteniteMessage::Frame(_) => {
4104            bail!("received an unexpected frame while reading the message")
4105        }
4106    };
4107
4108    Ok(message)
4109}
4110
4111fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4112    match message {
4113        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4114        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4115        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4116        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4117        AxumMessage::Close(frame) => {
4118            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4119                code: frame.code.into(),
4120                reason: frame.reason,
4121            }))
4122        }
4123    }
4124}
4125
4126fn notify_membership_updated(
4127    connection_pool: &mut ConnectionPool,
4128    result: MembershipUpdated,
4129    user_id: UserId,
4130    peer: &Peer,
4131) {
4132    for membership in &result.new_channels.channel_memberships {
4133        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4134    }
4135    for channel_id in &result.removed_channels {
4136        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4137    }
4138
4139    let user_channels_update = proto::UpdateUserChannels {
4140        channel_memberships: result
4141            .new_channels
4142            .channel_memberships
4143            .iter()
4144            .map(|cm| proto::ChannelMembership {
4145                channel_id: cm.channel_id.to_proto(),
4146                role: cm.role.into(),
4147            })
4148            .collect(),
4149        ..Default::default()
4150    };
4151
4152    let mut update = build_channels_update(result.new_channels);
4153    update.delete_channels = result
4154        .removed_channels
4155        .into_iter()
4156        .map(|id| id.to_proto())
4157        .collect();
4158    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4159
4160    for connection_id in connection_pool.user_connection_ids(user_id) {
4161        peer.send(connection_id, user_channels_update.clone())
4162            .trace_err();
4163        peer.send(connection_id, update.clone()).trace_err();
4164    }
4165}
4166
4167fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4168    proto::UpdateUserChannels {
4169        channel_memberships: channels
4170            .channel_memberships
4171            .iter()
4172            .map(|m| proto::ChannelMembership {
4173                channel_id: m.channel_id.to_proto(),
4174                role: m.role.into(),
4175            })
4176            .collect(),
4177        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4178        observed_channel_message_id: channels.observed_channel_messages.clone(),
4179    }
4180}
4181
4182fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4183    let mut update = proto::UpdateChannels::default();
4184
4185    for channel in channels.channels {
4186        update.channels.push(channel.to_proto());
4187    }
4188
4189    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4190    update.latest_channel_message_ids = channels.latest_channel_messages;
4191
4192    for (channel_id, participants) in channels.channel_participants {
4193        update
4194            .channel_participants
4195            .push(proto::ChannelParticipants {
4196                channel_id: channel_id.to_proto(),
4197                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4198            });
4199    }
4200
4201    for channel in channels.invited_channels {
4202        update.channel_invitations.push(channel.to_proto());
4203    }
4204
4205    update
4206}
4207
4208fn build_initial_contacts_update(
4209    contacts: Vec<db::Contact>,
4210    pool: &ConnectionPool,
4211) -> proto::UpdateContacts {
4212    let mut update = proto::UpdateContacts::default();
4213
4214    for contact in contacts {
4215        match contact {
4216            db::Contact::Accepted { user_id, busy } => {
4217                update.contacts.push(contact_for_user(user_id, busy, pool));
4218            }
4219            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4220            db::Contact::Incoming { user_id } => {
4221                update
4222                    .incoming_requests
4223                    .push(proto::IncomingContactRequest {
4224                        requester_id: user_id.to_proto(),
4225                    })
4226            }
4227        }
4228    }
4229
4230    update
4231}
4232
4233fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4234    proto::Contact {
4235        user_id: user_id.to_proto(),
4236        online: pool.is_user_online(user_id),
4237        busy,
4238    }
4239}
4240
4241fn room_updated(room: &proto::Room, peer: &Peer) {
4242    broadcast(
4243        None,
4244        room.participants
4245            .iter()
4246            .filter_map(|participant| Some(participant.peer_id?.into())),
4247        |peer_id| {
4248            peer.send(
4249                peer_id,
4250                proto::RoomUpdated {
4251                    room: Some(room.clone()),
4252                },
4253            )
4254        },
4255    );
4256}
4257
4258fn channel_updated(
4259    channel: &db::channel::Model,
4260    room: &proto::Room,
4261    peer: &Peer,
4262    pool: &ConnectionPool,
4263) {
4264    let participants = room
4265        .participants
4266        .iter()
4267        .map(|p| p.user_id)
4268        .collect::<Vec<_>>();
4269
4270    broadcast(
4271        None,
4272        pool.channel_connection_ids(channel.root_id())
4273            .filter_map(|(channel_id, role)| {
4274                role.can_see_channel(channel.visibility)
4275                    .then_some(channel_id)
4276            }),
4277        |peer_id| {
4278            peer.send(
4279                peer_id,
4280                proto::UpdateChannels {
4281                    channel_participants: vec![proto::ChannelParticipants {
4282                        channel_id: channel.id.to_proto(),
4283                        participant_user_ids: participants.clone(),
4284                    }],
4285                    ..Default::default()
4286                },
4287            )
4288        },
4289    );
4290}
4291
4292async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4293    let db = session.db().await;
4294
4295    let contacts = db.get_contacts(user_id).await?;
4296    let busy = db.is_user_busy(user_id).await?;
4297
4298    let pool = session.connection_pool().await;
4299    let updated_contact = contact_for_user(user_id, busy, &pool);
4300    for contact in contacts {
4301        if let db::Contact::Accepted {
4302            user_id: contact_user_id,
4303            ..
4304        } = contact
4305        {
4306            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4307                session
4308                    .peer
4309                    .send(
4310                        contact_conn_id,
4311                        proto::UpdateContacts {
4312                            contacts: vec![updated_contact.clone()],
4313                            remove_contacts: Default::default(),
4314                            incoming_requests: Default::default(),
4315                            remove_incoming_requests: Default::default(),
4316                            outgoing_requests: Default::default(),
4317                            remove_outgoing_requests: Default::default(),
4318                        },
4319                    )
4320                    .trace_err();
4321            }
4322        }
4323    }
4324    Ok(())
4325}
4326
4327async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4328    let mut contacts_to_update = HashSet::default();
4329
4330    let room_id;
4331    let canceled_calls_to_user_ids;
4332    let livekit_room;
4333    let delete_livekit_room;
4334    let room;
4335    let channel;
4336
4337    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4338        contacts_to_update.insert(session.user_id());
4339
4340        for project in left_room.left_projects.values() {
4341            project_left(project, session);
4342        }
4343
4344        room_id = RoomId::from_proto(left_room.room.id);
4345        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4346        livekit_room = mem::take(&mut left_room.room.livekit_room);
4347        delete_livekit_room = left_room.deleted;
4348        room = mem::take(&mut left_room.room);
4349        channel = mem::take(&mut left_room.channel);
4350
4351        room_updated(&room, &session.peer);
4352    } else {
4353        return Ok(());
4354    }
4355
4356    if let Some(channel) = channel {
4357        channel_updated(
4358            &channel,
4359            &room,
4360            &session.peer,
4361            &*session.connection_pool().await,
4362        );
4363    }
4364
4365    {
4366        let pool = session.connection_pool().await;
4367        for canceled_user_id in canceled_calls_to_user_ids {
4368            for connection_id in pool.user_connection_ids(canceled_user_id) {
4369                session
4370                    .peer
4371                    .send(
4372                        connection_id,
4373                        proto::CallCanceled {
4374                            room_id: room_id.to_proto(),
4375                        },
4376                    )
4377                    .trace_err();
4378            }
4379            contacts_to_update.insert(canceled_user_id);
4380        }
4381    }
4382
4383    for contact_user_id in contacts_to_update {
4384        update_user_contacts(contact_user_id, session).await?;
4385    }
4386
4387    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4388        live_kit
4389            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4390            .await
4391            .trace_err();
4392
4393        if delete_livekit_room {
4394            live_kit.delete_room(livekit_room).await.trace_err();
4395        }
4396    }
4397
4398    Ok(())
4399}
4400
4401async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4402    let left_channel_buffers = session
4403        .db()
4404        .await
4405        .leave_channel_buffers(session.connection_id)
4406        .await?;
4407
4408    for left_buffer in left_channel_buffers {
4409        channel_buffer_updated(
4410            session.connection_id,
4411            left_buffer.connections,
4412            &proto::UpdateChannelBufferCollaborators {
4413                channel_id: left_buffer.channel_id.to_proto(),
4414                collaborators: left_buffer.collaborators,
4415            },
4416            &session.peer,
4417        );
4418    }
4419
4420    Ok(())
4421}
4422
4423fn project_left(project: &db::LeftProject, session: &Session) {
4424    for connection_id in &project.connection_ids {
4425        if project.should_unshare {
4426            session
4427                .peer
4428                .send(
4429                    *connection_id,
4430                    proto::UnshareProject {
4431                        project_id: project.id.to_proto(),
4432                    },
4433                )
4434                .trace_err();
4435        } else {
4436            session
4437                .peer
4438                .send(
4439                    *connection_id,
4440                    proto::RemoveProjectCollaborator {
4441                        project_id: project.id.to_proto(),
4442                        peer_id: Some(session.connection_id.into()),
4443                    },
4444                )
4445                .trace_err();
4446        }
4447    }
4448}
4449
4450pub trait ResultExt {
4451    type Ok;
4452
4453    fn trace_err(self) -> Option<Self::Ok>;
4454}
4455
4456impl<T, E> ResultExt for Result<T, E>
4457where
4458    E: std::fmt::Debug,
4459{
4460    type Ok = T;
4461
4462    #[track_caller]
4463    fn trace_err(self) -> Option<T> {
4464        match self {
4465            Ok(value) => Some(value),
4466            Err(error) => {
4467                tracing::error!("{:?}", error);
4468                None
4469            }
4470        }
4471    }
4472}