rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
 313            .add_request_handler(
 314                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 315            )
 316            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 317            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 318            .add_request_handler(
 319                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 320            )
 321            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 322            .add_request_handler(
 323                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 324            )
 325            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 326            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 327            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 330            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 332            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 336            .add_request_handler(
 337                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 338            )
 339            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 340            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 341            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 342            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 343            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 344            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 345            .add_message_handler(create_buffer_for_peer)
 346            .add_request_handler(update_buffer)
 347            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 352            .add_request_handler(get_users)
 353            .add_request_handler(fuzzy_search_users)
 354            .add_request_handler(request_contact)
 355            .add_request_handler(remove_contact)
 356            .add_request_handler(respond_to_contact_request)
 357            .add_message_handler(subscribe_to_channels)
 358            .add_request_handler(create_channel)
 359            .add_request_handler(delete_channel)
 360            .add_request_handler(invite_channel_member)
 361            .add_request_handler(remove_channel_member)
 362            .add_request_handler(set_channel_member_role)
 363            .add_request_handler(set_channel_visibility)
 364            .add_request_handler(rename_channel)
 365            .add_request_handler(join_channel_buffer)
 366            .add_request_handler(leave_channel_buffer)
 367            .add_message_handler(update_channel_buffer)
 368            .add_request_handler(rejoin_channel_buffers)
 369            .add_request_handler(get_channel_members)
 370            .add_request_handler(respond_to_channel_invite)
 371            .add_request_handler(join_channel)
 372            .add_request_handler(join_channel_chat)
 373            .add_message_handler(leave_channel_chat)
 374            .add_request_handler(send_channel_message)
 375            .add_request_handler(remove_channel_message)
 376            .add_request_handler(update_channel_message)
 377            .add_request_handler(get_channel_messages)
 378            .add_request_handler(get_channel_messages_by_id)
 379            .add_request_handler(get_notifications)
 380            .add_request_handler(mark_notification_as_read)
 381            .add_request_handler(move_channel)
 382            .add_request_handler(follow)
 383            .add_message_handler(unfollow)
 384            .add_message_handler(update_followers)
 385            .add_request_handler(get_private_user_info)
 386            .add_request_handler(get_llm_api_token)
 387            .add_request_handler(accept_terms_of_service)
 388            .add_message_handler(acknowledge_channel_message)
 389            .add_message_handler(acknowledge_buffer_version)
 390            .add_request_handler(get_supermaven_api_key)
 391            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 392            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 393            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 394            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 395            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 397            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 398            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 399            .add_message_handler(update_context)
 400            .add_request_handler({
 401                let app_state = app_state.clone();
 402                move |request, response, session| {
 403                    let app_state = app_state.clone();
 404                    async move {
 405                        count_language_model_tokens(request, response, session, &app_state.config)
 406                            .await
 407                    }
 408                }
 409            })
 410            .add_request_handler(get_cached_embeddings)
 411            .add_request_handler({
 412                let app_state = app_state.clone();
 413                move |request, response, session| {
 414                    compute_embeddings(
 415                        request,
 416                        response,
 417                        session,
 418                        app_state.config.openai_api_key.clone(),
 419                    )
 420                }
 421            });
 422
 423        Arc::new(server)
 424    }
 425
 426    pub async fn start(&self) -> Result<()> {
 427        let server_id = *self.id.lock();
 428        let app_state = self.app_state.clone();
 429        let peer = self.peer.clone();
 430        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 431        let pool = self.connection_pool.clone();
 432        let livekit_client = self.app_state.livekit_client.clone();
 433
 434        let span = info_span!("start server");
 435        self.app_state.executor.spawn_detached(
 436            async move {
 437                tracing::info!("waiting for cleanup timeout");
 438                timeout.await;
 439                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 440                if let Some((room_ids, channel_ids)) = app_state
 441                    .db
 442                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 443                    .await
 444                    .trace_err()
 445                {
 446                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 447                    tracing::info!(
 448                        stale_channel_buffer_count = channel_ids.len(),
 449                        "retrieved stale channel buffers"
 450                    );
 451
 452                    for channel_id in channel_ids {
 453                        if let Some(refreshed_channel_buffer) = app_state
 454                            .db
 455                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 456                            .await
 457                            .trace_err()
 458                        {
 459                            for connection_id in refreshed_channel_buffer.connection_ids {
 460                                peer.send(
 461                                    connection_id,
 462                                    proto::UpdateChannelBufferCollaborators {
 463                                        channel_id: channel_id.to_proto(),
 464                                        collaborators: refreshed_channel_buffer
 465                                            .collaborators
 466                                            .clone(),
 467                                    },
 468                                )
 469                                .trace_err();
 470                            }
 471                        }
 472                    }
 473
 474                    for room_id in room_ids {
 475                        let mut contacts_to_update = HashSet::default();
 476                        let mut canceled_calls_to_user_ids = Vec::new();
 477                        let mut livekit_room = String::new();
 478                        let mut delete_livekit_room = false;
 479
 480                        if let Some(mut refreshed_room) = app_state
 481                            .db
 482                            .clear_stale_room_participants(room_id, server_id)
 483                            .await
 484                            .trace_err()
 485                        {
 486                            tracing::info!(
 487                                room_id = room_id.0,
 488                                new_participant_count = refreshed_room.room.participants.len(),
 489                                "refreshed room"
 490                            );
 491                            room_updated(&refreshed_room.room, &peer);
 492                            if let Some(channel) = refreshed_room.channel.as_ref() {
 493                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 494                            }
 495                            contacts_to_update
 496                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 497                            contacts_to_update
 498                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 499                            canceled_calls_to_user_ids =
 500                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 501                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 502                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 503                        }
 504
 505                        {
 506                            let pool = pool.lock();
 507                            for canceled_user_id in canceled_calls_to_user_ids {
 508                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 509                                    peer.send(
 510                                        connection_id,
 511                                        proto::CallCanceled {
 512                                            room_id: room_id.to_proto(),
 513                                        },
 514                                    )
 515                                    .trace_err();
 516                                }
 517                            }
 518                        }
 519
 520                        for user_id in contacts_to_update {
 521                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 522                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 523                            if let Some((busy, contacts)) = busy.zip(contacts) {
 524                                let pool = pool.lock();
 525                                let updated_contact = contact_for_user(user_id, busy, &pool);
 526                                for contact in contacts {
 527                                    if let db::Contact::Accepted {
 528                                        user_id: contact_user_id,
 529                                        ..
 530                                    } = contact
 531                                    {
 532                                        for contact_conn_id in
 533                                            pool.user_connection_ids(contact_user_id)
 534                                        {
 535                                            peer.send(
 536                                                contact_conn_id,
 537                                                proto::UpdateContacts {
 538                                                    contacts: vec![updated_contact.clone()],
 539                                                    remove_contacts: Default::default(),
 540                                                    incoming_requests: Default::default(),
 541                                                    remove_incoming_requests: Default::default(),
 542                                                    outgoing_requests: Default::default(),
 543                                                    remove_outgoing_requests: Default::default(),
 544                                                },
 545                                            )
 546                                            .trace_err();
 547                                        }
 548                                    }
 549                                }
 550                            }
 551                        }
 552
 553                        if let Some(live_kit) = livekit_client.as_ref() {
 554                            if delete_livekit_room {
 555                                live_kit.delete_room(livekit_room).await.trace_err();
 556                            }
 557                        }
 558                    }
 559                }
 560
 561                app_state
 562                    .db
 563                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 564                    .await
 565                    .trace_err();
 566            }
 567            .instrument(span),
 568        );
 569        Ok(())
 570    }
 571
 572    pub fn teardown(&self) {
 573        self.peer.teardown();
 574        self.connection_pool.lock().reset();
 575        let _ = self.teardown.send(true);
 576    }
 577
 578    #[cfg(test)]
 579    pub fn reset(&self, id: ServerId) {
 580        self.teardown();
 581        *self.id.lock() = id;
 582        self.peer.reset(id.0 as u32);
 583        let _ = self.teardown.send(false);
 584    }
 585
 586    #[cfg(test)]
 587    pub fn id(&self) -> ServerId {
 588        *self.id.lock()
 589    }
 590
 591    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 592    where
 593        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 594        Fut: 'static + Send + Future<Output = Result<()>>,
 595        M: EnvelopedMessage,
 596    {
 597        let prev_handler = self.handlers.insert(
 598            TypeId::of::<M>(),
 599            Box::new(move |envelope, session| {
 600                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 601                let received_at = envelope.received_at;
 602                tracing::info!("message received");
 603                let start_time = Instant::now();
 604                let future = (handler)(*envelope, session);
 605                async move {
 606                    let result = future.await;
 607                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 608                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 609                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 610                    let payload_type = M::NAME;
 611
 612                    match result {
 613                        Err(error) => {
 614                            tracing::error!(
 615                                ?error,
 616                                total_duration_ms,
 617                                processing_duration_ms,
 618                                queue_duration_ms,
 619                                payload_type,
 620                                "error handling message"
 621                            )
 622                        }
 623                        Ok(()) => tracing::info!(
 624                            total_duration_ms,
 625                            processing_duration_ms,
 626                            queue_duration_ms,
 627                            "finished handling message"
 628                        ),
 629                    }
 630                }
 631                .boxed()
 632            }),
 633        );
 634        if prev_handler.is_some() {
 635            panic!("registered a handler for the same message twice");
 636        }
 637        self
 638    }
 639
 640    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 641    where
 642        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 643        Fut: 'static + Send + Future<Output = Result<()>>,
 644        M: EnvelopedMessage,
 645    {
 646        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 647        self
 648    }
 649
 650    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 651    where
 652        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 653        Fut: Send + Future<Output = Result<()>>,
 654        M: RequestMessage,
 655    {
 656        let handler = Arc::new(handler);
 657        self.add_handler(move |envelope, session| {
 658            let receipt = envelope.receipt();
 659            let handler = handler.clone();
 660            async move {
 661                let peer = session.peer.clone();
 662                let responded = Arc::new(AtomicBool::default());
 663                let response = Response {
 664                    peer: peer.clone(),
 665                    responded: responded.clone(),
 666                    receipt,
 667                };
 668                match (handler)(envelope.payload, response, session).await {
 669                    Ok(()) => {
 670                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 671                            Ok(())
 672                        } else {
 673                            Err(anyhow!("handler did not send a response"))?
 674                        }
 675                    }
 676                    Err(error) => {
 677                        let proto_err = match &error {
 678                            Error::Internal(err) => err.to_proto(),
 679                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 680                        };
 681                        peer.respond_with_error(receipt, proto_err)?;
 682                        Err(error)
 683                    }
 684                }
 685            }
 686        })
 687    }
 688
 689    #[allow(clippy::too_many_arguments)]
 690    pub fn handle_connection(
 691        self: &Arc<Self>,
 692        connection: Connection,
 693        address: String,
 694        principal: Principal,
 695        zed_version: ZedVersion,
 696        geoip_country_code: Option<String>,
 697        system_id: Option<String>,
 698        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 699        executor: Executor,
 700    ) -> impl Future<Output = ()> {
 701        let this = self.clone();
 702        let span = info_span!("handle connection", %address,
 703            connection_id=field::Empty,
 704            user_id=field::Empty,
 705            login=field::Empty,
 706            impersonator=field::Empty,
 707            geoip_country_code=field::Empty
 708        );
 709        principal.update_span(&span);
 710        if let Some(country_code) = geoip_country_code.as_ref() {
 711            span.record("geoip_country_code", country_code);
 712        }
 713
 714        let mut teardown = self.teardown.subscribe();
 715        async move {
 716            if *teardown.borrow() {
 717                tracing::error!("server is tearing down");
 718                return
 719            }
 720            let (connection_id, handle_io, mut incoming_rx) = this
 721                .peer
 722                .add_connection(connection, {
 723                    let executor = executor.clone();
 724                    move |duration| executor.sleep(duration)
 725                });
 726            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 727
 728            tracing::info!("connection opened");
 729
 730            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 731            let http_client = match ReqwestClient::user_agent(&user_agent) {
 732                Ok(http_client) => Arc::new(http_client),
 733                Err(error) => {
 734                    tracing::error!(?error, "failed to create HTTP client");
 735                    return;
 736                }
 737            };
 738
 739            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 740                    supermaven_admin_api_key.to_string(),
 741                    http_client.clone(),
 742                )));
 743
 744            let session = Session {
 745                principal: principal.clone(),
 746                connection_id,
 747                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 748                peer: this.peer.clone(),
 749                connection_pool: this.connection_pool.clone(),
 750                app_state: this.app_state.clone(),
 751                http_client,
 752                geoip_country_code,
 753                system_id,
 754                _executor: executor.clone(),
 755                supermaven_client,
 756            };
 757
 758            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 759                tracing::error!(?error, "failed to send initial client update");
 760                return;
 761            }
 762
 763            let handle_io = handle_io.fuse();
 764            futures::pin_mut!(handle_io);
 765
 766            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 767            // This prevents deadlocks when e.g., client A performs a request to client B and
 768            // client B performs a request to client A. If both clients stop processing further
 769            // messages until their respective request completes, they won't have a chance to
 770            // respond to the other client's request and cause a deadlock.
 771            //
 772            // This arrangement ensures we will attempt to process earlier messages first, but fall
 773            // back to processing messages arrived later in the spirit of making progress.
 774            let mut foreground_message_handlers = FuturesUnordered::new();
 775            let concurrent_handlers = Arc::new(Semaphore::new(256));
 776            loop {
 777                let next_message = async {
 778                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 779                    let message = incoming_rx.next().await;
 780                    (permit, message)
 781                }.fuse();
 782                futures::pin_mut!(next_message);
 783                futures::select_biased! {
 784                    _ = teardown.changed().fuse() => return,
 785                    result = handle_io => {
 786                        if let Err(error) = result {
 787                            tracing::error!(?error, "error handling I/O");
 788                        }
 789                        break;
 790                    }
 791                    _ = foreground_message_handlers.next() => {}
 792                    next_message = next_message => {
 793                        let (permit, message) = next_message;
 794                        if let Some(message) = message {
 795                            let type_name = message.payload_type_name();
 796                            // note: we copy all the fields from the parent span so we can query them in the logs.
 797                            // (https://github.com/tokio-rs/tracing/issues/2670).
 798                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 799                                user_id=field::Empty,
 800                                login=field::Empty,
 801                                impersonator=field::Empty,
 802                            );
 803                            principal.update_span(&span);
 804                            let span_enter = span.enter();
 805                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 806                                let is_background = message.is_background();
 807                                let handle_message = (handler)(message, session.clone());
 808                                drop(span_enter);
 809
 810                                let handle_message = async move {
 811                                    handle_message.await;
 812                                    drop(permit);
 813                                }.instrument(span);
 814                                if is_background {
 815                                    executor.spawn_detached(handle_message);
 816                                } else {
 817                                    foreground_message_handlers.push(handle_message);
 818                                }
 819                            } else {
 820                                tracing::error!("no message handler");
 821                            }
 822                        } else {
 823                            tracing::info!("connection closed");
 824                            break;
 825                        }
 826                    }
 827                }
 828            }
 829
 830            drop(foreground_message_handlers);
 831            tracing::info!("signing out");
 832            if let Err(error) = connection_lost(session, teardown, executor).await {
 833                tracing::error!(?error, "error signing out");
 834            }
 835
 836        }.instrument(span)
 837    }
 838
 839    async fn send_initial_client_update(
 840        &self,
 841        connection_id: ConnectionId,
 842        principal: &Principal,
 843        zed_version: ZedVersion,
 844        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 845        session: &Session,
 846    ) -> Result<()> {
 847        self.peer.send(
 848            connection_id,
 849            proto::Hello {
 850                peer_id: Some(connection_id.into()),
 851            },
 852        )?;
 853        tracing::info!("sent hello message");
 854        if let Some(send_connection_id) = send_connection_id.take() {
 855            let _ = send_connection_id.send(connection_id);
 856        }
 857
 858        match principal {
 859            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 860                if !user.connected_once {
 861                    self.peer.send(connection_id, proto::ShowContacts {})?;
 862                    self.app_state
 863                        .db
 864                        .set_user_connected_once(user.id, true)
 865                        .await?;
 866                }
 867
 868                update_user_plan(user.id, session).await?;
 869
 870                let contacts = self.app_state.db.get_contacts(user.id).await?;
 871
 872                {
 873                    let mut pool = self.connection_pool.lock();
 874                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 875                    self.peer.send(
 876                        connection_id,
 877                        build_initial_contacts_update(contacts, &pool),
 878                    )?;
 879                }
 880
 881                if should_auto_subscribe_to_channels(zed_version) {
 882                    subscribe_user_to_channels(user.id, session).await?;
 883                }
 884
 885                if let Some(incoming_call) =
 886                    self.app_state.db.incoming_call_for_user(user.id).await?
 887                {
 888                    self.peer.send(connection_id, incoming_call)?;
 889                }
 890
 891                update_user_contacts(user.id, session).await?;
 892            }
 893        }
 894
 895        Ok(())
 896    }
 897
 898    pub async fn invite_code_redeemed(
 899        self: &Arc<Self>,
 900        inviter_id: UserId,
 901        invitee_id: UserId,
 902    ) -> Result<()> {
 903        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 904            if let Some(code) = &user.invite_code {
 905                let pool = self.connection_pool.lock();
 906                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 907                for connection_id in pool.user_connection_ids(inviter_id) {
 908                    self.peer.send(
 909                        connection_id,
 910                        proto::UpdateContacts {
 911                            contacts: vec![invitee_contact.clone()],
 912                            ..Default::default()
 913                        },
 914                    )?;
 915                    self.peer.send(
 916                        connection_id,
 917                        proto::UpdateInviteInfo {
 918                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 919                            count: user.invite_count as u32,
 920                        },
 921                    )?;
 922                }
 923            }
 924        }
 925        Ok(())
 926    }
 927
 928    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 929        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 930            if let Some(invite_code) = &user.invite_code {
 931                let pool = self.connection_pool.lock();
 932                for connection_id in pool.user_connection_ids(user_id) {
 933                    self.peer.send(
 934                        connection_id,
 935                        proto::UpdateInviteInfo {
 936                            url: format!(
 937                                "{}{}",
 938                                self.app_state.config.invite_link_prefix, invite_code
 939                            ),
 940                            count: user.invite_count as u32,
 941                        },
 942                    )?;
 943                }
 944            }
 945        }
 946        Ok(())
 947    }
 948
 949    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 950        let pool = self.connection_pool.lock();
 951        for connection_id in pool.user_connection_ids(user_id) {
 952            self.peer
 953                .send(connection_id, proto::RefreshLlmToken {})
 954                .trace_err();
 955        }
 956    }
 957
 958    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 959        ServerSnapshot {
 960            connection_pool: ConnectionPoolGuard {
 961                guard: self.connection_pool.lock(),
 962                _not_send: PhantomData,
 963            },
 964            peer: &self.peer,
 965        }
 966    }
 967}
 968
 969impl<'a> Deref for ConnectionPoolGuard<'a> {
 970    type Target = ConnectionPool;
 971
 972    fn deref(&self) -> &Self::Target {
 973        &self.guard
 974    }
 975}
 976
 977impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 978    fn deref_mut(&mut self) -> &mut Self::Target {
 979        &mut self.guard
 980    }
 981}
 982
 983impl<'a> Drop for ConnectionPoolGuard<'a> {
 984    fn drop(&mut self) {
 985        #[cfg(test)]
 986        self.check_invariants();
 987    }
 988}
 989
 990fn broadcast<F>(
 991    sender_id: Option<ConnectionId>,
 992    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 993    mut f: F,
 994) where
 995    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 996{
 997    for receiver_id in receiver_ids {
 998        if Some(receiver_id) != sender_id {
 999            if let Err(error) = f(receiver_id) {
1000                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1001            }
1002        }
1003    }
1004}
1005
1006pub struct ProtocolVersion(u32);
1007
1008impl Header for ProtocolVersion {
1009    fn name() -> &'static HeaderName {
1010        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1011        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1012    }
1013
1014    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1015    where
1016        Self: Sized,
1017        I: Iterator<Item = &'i axum::http::HeaderValue>,
1018    {
1019        let version = values
1020            .next()
1021            .ok_or_else(axum::headers::Error::invalid)?
1022            .to_str()
1023            .map_err(|_| axum::headers::Error::invalid())?
1024            .parse()
1025            .map_err(|_| axum::headers::Error::invalid())?;
1026        Ok(Self(version))
1027    }
1028
1029    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1030        values.extend([self.0.to_string().parse().unwrap()]);
1031    }
1032}
1033
1034pub struct AppVersionHeader(SemanticVersion);
1035impl Header for AppVersionHeader {
1036    fn name() -> &'static HeaderName {
1037        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1038        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1039    }
1040
1041    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1042    where
1043        Self: Sized,
1044        I: Iterator<Item = &'i axum::http::HeaderValue>,
1045    {
1046        let version = values
1047            .next()
1048            .ok_or_else(axum::headers::Error::invalid)?
1049            .to_str()
1050            .map_err(|_| axum::headers::Error::invalid())?
1051            .parse()
1052            .map_err(|_| axum::headers::Error::invalid())?;
1053        Ok(Self(version))
1054    }
1055
1056    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1057        values.extend([self.0.to_string().parse().unwrap()]);
1058    }
1059}
1060
1061pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1062    Router::new()
1063        .route("/rpc", get(handle_websocket_request))
1064        .layer(
1065            ServiceBuilder::new()
1066                .layer(Extension(server.app_state.clone()))
1067                .layer(middleware::from_fn(auth::validate_header)),
1068        )
1069        .route("/metrics", get(handle_metrics))
1070        .layer(Extension(server))
1071}
1072
1073#[allow(clippy::too_many_arguments)]
1074pub async fn handle_websocket_request(
1075    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1076    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1077    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1078    Extension(server): Extension<Arc<Server>>,
1079    Extension(principal): Extension<Principal>,
1080    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1081    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1082    ws: WebSocketUpgrade,
1083) -> axum::response::Response {
1084    if protocol_version != rpc::PROTOCOL_VERSION {
1085        return (
1086            StatusCode::UPGRADE_REQUIRED,
1087            "client must be upgraded".to_string(),
1088        )
1089            .into_response();
1090    }
1091
1092    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1093        return (
1094            StatusCode::UPGRADE_REQUIRED,
1095            "no version header found".to_string(),
1096        )
1097            .into_response();
1098    };
1099
1100    if !version.can_collaborate() {
1101        return (
1102            StatusCode::UPGRADE_REQUIRED,
1103            "client must be upgraded".to_string(),
1104        )
1105            .into_response();
1106    }
1107
1108    let socket_address = socket_address.to_string();
1109    ws.on_upgrade(move |socket| {
1110        let socket = socket
1111            .map_ok(to_tungstenite_message)
1112            .err_into()
1113            .with(|message| async move { to_axum_message(message) });
1114        let connection = Connection::new(Box::pin(socket));
1115        async move {
1116            server
1117                .handle_connection(
1118                    connection,
1119                    socket_address,
1120                    principal,
1121                    version,
1122                    country_code_header.map(|header| header.to_string()),
1123                    system_id_header.map(|header| header.to_string()),
1124                    None,
1125                    Executor::Production,
1126                )
1127                .await;
1128        }
1129    })
1130}
1131
1132pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1133    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1134    let connections_metric = CONNECTIONS_METRIC
1135        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1136
1137    let connections = server
1138        .connection_pool
1139        .lock()
1140        .connections()
1141        .filter(|connection| !connection.admin)
1142        .count();
1143    connections_metric.set(connections as _);
1144
1145    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1146    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1147        register_int_gauge!(
1148            "shared_projects",
1149            "number of open projects with one or more guests"
1150        )
1151        .unwrap()
1152    });
1153
1154    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1155    shared_projects_metric.set(shared_projects as _);
1156
1157    let encoder = prometheus::TextEncoder::new();
1158    let metric_families = prometheus::gather();
1159    let encoded_metrics = encoder
1160        .encode_to_string(&metric_families)
1161        .map_err(|err| anyhow!("{}", err))?;
1162    Ok(encoded_metrics)
1163}
1164
1165#[instrument(err, skip(executor))]
1166async fn connection_lost(
1167    session: Session,
1168    mut teardown: watch::Receiver<bool>,
1169    executor: Executor,
1170) -> Result<()> {
1171    session.peer.disconnect(session.connection_id);
1172    session
1173        .connection_pool()
1174        .await
1175        .remove_connection(session.connection_id)?;
1176
1177    session
1178        .db()
1179        .await
1180        .connection_lost(session.connection_id)
1181        .await
1182        .trace_err();
1183
1184    futures::select_biased! {
1185        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1186
1187            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1188            leave_room_for_session(&session, session.connection_id).await.trace_err();
1189            leave_channel_buffers_for_session(&session)
1190                .await
1191                .trace_err();
1192
1193            if !session
1194                .connection_pool()
1195                .await
1196                .is_user_online(session.user_id())
1197            {
1198                let db = session.db().await;
1199                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1200                    room_updated(&room, &session.peer);
1201                }
1202            }
1203
1204            update_user_contacts(session.user_id(), &session).await?;
1205        },
1206        _ = teardown.changed().fuse() => {}
1207    }
1208
1209    Ok(())
1210}
1211
1212/// Acknowledges a ping from a client, used to keep the connection alive.
1213async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1214    response.send(proto::Ack {})?;
1215    Ok(())
1216}
1217
1218/// Creates a new room for calling (outside of channels)
1219async fn create_room(
1220    _request: proto::CreateRoom,
1221    response: Response<proto::CreateRoom>,
1222    session: Session,
1223) -> Result<()> {
1224    let livekit_room = nanoid::nanoid!(30);
1225
1226    let live_kit_connection_info = util::maybe!(async {
1227        let live_kit = session.app_state.livekit_client.as_ref();
1228        let live_kit = live_kit?;
1229        let user_id = session.user_id().to_string();
1230
1231        let token = live_kit
1232            .room_token(&livekit_room, &user_id.to_string())
1233            .trace_err()?;
1234
1235        Some(proto::LiveKitConnectionInfo {
1236            server_url: live_kit.url().into(),
1237            token,
1238            can_publish: true,
1239        })
1240    })
1241    .await;
1242
1243    let room = session
1244        .db()
1245        .await
1246        .create_room(session.user_id(), session.connection_id, &livekit_room)
1247        .await?;
1248
1249    response.send(proto::CreateRoomResponse {
1250        room: Some(room.clone()),
1251        live_kit_connection_info,
1252    })?;
1253
1254    update_user_contacts(session.user_id(), &session).await?;
1255    Ok(())
1256}
1257
1258/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1259async fn join_room(
1260    request: proto::JoinRoom,
1261    response: Response<proto::JoinRoom>,
1262    session: Session,
1263) -> Result<()> {
1264    let room_id = RoomId::from_proto(request.id);
1265
1266    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1267
1268    if let Some(channel_id) = channel_id {
1269        return join_channel_internal(channel_id, Box::new(response), session).await;
1270    }
1271
1272    let joined_room = {
1273        let room = session
1274            .db()
1275            .await
1276            .join_room(room_id, session.user_id(), session.connection_id)
1277            .await?;
1278        room_updated(&room.room, &session.peer);
1279        room.into_inner()
1280    };
1281
1282    for connection_id in session
1283        .connection_pool()
1284        .await
1285        .user_connection_ids(session.user_id())
1286    {
1287        session
1288            .peer
1289            .send(
1290                connection_id,
1291                proto::CallCanceled {
1292                    room_id: room_id.to_proto(),
1293                },
1294            )
1295            .trace_err();
1296    }
1297
1298    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1299    {
1300        live_kit
1301            .room_token(
1302                &joined_room.room.livekit_room,
1303                &session.user_id().to_string(),
1304            )
1305            .trace_err()
1306            .map(|token| proto::LiveKitConnectionInfo {
1307                server_url: live_kit.url().into(),
1308                token,
1309                can_publish: true,
1310            })
1311    } else {
1312        None
1313    };
1314
1315    response.send(proto::JoinRoomResponse {
1316        room: Some(joined_room.room),
1317        channel_id: None,
1318        live_kit_connection_info,
1319    })?;
1320
1321    update_user_contacts(session.user_id(), &session).await?;
1322    Ok(())
1323}
1324
1325/// Rejoin room is used to reconnect to a room after connection errors.
1326async fn rejoin_room(
1327    request: proto::RejoinRoom,
1328    response: Response<proto::RejoinRoom>,
1329    session: Session,
1330) -> Result<()> {
1331    let room;
1332    let channel;
1333    {
1334        let mut rejoined_room = session
1335            .db()
1336            .await
1337            .rejoin_room(request, session.user_id(), session.connection_id)
1338            .await?;
1339
1340        response.send(proto::RejoinRoomResponse {
1341            room: Some(rejoined_room.room.clone()),
1342            reshared_projects: rejoined_room
1343                .reshared_projects
1344                .iter()
1345                .map(|project| proto::ResharedProject {
1346                    id: project.id.to_proto(),
1347                    collaborators: project
1348                        .collaborators
1349                        .iter()
1350                        .map(|collaborator| collaborator.to_proto())
1351                        .collect(),
1352                })
1353                .collect(),
1354            rejoined_projects: rejoined_room
1355                .rejoined_projects
1356                .iter()
1357                .map(|rejoined_project| rejoined_project.to_proto())
1358                .collect(),
1359        })?;
1360        room_updated(&rejoined_room.room, &session.peer);
1361
1362        for project in &rejoined_room.reshared_projects {
1363            for collaborator in &project.collaborators {
1364                session
1365                    .peer
1366                    .send(
1367                        collaborator.connection_id,
1368                        proto::UpdateProjectCollaborator {
1369                            project_id: project.id.to_proto(),
1370                            old_peer_id: Some(project.old_connection_id.into()),
1371                            new_peer_id: Some(session.connection_id.into()),
1372                        },
1373                    )
1374                    .trace_err();
1375            }
1376
1377            broadcast(
1378                Some(session.connection_id),
1379                project
1380                    .collaborators
1381                    .iter()
1382                    .map(|collaborator| collaborator.connection_id),
1383                |connection_id| {
1384                    session.peer.forward_send(
1385                        session.connection_id,
1386                        connection_id,
1387                        proto::UpdateProject {
1388                            project_id: project.id.to_proto(),
1389                            worktrees: project.worktrees.clone(),
1390                        },
1391                    )
1392                },
1393            );
1394        }
1395
1396        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1397
1398        let rejoined_room = rejoined_room.into_inner();
1399
1400        room = rejoined_room.room;
1401        channel = rejoined_room.channel;
1402    }
1403
1404    if let Some(channel) = channel {
1405        channel_updated(
1406            &channel,
1407            &room,
1408            &session.peer,
1409            &*session.connection_pool().await,
1410        );
1411    }
1412
1413    update_user_contacts(session.user_id(), &session).await?;
1414    Ok(())
1415}
1416
1417fn notify_rejoined_projects(
1418    rejoined_projects: &mut Vec<RejoinedProject>,
1419    session: &Session,
1420) -> Result<()> {
1421    for project in rejoined_projects.iter() {
1422        for collaborator in &project.collaborators {
1423            session
1424                .peer
1425                .send(
1426                    collaborator.connection_id,
1427                    proto::UpdateProjectCollaborator {
1428                        project_id: project.id.to_proto(),
1429                        old_peer_id: Some(project.old_connection_id.into()),
1430                        new_peer_id: Some(session.connection_id.into()),
1431                    },
1432                )
1433                .trace_err();
1434        }
1435    }
1436
1437    for project in rejoined_projects {
1438        for worktree in mem::take(&mut project.worktrees) {
1439            // Stream this worktree's entries.
1440            let message = proto::UpdateWorktree {
1441                project_id: project.id.to_proto(),
1442                worktree_id: worktree.id,
1443                abs_path: worktree.abs_path.clone(),
1444                root_name: worktree.root_name,
1445                updated_entries: worktree.updated_entries,
1446                removed_entries: worktree.removed_entries,
1447                scan_id: worktree.scan_id,
1448                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1449                updated_repositories: worktree.updated_repositories,
1450                removed_repositories: worktree.removed_repositories,
1451            };
1452            for update in proto::split_worktree_update(message) {
1453                session.peer.send(session.connection_id, update.clone())?;
1454            }
1455
1456            // Stream this worktree's diagnostics.
1457            for summary in worktree.diagnostic_summaries {
1458                session.peer.send(
1459                    session.connection_id,
1460                    proto::UpdateDiagnosticSummary {
1461                        project_id: project.id.to_proto(),
1462                        worktree_id: worktree.id,
1463                        summary: Some(summary),
1464                    },
1465                )?;
1466            }
1467
1468            for settings_file in worktree.settings_files {
1469                session.peer.send(
1470                    session.connection_id,
1471                    proto::UpdateWorktreeSettings {
1472                        project_id: project.id.to_proto(),
1473                        worktree_id: worktree.id,
1474                        path: settings_file.path,
1475                        content: Some(settings_file.content),
1476                        kind: Some(settings_file.kind.to_proto().into()),
1477                    },
1478                )?;
1479            }
1480        }
1481
1482        for language_server in &project.language_servers {
1483            session.peer.send(
1484                session.connection_id,
1485                proto::UpdateLanguageServer {
1486                    project_id: project.id.to_proto(),
1487                    language_server_id: language_server.id,
1488                    variant: Some(
1489                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1490                            proto::LspDiskBasedDiagnosticsUpdated {},
1491                        ),
1492                    ),
1493                },
1494            )?;
1495        }
1496    }
1497    Ok(())
1498}
1499
1500/// leave room disconnects from the room.
1501async fn leave_room(
1502    _: proto::LeaveRoom,
1503    response: Response<proto::LeaveRoom>,
1504    session: Session,
1505) -> Result<()> {
1506    leave_room_for_session(&session, session.connection_id).await?;
1507    response.send(proto::Ack {})?;
1508    Ok(())
1509}
1510
1511/// Updates the permissions of someone else in the room.
1512async fn set_room_participant_role(
1513    request: proto::SetRoomParticipantRole,
1514    response: Response<proto::SetRoomParticipantRole>,
1515    session: Session,
1516) -> Result<()> {
1517    let user_id = UserId::from_proto(request.user_id);
1518    let role = ChannelRole::from(request.role());
1519
1520    let (livekit_room, can_publish) = {
1521        let room = session
1522            .db()
1523            .await
1524            .set_room_participant_role(
1525                session.user_id(),
1526                RoomId::from_proto(request.room_id),
1527                user_id,
1528                role,
1529            )
1530            .await?;
1531
1532        let livekit_room = room.livekit_room.clone();
1533        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1534        room_updated(&room, &session.peer);
1535        (livekit_room, can_publish)
1536    };
1537
1538    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1539        live_kit
1540            .update_participant(
1541                livekit_room.clone(),
1542                request.user_id.to_string(),
1543                livekit_server::proto::ParticipantPermission {
1544                    can_subscribe: true,
1545                    can_publish,
1546                    can_publish_data: can_publish,
1547                    hidden: false,
1548                    recorder: false,
1549                },
1550            )
1551            .await
1552            .trace_err();
1553    }
1554
1555    response.send(proto::Ack {})?;
1556    Ok(())
1557}
1558
1559/// Call someone else into the current room
1560async fn call(
1561    request: proto::Call,
1562    response: Response<proto::Call>,
1563    session: Session,
1564) -> Result<()> {
1565    let room_id = RoomId::from_proto(request.room_id);
1566    let calling_user_id = session.user_id();
1567    let calling_connection_id = session.connection_id;
1568    let called_user_id = UserId::from_proto(request.called_user_id);
1569    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1570    if !session
1571        .db()
1572        .await
1573        .has_contact(calling_user_id, called_user_id)
1574        .await?
1575    {
1576        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1577    }
1578
1579    let incoming_call = {
1580        let (room, incoming_call) = &mut *session
1581            .db()
1582            .await
1583            .call(
1584                room_id,
1585                calling_user_id,
1586                calling_connection_id,
1587                called_user_id,
1588                initial_project_id,
1589            )
1590            .await?;
1591        room_updated(room, &session.peer);
1592        mem::take(incoming_call)
1593    };
1594    update_user_contacts(called_user_id, &session).await?;
1595
1596    let mut calls = session
1597        .connection_pool()
1598        .await
1599        .user_connection_ids(called_user_id)
1600        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1601        .collect::<FuturesUnordered<_>>();
1602
1603    while let Some(call_response) = calls.next().await {
1604        match call_response.as_ref() {
1605            Ok(_) => {
1606                response.send(proto::Ack {})?;
1607                return Ok(());
1608            }
1609            Err(_) => {
1610                call_response.trace_err();
1611            }
1612        }
1613    }
1614
1615    {
1616        let room = session
1617            .db()
1618            .await
1619            .call_failed(room_id, called_user_id)
1620            .await?;
1621        room_updated(&room, &session.peer);
1622    }
1623    update_user_contacts(called_user_id, &session).await?;
1624
1625    Err(anyhow!("failed to ring user"))?
1626}
1627
1628/// Cancel an outgoing call.
1629async fn cancel_call(
1630    request: proto::CancelCall,
1631    response: Response<proto::CancelCall>,
1632    session: Session,
1633) -> Result<()> {
1634    let called_user_id = UserId::from_proto(request.called_user_id);
1635    let room_id = RoomId::from_proto(request.room_id);
1636    {
1637        let room = session
1638            .db()
1639            .await
1640            .cancel_call(room_id, session.connection_id, called_user_id)
1641            .await?;
1642        room_updated(&room, &session.peer);
1643    }
1644
1645    for connection_id in session
1646        .connection_pool()
1647        .await
1648        .user_connection_ids(called_user_id)
1649    {
1650        session
1651            .peer
1652            .send(
1653                connection_id,
1654                proto::CallCanceled {
1655                    room_id: room_id.to_proto(),
1656                },
1657            )
1658            .trace_err();
1659    }
1660    response.send(proto::Ack {})?;
1661
1662    update_user_contacts(called_user_id, &session).await?;
1663    Ok(())
1664}
1665
1666/// Decline an incoming call.
1667async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1668    let room_id = RoomId::from_proto(message.room_id);
1669    {
1670        let room = session
1671            .db()
1672            .await
1673            .decline_call(Some(room_id), session.user_id())
1674            .await?
1675            .ok_or_else(|| anyhow!("failed to decline call"))?;
1676        room_updated(&room, &session.peer);
1677    }
1678
1679    for connection_id in session
1680        .connection_pool()
1681        .await
1682        .user_connection_ids(session.user_id())
1683    {
1684        session
1685            .peer
1686            .send(
1687                connection_id,
1688                proto::CallCanceled {
1689                    room_id: room_id.to_proto(),
1690                },
1691            )
1692            .trace_err();
1693    }
1694    update_user_contacts(session.user_id(), &session).await?;
1695    Ok(())
1696}
1697
1698/// Updates other participants in the room with your current location.
1699async fn update_participant_location(
1700    request: proto::UpdateParticipantLocation,
1701    response: Response<proto::UpdateParticipantLocation>,
1702    session: Session,
1703) -> Result<()> {
1704    let room_id = RoomId::from_proto(request.room_id);
1705    let location = request
1706        .location
1707        .ok_or_else(|| anyhow!("invalid location"))?;
1708
1709    let db = session.db().await;
1710    let room = db
1711        .update_room_participant_location(room_id, session.connection_id, location)
1712        .await?;
1713
1714    room_updated(&room, &session.peer);
1715    response.send(proto::Ack {})?;
1716    Ok(())
1717}
1718
1719/// Share a project into the room.
1720async fn share_project(
1721    request: proto::ShareProject,
1722    response: Response<proto::ShareProject>,
1723    session: Session,
1724) -> Result<()> {
1725    let (project_id, room) = &*session
1726        .db()
1727        .await
1728        .share_project(
1729            RoomId::from_proto(request.room_id),
1730            session.connection_id,
1731            &request.worktrees,
1732            request.is_ssh_project,
1733        )
1734        .await?;
1735    response.send(proto::ShareProjectResponse {
1736        project_id: project_id.to_proto(),
1737    })?;
1738    room_updated(room, &session.peer);
1739
1740    Ok(())
1741}
1742
1743/// Unshare a project from the room.
1744async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1745    let project_id = ProjectId::from_proto(message.project_id);
1746    unshare_project_internal(project_id, session.connection_id, &session).await
1747}
1748
1749async fn unshare_project_internal(
1750    project_id: ProjectId,
1751    connection_id: ConnectionId,
1752    session: &Session,
1753) -> Result<()> {
1754    let delete = {
1755        let room_guard = session
1756            .db()
1757            .await
1758            .unshare_project(project_id, connection_id)
1759            .await?;
1760
1761        let (delete, room, guest_connection_ids) = &*room_guard;
1762
1763        let message = proto::UnshareProject {
1764            project_id: project_id.to_proto(),
1765        };
1766
1767        broadcast(
1768            Some(connection_id),
1769            guest_connection_ids.iter().copied(),
1770            |conn_id| session.peer.send(conn_id, message.clone()),
1771        );
1772        if let Some(room) = room {
1773            room_updated(room, &session.peer);
1774        }
1775
1776        *delete
1777    };
1778
1779    if delete {
1780        let db = session.db().await;
1781        db.delete_project(project_id).await?;
1782    }
1783
1784    Ok(())
1785}
1786
1787/// Join someone elses shared project.
1788async fn join_project(
1789    request: proto::JoinProject,
1790    response: Response<proto::JoinProject>,
1791    session: Session,
1792) -> Result<()> {
1793    let project_id = ProjectId::from_proto(request.project_id);
1794
1795    tracing::info!(%project_id, "join project");
1796
1797    let db = session.db().await;
1798    let (project, replica_id) = &mut *db
1799        .join_project(project_id, session.connection_id, session.user_id())
1800        .await?;
1801    drop(db);
1802    tracing::info!(%project_id, "join remote project");
1803    join_project_internal(response, session, project, replica_id)
1804}
1805
1806trait JoinProjectInternalResponse {
1807    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1808}
1809impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1810    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1811        Response::<proto::JoinProject>::send(self, result)
1812    }
1813}
1814
1815fn join_project_internal(
1816    response: impl JoinProjectInternalResponse,
1817    session: Session,
1818    project: &mut Project,
1819    replica_id: &ReplicaId,
1820) -> Result<()> {
1821    let collaborators = project
1822        .collaborators
1823        .iter()
1824        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1825        .map(|collaborator| collaborator.to_proto())
1826        .collect::<Vec<_>>();
1827    let project_id = project.id;
1828    let guest_user_id = session.user_id();
1829
1830    let worktrees = project
1831        .worktrees
1832        .iter()
1833        .map(|(id, worktree)| proto::WorktreeMetadata {
1834            id: *id,
1835            root_name: worktree.root_name.clone(),
1836            visible: worktree.visible,
1837            abs_path: worktree.abs_path.clone(),
1838        })
1839        .collect::<Vec<_>>();
1840
1841    let add_project_collaborator = proto::AddProjectCollaborator {
1842        project_id: project_id.to_proto(),
1843        collaborator: Some(proto::Collaborator {
1844            peer_id: Some(session.connection_id.into()),
1845            replica_id: replica_id.0 as u32,
1846            user_id: guest_user_id.to_proto(),
1847            is_host: false,
1848        }),
1849    };
1850
1851    for collaborator in &collaborators {
1852        session
1853            .peer
1854            .send(
1855                collaborator.peer_id.unwrap().into(),
1856                add_project_collaborator.clone(),
1857            )
1858            .trace_err();
1859    }
1860
1861    // First, we send the metadata associated with each worktree.
1862    response.send(proto::JoinProjectResponse {
1863        project_id: project.id.0 as u64,
1864        worktrees: worktrees.clone(),
1865        replica_id: replica_id.0 as u32,
1866        collaborators: collaborators.clone(),
1867        language_servers: project.language_servers.clone(),
1868        role: project.role.into(),
1869    })?;
1870
1871    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1872        // Stream this worktree's entries.
1873        let message = proto::UpdateWorktree {
1874            project_id: project_id.to_proto(),
1875            worktree_id,
1876            abs_path: worktree.abs_path.clone(),
1877            root_name: worktree.root_name,
1878            updated_entries: worktree.entries,
1879            removed_entries: Default::default(),
1880            scan_id: worktree.scan_id,
1881            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1882            updated_repositories: worktree.repository_entries.into_values().collect(),
1883            removed_repositories: Default::default(),
1884        };
1885        for update in proto::split_worktree_update(message) {
1886            session.peer.send(session.connection_id, update.clone())?;
1887        }
1888
1889        // Stream this worktree's diagnostics.
1890        for summary in worktree.diagnostic_summaries {
1891            session.peer.send(
1892                session.connection_id,
1893                proto::UpdateDiagnosticSummary {
1894                    project_id: project_id.to_proto(),
1895                    worktree_id: worktree.id,
1896                    summary: Some(summary),
1897                },
1898            )?;
1899        }
1900
1901        for settings_file in worktree.settings_files {
1902            session.peer.send(
1903                session.connection_id,
1904                proto::UpdateWorktreeSettings {
1905                    project_id: project_id.to_proto(),
1906                    worktree_id: worktree.id,
1907                    path: settings_file.path,
1908                    content: Some(settings_file.content),
1909                    kind: Some(settings_file.kind.to_proto() as i32),
1910                },
1911            )?;
1912        }
1913    }
1914
1915    for language_server in &project.language_servers {
1916        session.peer.send(
1917            session.connection_id,
1918            proto::UpdateLanguageServer {
1919                project_id: project_id.to_proto(),
1920                language_server_id: language_server.id,
1921                variant: Some(
1922                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1923                        proto::LspDiskBasedDiagnosticsUpdated {},
1924                    ),
1925                ),
1926            },
1927        )?;
1928    }
1929
1930    Ok(())
1931}
1932
1933/// Leave someone elses shared project.
1934async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1935    let sender_id = session.connection_id;
1936    let project_id = ProjectId::from_proto(request.project_id);
1937    let db = session.db().await;
1938
1939    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1940    tracing::info!(
1941        %project_id,
1942        "leave project"
1943    );
1944
1945    project_left(project, &session);
1946    if let Some(room) = room {
1947        room_updated(room, &session.peer);
1948    }
1949
1950    Ok(())
1951}
1952
1953/// Updates other participants with changes to the project
1954async fn update_project(
1955    request: proto::UpdateProject,
1956    response: Response<proto::UpdateProject>,
1957    session: Session,
1958) -> Result<()> {
1959    let project_id = ProjectId::from_proto(request.project_id);
1960    let (room, guest_connection_ids) = &*session
1961        .db()
1962        .await
1963        .update_project(project_id, session.connection_id, &request.worktrees)
1964        .await?;
1965    broadcast(
1966        Some(session.connection_id),
1967        guest_connection_ids.iter().copied(),
1968        |connection_id| {
1969            session
1970                .peer
1971                .forward_send(session.connection_id, connection_id, request.clone())
1972        },
1973    );
1974    if let Some(room) = room {
1975        room_updated(room, &session.peer);
1976    }
1977    response.send(proto::Ack {})?;
1978
1979    Ok(())
1980}
1981
1982/// Updates other participants with changes to the worktree
1983async fn update_worktree(
1984    request: proto::UpdateWorktree,
1985    response: Response<proto::UpdateWorktree>,
1986    session: Session,
1987) -> Result<()> {
1988    let guest_connection_ids = session
1989        .db()
1990        .await
1991        .update_worktree(&request, session.connection_id)
1992        .await?;
1993
1994    broadcast(
1995        Some(session.connection_id),
1996        guest_connection_ids.iter().copied(),
1997        |connection_id| {
1998            session
1999                .peer
2000                .forward_send(session.connection_id, connection_id, request.clone())
2001        },
2002    );
2003    response.send(proto::Ack {})?;
2004    Ok(())
2005}
2006
2007/// Updates other participants with changes to the diagnostics
2008async fn update_diagnostic_summary(
2009    message: proto::UpdateDiagnosticSummary,
2010    session: Session,
2011) -> Result<()> {
2012    let guest_connection_ids = session
2013        .db()
2014        .await
2015        .update_diagnostic_summary(&message, session.connection_id)
2016        .await?;
2017
2018    broadcast(
2019        Some(session.connection_id),
2020        guest_connection_ids.iter().copied(),
2021        |connection_id| {
2022            session
2023                .peer
2024                .forward_send(session.connection_id, connection_id, message.clone())
2025        },
2026    );
2027
2028    Ok(())
2029}
2030
2031/// Updates other participants with changes to the worktree settings
2032async fn update_worktree_settings(
2033    message: proto::UpdateWorktreeSettings,
2034    session: Session,
2035) -> Result<()> {
2036    let guest_connection_ids = session
2037        .db()
2038        .await
2039        .update_worktree_settings(&message, session.connection_id)
2040        .await?;
2041
2042    broadcast(
2043        Some(session.connection_id),
2044        guest_connection_ids.iter().copied(),
2045        |connection_id| {
2046            session
2047                .peer
2048                .forward_send(session.connection_id, connection_id, message.clone())
2049        },
2050    );
2051
2052    Ok(())
2053}
2054
2055/// Notify other participants that a  language server has started.
2056async fn start_language_server(
2057    request: proto::StartLanguageServer,
2058    session: Session,
2059) -> Result<()> {
2060    let guest_connection_ids = session
2061        .db()
2062        .await
2063        .start_language_server(&request, session.connection_id)
2064        .await?;
2065
2066    broadcast(
2067        Some(session.connection_id),
2068        guest_connection_ids.iter().copied(),
2069        |connection_id| {
2070            session
2071                .peer
2072                .forward_send(session.connection_id, connection_id, request.clone())
2073        },
2074    );
2075    Ok(())
2076}
2077
2078/// Notify other participants that a language server has changed.
2079async fn update_language_server(
2080    request: proto::UpdateLanguageServer,
2081    session: Session,
2082) -> Result<()> {
2083    let project_id = ProjectId::from_proto(request.project_id);
2084    let project_connection_ids = session
2085        .db()
2086        .await
2087        .project_connection_ids(project_id, session.connection_id, true)
2088        .await?;
2089    broadcast(
2090        Some(session.connection_id),
2091        project_connection_ids.iter().copied(),
2092        |connection_id| {
2093            session
2094                .peer
2095                .forward_send(session.connection_id, connection_id, request.clone())
2096        },
2097    );
2098    Ok(())
2099}
2100
2101/// forward a project request to the host. These requests should be read only
2102/// as guests are allowed to send them.
2103async fn forward_read_only_project_request<T>(
2104    request: T,
2105    response: Response<T>,
2106    session: Session,
2107) -> Result<()>
2108where
2109    T: EntityMessage + RequestMessage,
2110{
2111    let project_id = ProjectId::from_proto(request.remote_entity_id());
2112    let host_connection_id = session
2113        .db()
2114        .await
2115        .host_for_read_only_project_request(project_id, session.connection_id)
2116        .await?;
2117    let payload = session
2118        .peer
2119        .forward_request(session.connection_id, host_connection_id, request)
2120        .await?;
2121    response.send(payload)?;
2122    Ok(())
2123}
2124
2125async fn forward_find_search_candidates_request(
2126    request: proto::FindSearchCandidates,
2127    response: Response<proto::FindSearchCandidates>,
2128    session: Session,
2129) -> Result<()> {
2130    let project_id = ProjectId::from_proto(request.remote_entity_id());
2131    let host_connection_id = session
2132        .db()
2133        .await
2134        .host_for_read_only_project_request(project_id, session.connection_id)
2135        .await?;
2136    let payload = session
2137        .peer
2138        .forward_request(session.connection_id, host_connection_id, request)
2139        .await?;
2140    response.send(payload)?;
2141    Ok(())
2142}
2143
2144/// forward a project request to the host. These requests are disallowed
2145/// for guests.
2146async fn forward_mutating_project_request<T>(
2147    request: T,
2148    response: Response<T>,
2149    session: Session,
2150) -> Result<()>
2151where
2152    T: EntityMessage + RequestMessage,
2153{
2154    let project_id = ProjectId::from_proto(request.remote_entity_id());
2155
2156    let host_connection_id = session
2157        .db()
2158        .await
2159        .host_for_mutating_project_request(project_id, session.connection_id)
2160        .await?;
2161    let payload = session
2162        .peer
2163        .forward_request(session.connection_id, host_connection_id, request)
2164        .await?;
2165    response.send(payload)?;
2166    Ok(())
2167}
2168
2169/// Notify other participants that a new buffer has been created
2170async fn create_buffer_for_peer(
2171    request: proto::CreateBufferForPeer,
2172    session: Session,
2173) -> Result<()> {
2174    session
2175        .db()
2176        .await
2177        .check_user_is_project_host(
2178            ProjectId::from_proto(request.project_id),
2179            session.connection_id,
2180        )
2181        .await?;
2182    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2183    session
2184        .peer
2185        .forward_send(session.connection_id, peer_id.into(), request)?;
2186    Ok(())
2187}
2188
2189/// Notify other participants that a buffer has been updated. This is
2190/// allowed for guests as long as the update is limited to selections.
2191async fn update_buffer(
2192    request: proto::UpdateBuffer,
2193    response: Response<proto::UpdateBuffer>,
2194    session: Session,
2195) -> Result<()> {
2196    let project_id = ProjectId::from_proto(request.project_id);
2197    let mut capability = Capability::ReadOnly;
2198
2199    for op in request.operations.iter() {
2200        match op.variant {
2201            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2202            Some(_) => capability = Capability::ReadWrite,
2203        }
2204    }
2205
2206    let host = {
2207        let guard = session
2208            .db()
2209            .await
2210            .connections_for_buffer_update(project_id, session.connection_id, capability)
2211            .await?;
2212
2213        let (host, guests) = &*guard;
2214
2215        broadcast(
2216            Some(session.connection_id),
2217            guests.clone(),
2218            |connection_id| {
2219                session
2220                    .peer
2221                    .forward_send(session.connection_id, connection_id, request.clone())
2222            },
2223        );
2224
2225        *host
2226    };
2227
2228    if host != session.connection_id {
2229        session
2230            .peer
2231            .forward_request(session.connection_id, host, request.clone())
2232            .await?;
2233    }
2234
2235    response.send(proto::Ack {})?;
2236    Ok(())
2237}
2238
2239async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2240    let project_id = ProjectId::from_proto(message.project_id);
2241
2242    let operation = message.operation.as_ref().context("invalid operation")?;
2243    let capability = match operation.variant.as_ref() {
2244        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2245            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2246                match buffer_op.variant {
2247                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2248                        Capability::ReadOnly
2249                    }
2250                    _ => Capability::ReadWrite,
2251                }
2252            } else {
2253                Capability::ReadWrite
2254            }
2255        }
2256        Some(_) => Capability::ReadWrite,
2257        None => Capability::ReadOnly,
2258    };
2259
2260    let guard = session
2261        .db()
2262        .await
2263        .connections_for_buffer_update(project_id, session.connection_id, capability)
2264        .await?;
2265
2266    let (host, guests) = &*guard;
2267
2268    broadcast(
2269        Some(session.connection_id),
2270        guests.iter().chain([host]).copied(),
2271        |connection_id| {
2272            session
2273                .peer
2274                .forward_send(session.connection_id, connection_id, message.clone())
2275        },
2276    );
2277
2278    Ok(())
2279}
2280
2281/// Notify other participants that a project has been updated.
2282async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2283    request: T,
2284    session: Session,
2285) -> Result<()> {
2286    let project_id = ProjectId::from_proto(request.remote_entity_id());
2287    let project_connection_ids = session
2288        .db()
2289        .await
2290        .project_connection_ids(project_id, session.connection_id, false)
2291        .await?;
2292
2293    broadcast(
2294        Some(session.connection_id),
2295        project_connection_ids.iter().copied(),
2296        |connection_id| {
2297            session
2298                .peer
2299                .forward_send(session.connection_id, connection_id, request.clone())
2300        },
2301    );
2302    Ok(())
2303}
2304
2305/// Start following another user in a call.
2306async fn follow(
2307    request: proto::Follow,
2308    response: Response<proto::Follow>,
2309    session: Session,
2310) -> Result<()> {
2311    let room_id = RoomId::from_proto(request.room_id);
2312    let project_id = request.project_id.map(ProjectId::from_proto);
2313    let leader_id = request
2314        .leader_id
2315        .ok_or_else(|| anyhow!("invalid leader id"))?
2316        .into();
2317    let follower_id = session.connection_id;
2318
2319    session
2320        .db()
2321        .await
2322        .check_room_participants(room_id, leader_id, session.connection_id)
2323        .await?;
2324
2325    let response_payload = session
2326        .peer
2327        .forward_request(session.connection_id, leader_id, request)
2328        .await?;
2329    response.send(response_payload)?;
2330
2331    if let Some(project_id) = project_id {
2332        let room = session
2333            .db()
2334            .await
2335            .follow(room_id, project_id, leader_id, follower_id)
2336            .await?;
2337        room_updated(&room, &session.peer);
2338    }
2339
2340    Ok(())
2341}
2342
2343/// Stop following another user in a call.
2344async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2345    let room_id = RoomId::from_proto(request.room_id);
2346    let project_id = request.project_id.map(ProjectId::from_proto);
2347    let leader_id = request
2348        .leader_id
2349        .ok_or_else(|| anyhow!("invalid leader id"))?
2350        .into();
2351    let follower_id = session.connection_id;
2352
2353    session
2354        .db()
2355        .await
2356        .check_room_participants(room_id, leader_id, session.connection_id)
2357        .await?;
2358
2359    session
2360        .peer
2361        .forward_send(session.connection_id, leader_id, request)?;
2362
2363    if let Some(project_id) = project_id {
2364        let room = session
2365            .db()
2366            .await
2367            .unfollow(room_id, project_id, leader_id, follower_id)
2368            .await?;
2369        room_updated(&room, &session.peer);
2370    }
2371
2372    Ok(())
2373}
2374
2375/// Notify everyone following you of your current location.
2376async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2377    let room_id = RoomId::from_proto(request.room_id);
2378    let database = session.db.lock().await;
2379
2380    let connection_ids = if let Some(project_id) = request.project_id {
2381        let project_id = ProjectId::from_proto(project_id);
2382        database
2383            .project_connection_ids(project_id, session.connection_id, true)
2384            .await?
2385    } else {
2386        database
2387            .room_connection_ids(room_id, session.connection_id)
2388            .await?
2389    };
2390
2391    // For now, don't send view update messages back to that view's current leader.
2392    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2393        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2394        _ => None,
2395    });
2396
2397    for connection_id in connection_ids.iter().cloned() {
2398        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2399            session
2400                .peer
2401                .forward_send(session.connection_id, connection_id, request.clone())?;
2402        }
2403    }
2404    Ok(())
2405}
2406
2407/// Get public data about users.
2408async fn get_users(
2409    request: proto::GetUsers,
2410    response: Response<proto::GetUsers>,
2411    session: Session,
2412) -> Result<()> {
2413    let user_ids = request
2414        .user_ids
2415        .into_iter()
2416        .map(UserId::from_proto)
2417        .collect();
2418    let users = session
2419        .db()
2420        .await
2421        .get_users_by_ids(user_ids)
2422        .await?
2423        .into_iter()
2424        .map(|user| proto::User {
2425            id: user.id.to_proto(),
2426            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2427            github_login: user.github_login,
2428            email: user.email_address,
2429            name: user.name,
2430        })
2431        .collect();
2432    response.send(proto::UsersResponse { users })?;
2433    Ok(())
2434}
2435
2436/// Search for users (to invite) buy Github login
2437async fn fuzzy_search_users(
2438    request: proto::FuzzySearchUsers,
2439    response: Response<proto::FuzzySearchUsers>,
2440    session: Session,
2441) -> Result<()> {
2442    let query = request.query;
2443    let users = match query.len() {
2444        0 => vec![],
2445        1 | 2 => session
2446            .db()
2447            .await
2448            .get_user_by_github_login(&query)
2449            .await?
2450            .into_iter()
2451            .collect(),
2452        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2453    };
2454    let users = users
2455        .into_iter()
2456        .filter(|user| user.id != session.user_id())
2457        .map(|user| proto::User {
2458            id: user.id.to_proto(),
2459            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2460            github_login: user.github_login,
2461            name: user.name,
2462            email: user.email_address,
2463        })
2464        .collect();
2465    response.send(proto::UsersResponse { users })?;
2466    Ok(())
2467}
2468
2469/// Send a contact request to another user.
2470async fn request_contact(
2471    request: proto::RequestContact,
2472    response: Response<proto::RequestContact>,
2473    session: Session,
2474) -> Result<()> {
2475    let requester_id = session.user_id();
2476    let responder_id = UserId::from_proto(request.responder_id);
2477    if requester_id == responder_id {
2478        return Err(anyhow!("cannot add yourself as a contact"))?;
2479    }
2480
2481    let notifications = session
2482        .db()
2483        .await
2484        .send_contact_request(requester_id, responder_id)
2485        .await?;
2486
2487    // Update outgoing contact requests of requester
2488    let mut update = proto::UpdateContacts::default();
2489    update.outgoing_requests.push(responder_id.to_proto());
2490    for connection_id in session
2491        .connection_pool()
2492        .await
2493        .user_connection_ids(requester_id)
2494    {
2495        session.peer.send(connection_id, update.clone())?;
2496    }
2497
2498    // Update incoming contact requests of responder
2499    let mut update = proto::UpdateContacts::default();
2500    update
2501        .incoming_requests
2502        .push(proto::IncomingContactRequest {
2503            requester_id: requester_id.to_proto(),
2504        });
2505    let connection_pool = session.connection_pool().await;
2506    for connection_id in connection_pool.user_connection_ids(responder_id) {
2507        session.peer.send(connection_id, update.clone())?;
2508    }
2509
2510    send_notifications(&connection_pool, &session.peer, notifications);
2511
2512    response.send(proto::Ack {})?;
2513    Ok(())
2514}
2515
2516/// Accept or decline a contact request
2517async fn respond_to_contact_request(
2518    request: proto::RespondToContactRequest,
2519    response: Response<proto::RespondToContactRequest>,
2520    session: Session,
2521) -> Result<()> {
2522    let responder_id = session.user_id();
2523    let requester_id = UserId::from_proto(request.requester_id);
2524    let db = session.db().await;
2525    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2526        db.dismiss_contact_notification(responder_id, requester_id)
2527            .await?;
2528    } else {
2529        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2530
2531        let notifications = db
2532            .respond_to_contact_request(responder_id, requester_id, accept)
2533            .await?;
2534        let requester_busy = db.is_user_busy(requester_id).await?;
2535        let responder_busy = db.is_user_busy(responder_id).await?;
2536
2537        let pool = session.connection_pool().await;
2538        // Update responder with new contact
2539        let mut update = proto::UpdateContacts::default();
2540        if accept {
2541            update
2542                .contacts
2543                .push(contact_for_user(requester_id, requester_busy, &pool));
2544        }
2545        update
2546            .remove_incoming_requests
2547            .push(requester_id.to_proto());
2548        for connection_id in pool.user_connection_ids(responder_id) {
2549            session.peer.send(connection_id, update.clone())?;
2550        }
2551
2552        // Update requester with new contact
2553        let mut update = proto::UpdateContacts::default();
2554        if accept {
2555            update
2556                .contacts
2557                .push(contact_for_user(responder_id, responder_busy, &pool));
2558        }
2559        update
2560            .remove_outgoing_requests
2561            .push(responder_id.to_proto());
2562
2563        for connection_id in pool.user_connection_ids(requester_id) {
2564            session.peer.send(connection_id, update.clone())?;
2565        }
2566
2567        send_notifications(&pool, &session.peer, notifications);
2568    }
2569
2570    response.send(proto::Ack {})?;
2571    Ok(())
2572}
2573
2574/// Remove a contact.
2575async fn remove_contact(
2576    request: proto::RemoveContact,
2577    response: Response<proto::RemoveContact>,
2578    session: Session,
2579) -> Result<()> {
2580    let requester_id = session.user_id();
2581    let responder_id = UserId::from_proto(request.user_id);
2582    let db = session.db().await;
2583    let (contact_accepted, deleted_notification_id) =
2584        db.remove_contact(requester_id, responder_id).await?;
2585
2586    let pool = session.connection_pool().await;
2587    // Update outgoing contact requests of requester
2588    let mut update = proto::UpdateContacts::default();
2589    if contact_accepted {
2590        update.remove_contacts.push(responder_id.to_proto());
2591    } else {
2592        update
2593            .remove_outgoing_requests
2594            .push(responder_id.to_proto());
2595    }
2596    for connection_id in pool.user_connection_ids(requester_id) {
2597        session.peer.send(connection_id, update.clone())?;
2598    }
2599
2600    // Update incoming contact requests of responder
2601    let mut update = proto::UpdateContacts::default();
2602    if contact_accepted {
2603        update.remove_contacts.push(requester_id.to_proto());
2604    } else {
2605        update
2606            .remove_incoming_requests
2607            .push(requester_id.to_proto());
2608    }
2609    for connection_id in pool.user_connection_ids(responder_id) {
2610        session.peer.send(connection_id, update.clone())?;
2611        if let Some(notification_id) = deleted_notification_id {
2612            session.peer.send(
2613                connection_id,
2614                proto::DeleteNotification {
2615                    notification_id: notification_id.to_proto(),
2616                },
2617            )?;
2618        }
2619    }
2620
2621    response.send(proto::Ack {})?;
2622    Ok(())
2623}
2624
2625fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2626    version.0.minor() < 139
2627}
2628
2629async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2630    let plan = session.current_plan(&session.db().await).await?;
2631
2632    session
2633        .peer
2634        .send(
2635            session.connection_id,
2636            proto::UpdateUserPlan { plan: plan.into() },
2637        )
2638        .trace_err();
2639
2640    Ok(())
2641}
2642
2643async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2644    subscribe_user_to_channels(session.user_id(), &session).await?;
2645    Ok(())
2646}
2647
2648async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2649    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2650    let mut pool = session.connection_pool().await;
2651    for membership in &channels_for_user.channel_memberships {
2652        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2653    }
2654    session.peer.send(
2655        session.connection_id,
2656        build_update_user_channels(&channels_for_user),
2657    )?;
2658    session.peer.send(
2659        session.connection_id,
2660        build_channels_update(channels_for_user),
2661    )?;
2662    Ok(())
2663}
2664
2665/// Creates a new channel.
2666async fn create_channel(
2667    request: proto::CreateChannel,
2668    response: Response<proto::CreateChannel>,
2669    session: Session,
2670) -> Result<()> {
2671    let db = session.db().await;
2672
2673    let parent_id = request.parent_id.map(ChannelId::from_proto);
2674    let (channel, membership) = db
2675        .create_channel(&request.name, parent_id, session.user_id())
2676        .await?;
2677
2678    let root_id = channel.root_id();
2679    let channel = Channel::from_model(channel);
2680
2681    response.send(proto::CreateChannelResponse {
2682        channel: Some(channel.to_proto()),
2683        parent_id: request.parent_id,
2684    })?;
2685
2686    let mut connection_pool = session.connection_pool().await;
2687    if let Some(membership) = membership {
2688        connection_pool.subscribe_to_channel(
2689            membership.user_id,
2690            membership.channel_id,
2691            membership.role,
2692        );
2693        let update = proto::UpdateUserChannels {
2694            channel_memberships: vec![proto::ChannelMembership {
2695                channel_id: membership.channel_id.to_proto(),
2696                role: membership.role.into(),
2697            }],
2698            ..Default::default()
2699        };
2700        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2701            session.peer.send(connection_id, update.clone())?;
2702        }
2703    }
2704
2705    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2706        if !role.can_see_channel(channel.visibility) {
2707            continue;
2708        }
2709
2710        let update = proto::UpdateChannels {
2711            channels: vec![channel.to_proto()],
2712            ..Default::default()
2713        };
2714        session.peer.send(connection_id, update.clone())?;
2715    }
2716
2717    Ok(())
2718}
2719
2720/// Delete a channel
2721async fn delete_channel(
2722    request: proto::DeleteChannel,
2723    response: Response<proto::DeleteChannel>,
2724    session: Session,
2725) -> Result<()> {
2726    let db = session.db().await;
2727
2728    let channel_id = request.channel_id;
2729    let (root_channel, removed_channels) = db
2730        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2731        .await?;
2732    response.send(proto::Ack {})?;
2733
2734    // Notify members of removed channels
2735    let mut update = proto::UpdateChannels::default();
2736    update
2737        .delete_channels
2738        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2739
2740    let connection_pool = session.connection_pool().await;
2741    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2742        session.peer.send(connection_id, update.clone())?;
2743    }
2744
2745    Ok(())
2746}
2747
2748/// Invite someone to join a channel.
2749async fn invite_channel_member(
2750    request: proto::InviteChannelMember,
2751    response: Response<proto::InviteChannelMember>,
2752    session: Session,
2753) -> Result<()> {
2754    let db = session.db().await;
2755    let channel_id = ChannelId::from_proto(request.channel_id);
2756    let invitee_id = UserId::from_proto(request.user_id);
2757    let InviteMemberResult {
2758        channel,
2759        notifications,
2760    } = db
2761        .invite_channel_member(
2762            channel_id,
2763            invitee_id,
2764            session.user_id(),
2765            request.role().into(),
2766        )
2767        .await?;
2768
2769    let update = proto::UpdateChannels {
2770        channel_invitations: vec![channel.to_proto()],
2771        ..Default::default()
2772    };
2773
2774    let connection_pool = session.connection_pool().await;
2775    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2776        session.peer.send(connection_id, update.clone())?;
2777    }
2778
2779    send_notifications(&connection_pool, &session.peer, notifications);
2780
2781    response.send(proto::Ack {})?;
2782    Ok(())
2783}
2784
2785/// remove someone from a channel
2786async fn remove_channel_member(
2787    request: proto::RemoveChannelMember,
2788    response: Response<proto::RemoveChannelMember>,
2789    session: Session,
2790) -> Result<()> {
2791    let db = session.db().await;
2792    let channel_id = ChannelId::from_proto(request.channel_id);
2793    let member_id = UserId::from_proto(request.user_id);
2794
2795    let RemoveChannelMemberResult {
2796        membership_update,
2797        notification_id,
2798    } = db
2799        .remove_channel_member(channel_id, member_id, session.user_id())
2800        .await?;
2801
2802    let mut connection_pool = session.connection_pool().await;
2803    notify_membership_updated(
2804        &mut connection_pool,
2805        membership_update,
2806        member_id,
2807        &session.peer,
2808    );
2809    for connection_id in connection_pool.user_connection_ids(member_id) {
2810        if let Some(notification_id) = notification_id {
2811            session
2812                .peer
2813                .send(
2814                    connection_id,
2815                    proto::DeleteNotification {
2816                        notification_id: notification_id.to_proto(),
2817                    },
2818                )
2819                .trace_err();
2820        }
2821    }
2822
2823    response.send(proto::Ack {})?;
2824    Ok(())
2825}
2826
2827/// Toggle the channel between public and private.
2828/// Care is taken to maintain the invariant that public channels only descend from public channels,
2829/// (though members-only channels can appear at any point in the hierarchy).
2830async fn set_channel_visibility(
2831    request: proto::SetChannelVisibility,
2832    response: Response<proto::SetChannelVisibility>,
2833    session: Session,
2834) -> Result<()> {
2835    let db = session.db().await;
2836    let channel_id = ChannelId::from_proto(request.channel_id);
2837    let visibility = request.visibility().into();
2838
2839    let channel_model = db
2840        .set_channel_visibility(channel_id, visibility, session.user_id())
2841        .await?;
2842    let root_id = channel_model.root_id();
2843    let channel = Channel::from_model(channel_model);
2844
2845    let mut connection_pool = session.connection_pool().await;
2846    for (user_id, role) in connection_pool
2847        .channel_user_ids(root_id)
2848        .collect::<Vec<_>>()
2849        .into_iter()
2850    {
2851        let update = if role.can_see_channel(channel.visibility) {
2852            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2853            proto::UpdateChannels {
2854                channels: vec![channel.to_proto()],
2855                ..Default::default()
2856            }
2857        } else {
2858            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2859            proto::UpdateChannels {
2860                delete_channels: vec![channel.id.to_proto()],
2861                ..Default::default()
2862            }
2863        };
2864
2865        for connection_id in connection_pool.user_connection_ids(user_id) {
2866            session.peer.send(connection_id, update.clone())?;
2867        }
2868    }
2869
2870    response.send(proto::Ack {})?;
2871    Ok(())
2872}
2873
2874/// Alter the role for a user in the channel.
2875async fn set_channel_member_role(
2876    request: proto::SetChannelMemberRole,
2877    response: Response<proto::SetChannelMemberRole>,
2878    session: Session,
2879) -> Result<()> {
2880    let db = session.db().await;
2881    let channel_id = ChannelId::from_proto(request.channel_id);
2882    let member_id = UserId::from_proto(request.user_id);
2883    let result = db
2884        .set_channel_member_role(
2885            channel_id,
2886            session.user_id(),
2887            member_id,
2888            request.role().into(),
2889        )
2890        .await?;
2891
2892    match result {
2893        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2894            let mut connection_pool = session.connection_pool().await;
2895            notify_membership_updated(
2896                &mut connection_pool,
2897                membership_update,
2898                member_id,
2899                &session.peer,
2900            )
2901        }
2902        db::SetMemberRoleResult::InviteUpdated(channel) => {
2903            let update = proto::UpdateChannels {
2904                channel_invitations: vec![channel.to_proto()],
2905                ..Default::default()
2906            };
2907
2908            for connection_id in session
2909                .connection_pool()
2910                .await
2911                .user_connection_ids(member_id)
2912            {
2913                session.peer.send(connection_id, update.clone())?;
2914            }
2915        }
2916    }
2917
2918    response.send(proto::Ack {})?;
2919    Ok(())
2920}
2921
2922/// Change the name of a channel
2923async fn rename_channel(
2924    request: proto::RenameChannel,
2925    response: Response<proto::RenameChannel>,
2926    session: Session,
2927) -> Result<()> {
2928    let db = session.db().await;
2929    let channel_id = ChannelId::from_proto(request.channel_id);
2930    let channel_model = db
2931        .rename_channel(channel_id, session.user_id(), &request.name)
2932        .await?;
2933    let root_id = channel_model.root_id();
2934    let channel = Channel::from_model(channel_model);
2935
2936    response.send(proto::RenameChannelResponse {
2937        channel: Some(channel.to_proto()),
2938    })?;
2939
2940    let connection_pool = session.connection_pool().await;
2941    let update = proto::UpdateChannels {
2942        channels: vec![channel.to_proto()],
2943        ..Default::default()
2944    };
2945    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2946        if role.can_see_channel(channel.visibility) {
2947            session.peer.send(connection_id, update.clone())?;
2948        }
2949    }
2950
2951    Ok(())
2952}
2953
2954/// Move a channel to a new parent.
2955async fn move_channel(
2956    request: proto::MoveChannel,
2957    response: Response<proto::MoveChannel>,
2958    session: Session,
2959) -> Result<()> {
2960    let channel_id = ChannelId::from_proto(request.channel_id);
2961    let to = ChannelId::from_proto(request.to);
2962
2963    let (root_id, channels) = session
2964        .db()
2965        .await
2966        .move_channel(channel_id, to, session.user_id())
2967        .await?;
2968
2969    let connection_pool = session.connection_pool().await;
2970    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2971        let channels = channels
2972            .iter()
2973            .filter_map(|channel| {
2974                if role.can_see_channel(channel.visibility) {
2975                    Some(channel.to_proto())
2976                } else {
2977                    None
2978                }
2979            })
2980            .collect::<Vec<_>>();
2981        if channels.is_empty() {
2982            continue;
2983        }
2984
2985        let update = proto::UpdateChannels {
2986            channels,
2987            ..Default::default()
2988        };
2989
2990        session.peer.send(connection_id, update.clone())?;
2991    }
2992
2993    response.send(Ack {})?;
2994    Ok(())
2995}
2996
2997/// Get the list of channel members
2998async fn get_channel_members(
2999    request: proto::GetChannelMembers,
3000    response: Response<proto::GetChannelMembers>,
3001    session: Session,
3002) -> Result<()> {
3003    let db = session.db().await;
3004    let channel_id = ChannelId::from_proto(request.channel_id);
3005    let limit = if request.limit == 0 {
3006        u16::MAX as u64
3007    } else {
3008        request.limit
3009    };
3010    let (members, users) = db
3011        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3012        .await?;
3013    response.send(proto::GetChannelMembersResponse { members, users })?;
3014    Ok(())
3015}
3016
3017/// Accept or decline a channel invitation.
3018async fn respond_to_channel_invite(
3019    request: proto::RespondToChannelInvite,
3020    response: Response<proto::RespondToChannelInvite>,
3021    session: Session,
3022) -> Result<()> {
3023    let db = session.db().await;
3024    let channel_id = ChannelId::from_proto(request.channel_id);
3025    let RespondToChannelInvite {
3026        membership_update,
3027        notifications,
3028    } = db
3029        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3030        .await?;
3031
3032    let mut connection_pool = session.connection_pool().await;
3033    if let Some(membership_update) = membership_update {
3034        notify_membership_updated(
3035            &mut connection_pool,
3036            membership_update,
3037            session.user_id(),
3038            &session.peer,
3039        );
3040    } else {
3041        let update = proto::UpdateChannels {
3042            remove_channel_invitations: vec![channel_id.to_proto()],
3043            ..Default::default()
3044        };
3045
3046        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3047            session.peer.send(connection_id, update.clone())?;
3048        }
3049    };
3050
3051    send_notifications(&connection_pool, &session.peer, notifications);
3052
3053    response.send(proto::Ack {})?;
3054
3055    Ok(())
3056}
3057
3058/// Join the channels' room
3059async fn join_channel(
3060    request: proto::JoinChannel,
3061    response: Response<proto::JoinChannel>,
3062    session: Session,
3063) -> Result<()> {
3064    let channel_id = ChannelId::from_proto(request.channel_id);
3065    join_channel_internal(channel_id, Box::new(response), session).await
3066}
3067
3068trait JoinChannelInternalResponse {
3069    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3070}
3071impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3072    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3073        Response::<proto::JoinChannel>::send(self, result)
3074    }
3075}
3076impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3077    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3078        Response::<proto::JoinRoom>::send(self, result)
3079    }
3080}
3081
3082async fn join_channel_internal(
3083    channel_id: ChannelId,
3084    response: Box<impl JoinChannelInternalResponse>,
3085    session: Session,
3086) -> Result<()> {
3087    let joined_room = {
3088        let mut db = session.db().await;
3089        // If zed quits without leaving the room, and the user re-opens zed before the
3090        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3091        // room they were in.
3092        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3093            tracing::info!(
3094                stale_connection_id = %connection,
3095                "cleaning up stale connection",
3096            );
3097            drop(db);
3098            leave_room_for_session(&session, connection).await?;
3099            db = session.db().await;
3100        }
3101
3102        let (joined_room, membership_updated, role) = db
3103            .join_channel(channel_id, session.user_id(), session.connection_id)
3104            .await?;
3105
3106        let live_kit_connection_info =
3107            session
3108                .app_state
3109                .livekit_client
3110                .as_ref()
3111                .and_then(|live_kit| {
3112                    let (can_publish, token) = if role == ChannelRole::Guest {
3113                        (
3114                            false,
3115                            live_kit
3116                                .guest_token(
3117                                    &joined_room.room.livekit_room,
3118                                    &session.user_id().to_string(),
3119                                )
3120                                .trace_err()?,
3121                        )
3122                    } else {
3123                        (
3124                            true,
3125                            live_kit
3126                                .room_token(
3127                                    &joined_room.room.livekit_room,
3128                                    &session.user_id().to_string(),
3129                                )
3130                                .trace_err()?,
3131                        )
3132                    };
3133
3134                    Some(LiveKitConnectionInfo {
3135                        server_url: live_kit.url().into(),
3136                        token,
3137                        can_publish,
3138                    })
3139                });
3140
3141        response.send(proto::JoinRoomResponse {
3142            room: Some(joined_room.room.clone()),
3143            channel_id: joined_room
3144                .channel
3145                .as_ref()
3146                .map(|channel| channel.id.to_proto()),
3147            live_kit_connection_info,
3148        })?;
3149
3150        let mut connection_pool = session.connection_pool().await;
3151        if let Some(membership_updated) = membership_updated {
3152            notify_membership_updated(
3153                &mut connection_pool,
3154                membership_updated,
3155                session.user_id(),
3156                &session.peer,
3157            );
3158        }
3159
3160        room_updated(&joined_room.room, &session.peer);
3161
3162        joined_room
3163    };
3164
3165    channel_updated(
3166        &joined_room
3167            .channel
3168            .ok_or_else(|| anyhow!("channel not returned"))?,
3169        &joined_room.room,
3170        &session.peer,
3171        &*session.connection_pool().await,
3172    );
3173
3174    update_user_contacts(session.user_id(), &session).await?;
3175    Ok(())
3176}
3177
3178/// Start editing the channel notes
3179async fn join_channel_buffer(
3180    request: proto::JoinChannelBuffer,
3181    response: Response<proto::JoinChannelBuffer>,
3182    session: Session,
3183) -> Result<()> {
3184    let db = session.db().await;
3185    let channel_id = ChannelId::from_proto(request.channel_id);
3186
3187    let open_response = db
3188        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3189        .await?;
3190
3191    let collaborators = open_response.collaborators.clone();
3192    response.send(open_response)?;
3193
3194    let update = UpdateChannelBufferCollaborators {
3195        channel_id: channel_id.to_proto(),
3196        collaborators: collaborators.clone(),
3197    };
3198    channel_buffer_updated(
3199        session.connection_id,
3200        collaborators
3201            .iter()
3202            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3203        &update,
3204        &session.peer,
3205    );
3206
3207    Ok(())
3208}
3209
3210/// Edit the channel notes
3211async fn update_channel_buffer(
3212    request: proto::UpdateChannelBuffer,
3213    session: Session,
3214) -> Result<()> {
3215    let db = session.db().await;
3216    let channel_id = ChannelId::from_proto(request.channel_id);
3217
3218    let (collaborators, epoch, version) = db
3219        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3220        .await?;
3221
3222    channel_buffer_updated(
3223        session.connection_id,
3224        collaborators.clone(),
3225        &proto::UpdateChannelBuffer {
3226            channel_id: channel_id.to_proto(),
3227            operations: request.operations,
3228        },
3229        &session.peer,
3230    );
3231
3232    let pool = &*session.connection_pool().await;
3233
3234    let non_collaborators =
3235        pool.channel_connection_ids(channel_id)
3236            .filter_map(|(connection_id, _)| {
3237                if collaborators.contains(&connection_id) {
3238                    None
3239                } else {
3240                    Some(connection_id)
3241                }
3242            });
3243
3244    broadcast(None, non_collaborators, |peer_id| {
3245        session.peer.send(
3246            peer_id,
3247            proto::UpdateChannels {
3248                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3249                    channel_id: channel_id.to_proto(),
3250                    epoch: epoch as u64,
3251                    version: version.clone(),
3252                }],
3253                ..Default::default()
3254            },
3255        )
3256    });
3257
3258    Ok(())
3259}
3260
3261/// Rejoin the channel notes after a connection blip
3262async fn rejoin_channel_buffers(
3263    request: proto::RejoinChannelBuffers,
3264    response: Response<proto::RejoinChannelBuffers>,
3265    session: Session,
3266) -> Result<()> {
3267    let db = session.db().await;
3268    let buffers = db
3269        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3270        .await?;
3271
3272    for rejoined_buffer in &buffers {
3273        let collaborators_to_notify = rejoined_buffer
3274            .buffer
3275            .collaborators
3276            .iter()
3277            .filter_map(|c| Some(c.peer_id?.into()));
3278        channel_buffer_updated(
3279            session.connection_id,
3280            collaborators_to_notify,
3281            &proto::UpdateChannelBufferCollaborators {
3282                channel_id: rejoined_buffer.buffer.channel_id,
3283                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3284            },
3285            &session.peer,
3286        );
3287    }
3288
3289    response.send(proto::RejoinChannelBuffersResponse {
3290        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3291    })?;
3292
3293    Ok(())
3294}
3295
3296/// Stop editing the channel notes
3297async fn leave_channel_buffer(
3298    request: proto::LeaveChannelBuffer,
3299    response: Response<proto::LeaveChannelBuffer>,
3300    session: Session,
3301) -> Result<()> {
3302    let db = session.db().await;
3303    let channel_id = ChannelId::from_proto(request.channel_id);
3304
3305    let left_buffer = db
3306        .leave_channel_buffer(channel_id, session.connection_id)
3307        .await?;
3308
3309    response.send(Ack {})?;
3310
3311    channel_buffer_updated(
3312        session.connection_id,
3313        left_buffer.connections,
3314        &proto::UpdateChannelBufferCollaborators {
3315            channel_id: channel_id.to_proto(),
3316            collaborators: left_buffer.collaborators,
3317        },
3318        &session.peer,
3319    );
3320
3321    Ok(())
3322}
3323
3324fn channel_buffer_updated<T: EnvelopedMessage>(
3325    sender_id: ConnectionId,
3326    collaborators: impl IntoIterator<Item = ConnectionId>,
3327    message: &T,
3328    peer: &Peer,
3329) {
3330    broadcast(Some(sender_id), collaborators, |peer_id| {
3331        peer.send(peer_id, message.clone())
3332    });
3333}
3334
3335fn send_notifications(
3336    connection_pool: &ConnectionPool,
3337    peer: &Peer,
3338    notifications: db::NotificationBatch,
3339) {
3340    for (user_id, notification) in notifications {
3341        for connection_id in connection_pool.user_connection_ids(user_id) {
3342            if let Err(error) = peer.send(
3343                connection_id,
3344                proto::AddNotification {
3345                    notification: Some(notification.clone()),
3346                },
3347            ) {
3348                tracing::error!(
3349                    "failed to send notification to {:?} {}",
3350                    connection_id,
3351                    error
3352                );
3353            }
3354        }
3355    }
3356}
3357
3358/// Send a message to the channel
3359async fn send_channel_message(
3360    request: proto::SendChannelMessage,
3361    response: Response<proto::SendChannelMessage>,
3362    session: Session,
3363) -> Result<()> {
3364    // Validate the message body.
3365    let body = request.body.trim().to_string();
3366    if body.len() > MAX_MESSAGE_LEN {
3367        return Err(anyhow!("message is too long"))?;
3368    }
3369    if body.is_empty() {
3370        return Err(anyhow!("message can't be blank"))?;
3371    }
3372
3373    // TODO: adjust mentions if body is trimmed
3374
3375    let timestamp = OffsetDateTime::now_utc();
3376    let nonce = request
3377        .nonce
3378        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3379
3380    let channel_id = ChannelId::from_proto(request.channel_id);
3381    let CreatedChannelMessage {
3382        message_id,
3383        participant_connection_ids,
3384        notifications,
3385    } = session
3386        .db()
3387        .await
3388        .create_channel_message(
3389            channel_id,
3390            session.user_id(),
3391            &body,
3392            &request.mentions,
3393            timestamp,
3394            nonce.clone().into(),
3395            request.reply_to_message_id.map(MessageId::from_proto),
3396        )
3397        .await?;
3398
3399    let message = proto::ChannelMessage {
3400        sender_id: session.user_id().to_proto(),
3401        id: message_id.to_proto(),
3402        body,
3403        mentions: request.mentions,
3404        timestamp: timestamp.unix_timestamp() as u64,
3405        nonce: Some(nonce),
3406        reply_to_message_id: request.reply_to_message_id,
3407        edited_at: None,
3408    };
3409    broadcast(
3410        Some(session.connection_id),
3411        participant_connection_ids.clone(),
3412        |connection| {
3413            session.peer.send(
3414                connection,
3415                proto::ChannelMessageSent {
3416                    channel_id: channel_id.to_proto(),
3417                    message: Some(message.clone()),
3418                },
3419            )
3420        },
3421    );
3422    response.send(proto::SendChannelMessageResponse {
3423        message: Some(message),
3424    })?;
3425
3426    let pool = &*session.connection_pool().await;
3427    let non_participants =
3428        pool.channel_connection_ids(channel_id)
3429            .filter_map(|(connection_id, _)| {
3430                if participant_connection_ids.contains(&connection_id) {
3431                    None
3432                } else {
3433                    Some(connection_id)
3434                }
3435            });
3436    broadcast(None, non_participants, |peer_id| {
3437        session.peer.send(
3438            peer_id,
3439            proto::UpdateChannels {
3440                latest_channel_message_ids: vec![proto::ChannelMessageId {
3441                    channel_id: channel_id.to_proto(),
3442                    message_id: message_id.to_proto(),
3443                }],
3444                ..Default::default()
3445            },
3446        )
3447    });
3448    send_notifications(pool, &session.peer, notifications);
3449
3450    Ok(())
3451}
3452
3453/// Delete a channel message
3454async fn remove_channel_message(
3455    request: proto::RemoveChannelMessage,
3456    response: Response<proto::RemoveChannelMessage>,
3457    session: Session,
3458) -> Result<()> {
3459    let channel_id = ChannelId::from_proto(request.channel_id);
3460    let message_id = MessageId::from_proto(request.message_id);
3461    let (connection_ids, existing_notification_ids) = session
3462        .db()
3463        .await
3464        .remove_channel_message(channel_id, message_id, session.user_id())
3465        .await?;
3466
3467    broadcast(
3468        Some(session.connection_id),
3469        connection_ids,
3470        move |connection| {
3471            session.peer.send(connection, request.clone())?;
3472
3473            for notification_id in &existing_notification_ids {
3474                session.peer.send(
3475                    connection,
3476                    proto::DeleteNotification {
3477                        notification_id: (*notification_id).to_proto(),
3478                    },
3479                )?;
3480            }
3481
3482            Ok(())
3483        },
3484    );
3485    response.send(proto::Ack {})?;
3486    Ok(())
3487}
3488
3489async fn update_channel_message(
3490    request: proto::UpdateChannelMessage,
3491    response: Response<proto::UpdateChannelMessage>,
3492    session: Session,
3493) -> Result<()> {
3494    let channel_id = ChannelId::from_proto(request.channel_id);
3495    let message_id = MessageId::from_proto(request.message_id);
3496    let updated_at = OffsetDateTime::now_utc();
3497    let UpdatedChannelMessage {
3498        message_id,
3499        participant_connection_ids,
3500        notifications,
3501        reply_to_message_id,
3502        timestamp,
3503        deleted_mention_notification_ids,
3504        updated_mention_notifications,
3505    } = session
3506        .db()
3507        .await
3508        .update_channel_message(
3509            channel_id,
3510            message_id,
3511            session.user_id(),
3512            request.body.as_str(),
3513            &request.mentions,
3514            updated_at,
3515        )
3516        .await?;
3517
3518    let nonce = request
3519        .nonce
3520        .clone()
3521        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3522
3523    let message = proto::ChannelMessage {
3524        sender_id: session.user_id().to_proto(),
3525        id: message_id.to_proto(),
3526        body: request.body.clone(),
3527        mentions: request.mentions.clone(),
3528        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3529        nonce: Some(nonce),
3530        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3531        edited_at: Some(updated_at.unix_timestamp() as u64),
3532    };
3533
3534    response.send(proto::Ack {})?;
3535
3536    let pool = &*session.connection_pool().await;
3537    broadcast(
3538        Some(session.connection_id),
3539        participant_connection_ids,
3540        |connection| {
3541            session.peer.send(
3542                connection,
3543                proto::ChannelMessageUpdate {
3544                    channel_id: channel_id.to_proto(),
3545                    message: Some(message.clone()),
3546                },
3547            )?;
3548
3549            for notification_id in &deleted_mention_notification_ids {
3550                session.peer.send(
3551                    connection,
3552                    proto::DeleteNotification {
3553                        notification_id: (*notification_id).to_proto(),
3554                    },
3555                )?;
3556            }
3557
3558            for notification in &updated_mention_notifications {
3559                session.peer.send(
3560                    connection,
3561                    proto::UpdateNotification {
3562                        notification: Some(notification.clone()),
3563                    },
3564                )?;
3565            }
3566
3567            Ok(())
3568        },
3569    );
3570
3571    send_notifications(pool, &session.peer, notifications);
3572
3573    Ok(())
3574}
3575
3576/// Mark a channel message as read
3577async fn acknowledge_channel_message(
3578    request: proto::AckChannelMessage,
3579    session: Session,
3580) -> Result<()> {
3581    let channel_id = ChannelId::from_proto(request.channel_id);
3582    let message_id = MessageId::from_proto(request.message_id);
3583    let notifications = session
3584        .db()
3585        .await
3586        .observe_channel_message(channel_id, session.user_id(), message_id)
3587        .await?;
3588    send_notifications(
3589        &*session.connection_pool().await,
3590        &session.peer,
3591        notifications,
3592    );
3593    Ok(())
3594}
3595
3596/// Mark a buffer version as synced
3597async fn acknowledge_buffer_version(
3598    request: proto::AckBufferOperation,
3599    session: Session,
3600) -> Result<()> {
3601    let buffer_id = BufferId::from_proto(request.buffer_id);
3602    session
3603        .db()
3604        .await
3605        .observe_buffer_version(
3606            buffer_id,
3607            session.user_id(),
3608            request.epoch as i32,
3609            &request.version,
3610        )
3611        .await?;
3612    Ok(())
3613}
3614
3615async fn count_language_model_tokens(
3616    request: proto::CountLanguageModelTokens,
3617    response: Response<proto::CountLanguageModelTokens>,
3618    session: Session,
3619    config: &Config,
3620) -> Result<()> {
3621    authorize_access_to_legacy_llm_endpoints(&session).await?;
3622
3623    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3624        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3625        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3626    };
3627
3628    session
3629        .app_state
3630        .rate_limiter
3631        .check(&*rate_limit, session.user_id())
3632        .await?;
3633
3634    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3635        Some(proto::LanguageModelProvider::Google) => {
3636            let api_key = config
3637                .google_ai_api_key
3638                .as_ref()
3639                .context("no Google AI API key configured on the server")?;
3640            google_ai::count_tokens(
3641                session.http_client.as_ref(),
3642                google_ai::API_URL,
3643                api_key,
3644                serde_json::from_str(&request.request)?,
3645            )
3646            .await?
3647        }
3648        _ => return Err(anyhow!("unsupported provider"))?,
3649    };
3650
3651    response.send(proto::CountLanguageModelTokensResponse {
3652        token_count: result.total_tokens as u32,
3653    })?;
3654
3655    Ok(())
3656}
3657
3658struct ZedProCountLanguageModelTokensRateLimit;
3659
3660impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3661    fn capacity(&self) -> usize {
3662        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3663            .ok()
3664            .and_then(|v| v.parse().ok())
3665            .unwrap_or(600) // Picked arbitrarily
3666    }
3667
3668    fn refill_duration(&self) -> chrono::Duration {
3669        chrono::Duration::hours(1)
3670    }
3671
3672    fn db_name(&self) -> &'static str {
3673        "zed-pro:count-language-model-tokens"
3674    }
3675}
3676
3677struct FreeCountLanguageModelTokensRateLimit;
3678
3679impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3680    fn capacity(&self) -> usize {
3681        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3682            .ok()
3683            .and_then(|v| v.parse().ok())
3684            .unwrap_or(600 / 10) // Picked arbitrarily
3685    }
3686
3687    fn refill_duration(&self) -> chrono::Duration {
3688        chrono::Duration::hours(1)
3689    }
3690
3691    fn db_name(&self) -> &'static str {
3692        "free:count-language-model-tokens"
3693    }
3694}
3695
3696struct ZedProComputeEmbeddingsRateLimit;
3697
3698impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3699    fn capacity(&self) -> usize {
3700        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3701            .ok()
3702            .and_then(|v| v.parse().ok())
3703            .unwrap_or(5000) // Picked arbitrarily
3704    }
3705
3706    fn refill_duration(&self) -> chrono::Duration {
3707        chrono::Duration::hours(1)
3708    }
3709
3710    fn db_name(&self) -> &'static str {
3711        "zed-pro:compute-embeddings"
3712    }
3713}
3714
3715struct FreeComputeEmbeddingsRateLimit;
3716
3717impl RateLimit for FreeComputeEmbeddingsRateLimit {
3718    fn capacity(&self) -> usize {
3719        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3720            .ok()
3721            .and_then(|v| v.parse().ok())
3722            .unwrap_or(5000 / 10) // Picked arbitrarily
3723    }
3724
3725    fn refill_duration(&self) -> chrono::Duration {
3726        chrono::Duration::hours(1)
3727    }
3728
3729    fn db_name(&self) -> &'static str {
3730        "free:compute-embeddings"
3731    }
3732}
3733
3734async fn compute_embeddings(
3735    request: proto::ComputeEmbeddings,
3736    response: Response<proto::ComputeEmbeddings>,
3737    session: Session,
3738    api_key: Option<Arc<str>>,
3739) -> Result<()> {
3740    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3741    authorize_access_to_legacy_llm_endpoints(&session).await?;
3742
3743    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3744        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3745        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3746    };
3747
3748    session
3749        .app_state
3750        .rate_limiter
3751        .check(&*rate_limit, session.user_id())
3752        .await?;
3753
3754    let embeddings = match request.model.as_str() {
3755        "openai/text-embedding-3-small" => {
3756            open_ai::embed(
3757                session.http_client.as_ref(),
3758                OPEN_AI_API_URL,
3759                &api_key,
3760                OpenAiEmbeddingModel::TextEmbedding3Small,
3761                request.texts.iter().map(|text| text.as_str()),
3762            )
3763            .await?
3764        }
3765        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3766    };
3767
3768    let embeddings = request
3769        .texts
3770        .iter()
3771        .map(|text| {
3772            let mut hasher = sha2::Sha256::new();
3773            hasher.update(text.as_bytes());
3774            let result = hasher.finalize();
3775            result.to_vec()
3776        })
3777        .zip(
3778            embeddings
3779                .data
3780                .into_iter()
3781                .map(|embedding| embedding.embedding),
3782        )
3783        .collect::<HashMap<_, _>>();
3784
3785    let db = session.db().await;
3786    db.save_embeddings(&request.model, &embeddings)
3787        .await
3788        .context("failed to save embeddings")
3789        .trace_err();
3790
3791    response.send(proto::ComputeEmbeddingsResponse {
3792        embeddings: embeddings
3793            .into_iter()
3794            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3795            .collect(),
3796    })?;
3797    Ok(())
3798}
3799
3800async fn get_cached_embeddings(
3801    request: proto::GetCachedEmbeddings,
3802    response: Response<proto::GetCachedEmbeddings>,
3803    session: Session,
3804) -> Result<()> {
3805    authorize_access_to_legacy_llm_endpoints(&session).await?;
3806
3807    let db = session.db().await;
3808    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3809
3810    response.send(proto::GetCachedEmbeddingsResponse {
3811        embeddings: embeddings
3812            .into_iter()
3813            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3814            .collect(),
3815    })?;
3816    Ok(())
3817}
3818
3819/// This is leftover from before the LLM service.
3820///
3821/// The endpoints protected by this check will be moved there eventually.
3822async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3823    if session.is_staff() {
3824        Ok(())
3825    } else {
3826        Err(anyhow!("permission denied"))?
3827    }
3828}
3829
3830/// Get a Supermaven API key for the user
3831async fn get_supermaven_api_key(
3832    _request: proto::GetSupermavenApiKey,
3833    response: Response<proto::GetSupermavenApiKey>,
3834    session: Session,
3835) -> Result<()> {
3836    let user_id: String = session.user_id().to_string();
3837    if !session.is_staff() {
3838        return Err(anyhow!("supermaven not enabled for this account"))?;
3839    }
3840
3841    let email = session
3842        .email()
3843        .ok_or_else(|| anyhow!("user must have an email"))?;
3844
3845    let supermaven_admin_api = session
3846        .supermaven_client
3847        .as_ref()
3848        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3849
3850    let result = supermaven_admin_api
3851        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3852        .await?;
3853
3854    response.send(proto::GetSupermavenApiKeyResponse {
3855        api_key: result.api_key,
3856    })?;
3857
3858    Ok(())
3859}
3860
3861/// Start receiving chat updates for a channel
3862async fn join_channel_chat(
3863    request: proto::JoinChannelChat,
3864    response: Response<proto::JoinChannelChat>,
3865    session: Session,
3866) -> Result<()> {
3867    let channel_id = ChannelId::from_proto(request.channel_id);
3868
3869    let db = session.db().await;
3870    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3871        .await?;
3872    let messages = db
3873        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3874        .await?;
3875    response.send(proto::JoinChannelChatResponse {
3876        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3877        messages,
3878    })?;
3879    Ok(())
3880}
3881
3882/// Stop receiving chat updates for a channel
3883async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3884    let channel_id = ChannelId::from_proto(request.channel_id);
3885    session
3886        .db()
3887        .await
3888        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3889        .await?;
3890    Ok(())
3891}
3892
3893/// Retrieve the chat history for a channel
3894async fn get_channel_messages(
3895    request: proto::GetChannelMessages,
3896    response: Response<proto::GetChannelMessages>,
3897    session: Session,
3898) -> Result<()> {
3899    let channel_id = ChannelId::from_proto(request.channel_id);
3900    let messages = session
3901        .db()
3902        .await
3903        .get_channel_messages(
3904            channel_id,
3905            session.user_id(),
3906            MESSAGE_COUNT_PER_PAGE,
3907            Some(MessageId::from_proto(request.before_message_id)),
3908        )
3909        .await?;
3910    response.send(proto::GetChannelMessagesResponse {
3911        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3912        messages,
3913    })?;
3914    Ok(())
3915}
3916
3917/// Retrieve specific chat messages
3918async fn get_channel_messages_by_id(
3919    request: proto::GetChannelMessagesById,
3920    response: Response<proto::GetChannelMessagesById>,
3921    session: Session,
3922) -> Result<()> {
3923    let message_ids = request
3924        .message_ids
3925        .iter()
3926        .map(|id| MessageId::from_proto(*id))
3927        .collect::<Vec<_>>();
3928    let messages = session
3929        .db()
3930        .await
3931        .get_channel_messages_by_id(session.user_id(), &message_ids)
3932        .await?;
3933    response.send(proto::GetChannelMessagesResponse {
3934        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3935        messages,
3936    })?;
3937    Ok(())
3938}
3939
3940/// Retrieve the current users notifications
3941async fn get_notifications(
3942    request: proto::GetNotifications,
3943    response: Response<proto::GetNotifications>,
3944    session: Session,
3945) -> Result<()> {
3946    let notifications = session
3947        .db()
3948        .await
3949        .get_notifications(
3950            session.user_id(),
3951            NOTIFICATION_COUNT_PER_PAGE,
3952            request.before_id.map(db::NotificationId::from_proto),
3953        )
3954        .await?;
3955    response.send(proto::GetNotificationsResponse {
3956        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3957        notifications,
3958    })?;
3959    Ok(())
3960}
3961
3962/// Mark notifications as read
3963async fn mark_notification_as_read(
3964    request: proto::MarkNotificationRead,
3965    response: Response<proto::MarkNotificationRead>,
3966    session: Session,
3967) -> Result<()> {
3968    let database = &session.db().await;
3969    let notifications = database
3970        .mark_notification_as_read_by_id(
3971            session.user_id(),
3972            NotificationId::from_proto(request.notification_id),
3973        )
3974        .await?;
3975    send_notifications(
3976        &*session.connection_pool().await,
3977        &session.peer,
3978        notifications,
3979    );
3980    response.send(proto::Ack {})?;
3981    Ok(())
3982}
3983
3984/// Get the current users information
3985async fn get_private_user_info(
3986    _request: proto::GetPrivateUserInfo,
3987    response: Response<proto::GetPrivateUserInfo>,
3988    session: Session,
3989) -> Result<()> {
3990    let db = session.db().await;
3991
3992    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3993    let user = db
3994        .get_user_by_id(session.user_id())
3995        .await?
3996        .ok_or_else(|| anyhow!("user not found"))?;
3997    let flags = db.get_user_flags(session.user_id()).await?;
3998
3999    response.send(proto::GetPrivateUserInfoResponse {
4000        metrics_id,
4001        staff: user.admin,
4002        flags,
4003        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4004    })?;
4005    Ok(())
4006}
4007
4008/// Accept the terms of service (tos) on behalf of the current user
4009async fn accept_terms_of_service(
4010    _request: proto::AcceptTermsOfService,
4011    response: Response<proto::AcceptTermsOfService>,
4012    session: Session,
4013) -> Result<()> {
4014    let db = session.db().await;
4015
4016    let accepted_tos_at = Utc::now();
4017    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4018        .await?;
4019
4020    response.send(proto::AcceptTermsOfServiceResponse {
4021        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4022    })?;
4023    Ok(())
4024}
4025
4026/// The minimum account age an account must have in order to use the LLM service.
4027const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4028
4029async fn get_llm_api_token(
4030    _request: proto::GetLlmToken,
4031    response: Response<proto::GetLlmToken>,
4032    session: Session,
4033) -> Result<()> {
4034    let db = session.db().await;
4035
4036    let flags = db.get_user_flags(session.user_id()).await?;
4037    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4038    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4039    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4040
4041    if !session.is_staff() && !has_language_models_feature_flag {
4042        Err(anyhow!("permission denied"))?
4043    }
4044
4045    let user_id = session.user_id();
4046    let user = db
4047        .get_user_by_id(user_id)
4048        .await?
4049        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4050
4051    if user.accepted_tos_at.is_none() {
4052        Err(anyhow!("terms of service not accepted"))?
4053    }
4054
4055    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4056
4057    let bypass_account_age_check =
4058        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4059    if !bypass_account_age_check {
4060        let mut account_created_at = user.created_at;
4061        if let Some(github_created_at) = user.github_user_created_at {
4062            account_created_at = account_created_at.min(github_created_at);
4063        }
4064        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4065            Err(anyhow!("account too young"))?
4066        }
4067    }
4068
4069    let billing_preferences = db.get_billing_preferences(user.id).await?;
4070
4071    let token = LlmTokenClaims::create(
4072        &user,
4073        session.is_staff(),
4074        billing_preferences,
4075        has_llm_closed_beta_feature_flag,
4076        has_predict_edits_feature_flag,
4077        has_llm_subscription,
4078        session.current_plan(&db).await?,
4079        session.system_id.clone(),
4080        &session.app_state.config,
4081    )?;
4082    response.send(proto::GetLlmTokenResponse { token })?;
4083    Ok(())
4084}
4085
4086fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4087    let message = match message {
4088        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4089        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4090        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4091        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4092        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4093            code: frame.code.into(),
4094            reason: frame.reason,
4095        })),
4096        // We should never receive a frame while reading the message, according
4097        // to the `tungstenite` maintainers:
4098        //
4099        // > It cannot occur when you read messages from the WebSocket, but it
4100        // > can be used when you want to send the raw frames (e.g. you want to
4101        // > send the frames to the WebSocket without composing the full message first).
4102        // >
4103        // > — https://github.com/snapview/tungstenite-rs/issues/268
4104        TungsteniteMessage::Frame(_) => {
4105            bail!("received an unexpected frame while reading the message")
4106        }
4107    };
4108
4109    Ok(message)
4110}
4111
4112fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4113    match message {
4114        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4115        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4116        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4117        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4118        AxumMessage::Close(frame) => {
4119            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4120                code: frame.code.into(),
4121                reason: frame.reason,
4122            }))
4123        }
4124    }
4125}
4126
4127fn notify_membership_updated(
4128    connection_pool: &mut ConnectionPool,
4129    result: MembershipUpdated,
4130    user_id: UserId,
4131    peer: &Peer,
4132) {
4133    for membership in &result.new_channels.channel_memberships {
4134        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4135    }
4136    for channel_id in &result.removed_channels {
4137        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4138    }
4139
4140    let user_channels_update = proto::UpdateUserChannels {
4141        channel_memberships: result
4142            .new_channels
4143            .channel_memberships
4144            .iter()
4145            .map(|cm| proto::ChannelMembership {
4146                channel_id: cm.channel_id.to_proto(),
4147                role: cm.role.into(),
4148            })
4149            .collect(),
4150        ..Default::default()
4151    };
4152
4153    let mut update = build_channels_update(result.new_channels);
4154    update.delete_channels = result
4155        .removed_channels
4156        .into_iter()
4157        .map(|id| id.to_proto())
4158        .collect();
4159    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4160
4161    for connection_id in connection_pool.user_connection_ids(user_id) {
4162        peer.send(connection_id, user_channels_update.clone())
4163            .trace_err();
4164        peer.send(connection_id, update.clone()).trace_err();
4165    }
4166}
4167
4168fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4169    proto::UpdateUserChannels {
4170        channel_memberships: channels
4171            .channel_memberships
4172            .iter()
4173            .map(|m| proto::ChannelMembership {
4174                channel_id: m.channel_id.to_proto(),
4175                role: m.role.into(),
4176            })
4177            .collect(),
4178        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4179        observed_channel_message_id: channels.observed_channel_messages.clone(),
4180    }
4181}
4182
4183fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4184    let mut update = proto::UpdateChannels::default();
4185
4186    for channel in channels.channels {
4187        update.channels.push(channel.to_proto());
4188    }
4189
4190    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4191    update.latest_channel_message_ids = channels.latest_channel_messages;
4192
4193    for (channel_id, participants) in channels.channel_participants {
4194        update
4195            .channel_participants
4196            .push(proto::ChannelParticipants {
4197                channel_id: channel_id.to_proto(),
4198                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4199            });
4200    }
4201
4202    for channel in channels.invited_channels {
4203        update.channel_invitations.push(channel.to_proto());
4204    }
4205
4206    update
4207}
4208
4209fn build_initial_contacts_update(
4210    contacts: Vec<db::Contact>,
4211    pool: &ConnectionPool,
4212) -> proto::UpdateContacts {
4213    let mut update = proto::UpdateContacts::default();
4214
4215    for contact in contacts {
4216        match contact {
4217            db::Contact::Accepted { user_id, busy } => {
4218                update.contacts.push(contact_for_user(user_id, busy, pool));
4219            }
4220            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4221            db::Contact::Incoming { user_id } => {
4222                update
4223                    .incoming_requests
4224                    .push(proto::IncomingContactRequest {
4225                        requester_id: user_id.to_proto(),
4226                    })
4227            }
4228        }
4229    }
4230
4231    update
4232}
4233
4234fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4235    proto::Contact {
4236        user_id: user_id.to_proto(),
4237        online: pool.is_user_online(user_id),
4238        busy,
4239    }
4240}
4241
4242fn room_updated(room: &proto::Room, peer: &Peer) {
4243    broadcast(
4244        None,
4245        room.participants
4246            .iter()
4247            .filter_map(|participant| Some(participant.peer_id?.into())),
4248        |peer_id| {
4249            peer.send(
4250                peer_id,
4251                proto::RoomUpdated {
4252                    room: Some(room.clone()),
4253                },
4254            )
4255        },
4256    );
4257}
4258
4259fn channel_updated(
4260    channel: &db::channel::Model,
4261    room: &proto::Room,
4262    peer: &Peer,
4263    pool: &ConnectionPool,
4264) {
4265    let participants = room
4266        .participants
4267        .iter()
4268        .map(|p| p.user_id)
4269        .collect::<Vec<_>>();
4270
4271    broadcast(
4272        None,
4273        pool.channel_connection_ids(channel.root_id())
4274            .filter_map(|(channel_id, role)| {
4275                role.can_see_channel(channel.visibility)
4276                    .then_some(channel_id)
4277            }),
4278        |peer_id| {
4279            peer.send(
4280                peer_id,
4281                proto::UpdateChannels {
4282                    channel_participants: vec![proto::ChannelParticipants {
4283                        channel_id: channel.id.to_proto(),
4284                        participant_user_ids: participants.clone(),
4285                    }],
4286                    ..Default::default()
4287                },
4288            )
4289        },
4290    );
4291}
4292
4293async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4294    let db = session.db().await;
4295
4296    let contacts = db.get_contacts(user_id).await?;
4297    let busy = db.is_user_busy(user_id).await?;
4298
4299    let pool = session.connection_pool().await;
4300    let updated_contact = contact_for_user(user_id, busy, &pool);
4301    for contact in contacts {
4302        if let db::Contact::Accepted {
4303            user_id: contact_user_id,
4304            ..
4305        } = contact
4306        {
4307            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4308                session
4309                    .peer
4310                    .send(
4311                        contact_conn_id,
4312                        proto::UpdateContacts {
4313                            contacts: vec![updated_contact.clone()],
4314                            remove_contacts: Default::default(),
4315                            incoming_requests: Default::default(),
4316                            remove_incoming_requests: Default::default(),
4317                            outgoing_requests: Default::default(),
4318                            remove_outgoing_requests: Default::default(),
4319                        },
4320                    )
4321                    .trace_err();
4322            }
4323        }
4324    }
4325    Ok(())
4326}
4327
4328async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4329    let mut contacts_to_update = HashSet::default();
4330
4331    let room_id;
4332    let canceled_calls_to_user_ids;
4333    let livekit_room;
4334    let delete_livekit_room;
4335    let room;
4336    let channel;
4337
4338    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4339        contacts_to_update.insert(session.user_id());
4340
4341        for project in left_room.left_projects.values() {
4342            project_left(project, session);
4343        }
4344
4345        room_id = RoomId::from_proto(left_room.room.id);
4346        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4347        livekit_room = mem::take(&mut left_room.room.livekit_room);
4348        delete_livekit_room = left_room.deleted;
4349        room = mem::take(&mut left_room.room);
4350        channel = mem::take(&mut left_room.channel);
4351
4352        room_updated(&room, &session.peer);
4353    } else {
4354        return Ok(());
4355    }
4356
4357    if let Some(channel) = channel {
4358        channel_updated(
4359            &channel,
4360            &room,
4361            &session.peer,
4362            &*session.connection_pool().await,
4363        );
4364    }
4365
4366    {
4367        let pool = session.connection_pool().await;
4368        for canceled_user_id in canceled_calls_to_user_ids {
4369            for connection_id in pool.user_connection_ids(canceled_user_id) {
4370                session
4371                    .peer
4372                    .send(
4373                        connection_id,
4374                        proto::CallCanceled {
4375                            room_id: room_id.to_proto(),
4376                        },
4377                    )
4378                    .trace_err();
4379            }
4380            contacts_to_update.insert(canceled_user_id);
4381        }
4382    }
4383
4384    for contact_user_id in contacts_to_update {
4385        update_user_contacts(contact_user_id, session).await?;
4386    }
4387
4388    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4389        live_kit
4390            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4391            .await
4392            .trace_err();
4393
4394        if delete_livekit_room {
4395            live_kit.delete_room(livekit_room).await.trace_err();
4396        }
4397    }
4398
4399    Ok(())
4400}
4401
4402async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4403    let left_channel_buffers = session
4404        .db()
4405        .await
4406        .leave_channel_buffers(session.connection_id)
4407        .await?;
4408
4409    for left_buffer in left_channel_buffers {
4410        channel_buffer_updated(
4411            session.connection_id,
4412            left_buffer.connections,
4413            &proto::UpdateChannelBufferCollaborators {
4414                channel_id: left_buffer.channel_id.to_proto(),
4415                collaborators: left_buffer.collaborators,
4416            },
4417            &session.peer,
4418        );
4419    }
4420
4421    Ok(())
4422}
4423
4424fn project_left(project: &db::LeftProject, session: &Session) {
4425    for connection_id in &project.connection_ids {
4426        if project.should_unshare {
4427            session
4428                .peer
4429                .send(
4430                    *connection_id,
4431                    proto::UnshareProject {
4432                        project_id: project.id.to_proto(),
4433                    },
4434                )
4435                .trace_err();
4436        } else {
4437            session
4438                .peer
4439                .send(
4440                    *connection_id,
4441                    proto::RemoveProjectCollaborator {
4442                        project_id: project.id.to_proto(),
4443                        peer_id: Some(session.connection_id.into()),
4444                    },
4445                )
4446                .trace_err();
4447        }
4448    }
4449}
4450
4451pub trait ResultExt {
4452    type Ok;
4453
4454    fn trace_err(self) -> Option<Self::Ok>;
4455}
4456
4457impl<T, E> ResultExt for Result<T, E>
4458where
4459    E: std::fmt::Debug,
4460{
4461    type Ok = T;
4462
4463    #[track_caller]
4464    fn trace_err(self) -> Option<T> {
4465        match self {
4466            Ok(value) => Some(value),
4467            Err(error) => {
4468                tracing::error!("{:?}", error);
4469                None
4470            }
4471        }
4472    }
4473}