rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 341            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 342            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 344            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 345            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 346            .add_message_handler(create_buffer_for_peer)
 347            .add_request_handler(update_buffer)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 353            .add_request_handler(get_users)
 354            .add_request_handler(fuzzy_search_users)
 355            .add_request_handler(request_contact)
 356            .add_request_handler(remove_contact)
 357            .add_request_handler(respond_to_contact_request)
 358            .add_message_handler(subscribe_to_channels)
 359            .add_request_handler(create_channel)
 360            .add_request_handler(delete_channel)
 361            .add_request_handler(invite_channel_member)
 362            .add_request_handler(remove_channel_member)
 363            .add_request_handler(set_channel_member_role)
 364            .add_request_handler(set_channel_visibility)
 365            .add_request_handler(rename_channel)
 366            .add_request_handler(join_channel_buffer)
 367            .add_request_handler(leave_channel_buffer)
 368            .add_message_handler(update_channel_buffer)
 369            .add_request_handler(rejoin_channel_buffers)
 370            .add_request_handler(get_channel_members)
 371            .add_request_handler(respond_to_channel_invite)
 372            .add_request_handler(join_channel)
 373            .add_request_handler(join_channel_chat)
 374            .add_message_handler(leave_channel_chat)
 375            .add_request_handler(send_channel_message)
 376            .add_request_handler(remove_channel_message)
 377            .add_request_handler(update_channel_message)
 378            .add_request_handler(get_channel_messages)
 379            .add_request_handler(get_channel_messages_by_id)
 380            .add_request_handler(get_notifications)
 381            .add_request_handler(mark_notification_as_read)
 382            .add_request_handler(move_channel)
 383            .add_request_handler(follow)
 384            .add_message_handler(unfollow)
 385            .add_message_handler(update_followers)
 386            .add_request_handler(get_private_user_info)
 387            .add_request_handler(get_llm_api_token)
 388            .add_request_handler(accept_terms_of_service)
 389            .add_message_handler(acknowledge_channel_message)
 390            .add_message_handler(acknowledge_buffer_version)
 391            .add_request_handler(get_supermaven_api_key)
 392            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 395            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 398            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 399            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 400            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 401            .add_message_handler(update_context)
 402            .add_request_handler({
 403                let app_state = app_state.clone();
 404                move |request, response, session| {
 405                    let app_state = app_state.clone();
 406                    async move {
 407                        count_language_model_tokens(request, response, session, &app_state.config)
 408                            .await
 409                    }
 410                }
 411            })
 412            .add_request_handler(get_cached_embeddings)
 413            .add_request_handler({
 414                let app_state = app_state.clone();
 415                move |request, response, session| {
 416                    compute_embeddings(
 417                        request,
 418                        response,
 419                        session,
 420                        app_state.config.openai_api_key.clone(),
 421                    )
 422                }
 423            });
 424
 425        Arc::new(server)
 426    }
 427
 428    pub async fn start(&self) -> Result<()> {
 429        let server_id = *self.id.lock();
 430        let app_state = self.app_state.clone();
 431        let peer = self.peer.clone();
 432        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 433        let pool = self.connection_pool.clone();
 434        let livekit_client = self.app_state.livekit_client.clone();
 435
 436        let span = info_span!("start server");
 437        self.app_state.executor.spawn_detached(
 438            async move {
 439                tracing::info!("waiting for cleanup timeout");
 440                timeout.await;
 441                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 442                if let Some((room_ids, channel_ids)) = app_state
 443                    .db
 444                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 445                    .await
 446                    .trace_err()
 447                {
 448                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 449                    tracing::info!(
 450                        stale_channel_buffer_count = channel_ids.len(),
 451                        "retrieved stale channel buffers"
 452                    );
 453
 454                    for channel_id in channel_ids {
 455                        if let Some(refreshed_channel_buffer) = app_state
 456                            .db
 457                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 458                            .await
 459                            .trace_err()
 460                        {
 461                            for connection_id in refreshed_channel_buffer.connection_ids {
 462                                peer.send(
 463                                    connection_id,
 464                                    proto::UpdateChannelBufferCollaborators {
 465                                        channel_id: channel_id.to_proto(),
 466                                        collaborators: refreshed_channel_buffer
 467                                            .collaborators
 468                                            .clone(),
 469                                    },
 470                                )
 471                                .trace_err();
 472                            }
 473                        }
 474                    }
 475
 476                    for room_id in room_ids {
 477                        let mut contacts_to_update = HashSet::default();
 478                        let mut canceled_calls_to_user_ids = Vec::new();
 479                        let mut livekit_room = String::new();
 480                        let mut delete_livekit_room = false;
 481
 482                        if let Some(mut refreshed_room) = app_state
 483                            .db
 484                            .clear_stale_room_participants(room_id, server_id)
 485                            .await
 486                            .trace_err()
 487                        {
 488                            tracing::info!(
 489                                room_id = room_id.0,
 490                                new_participant_count = refreshed_room.room.participants.len(),
 491                                "refreshed room"
 492                            );
 493                            room_updated(&refreshed_room.room, &peer);
 494                            if let Some(channel) = refreshed_room.channel.as_ref() {
 495                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 496                            }
 497                            contacts_to_update
 498                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 499                            contacts_to_update
 500                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 501                            canceled_calls_to_user_ids =
 502                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 503                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 504                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 505                        }
 506
 507                        {
 508                            let pool = pool.lock();
 509                            for canceled_user_id in canceled_calls_to_user_ids {
 510                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 511                                    peer.send(
 512                                        connection_id,
 513                                        proto::CallCanceled {
 514                                            room_id: room_id.to_proto(),
 515                                        },
 516                                    )
 517                                    .trace_err();
 518                                }
 519                            }
 520                        }
 521
 522                        for user_id in contacts_to_update {
 523                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 524                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 525                            if let Some((busy, contacts)) = busy.zip(contacts) {
 526                                let pool = pool.lock();
 527                                let updated_contact = contact_for_user(user_id, busy, &pool);
 528                                for contact in contacts {
 529                                    if let db::Contact::Accepted {
 530                                        user_id: contact_user_id,
 531                                        ..
 532                                    } = contact
 533                                    {
 534                                        for contact_conn_id in
 535                                            pool.user_connection_ids(contact_user_id)
 536                                        {
 537                                            peer.send(
 538                                                contact_conn_id,
 539                                                proto::UpdateContacts {
 540                                                    contacts: vec![updated_contact.clone()],
 541                                                    remove_contacts: Default::default(),
 542                                                    incoming_requests: Default::default(),
 543                                                    remove_incoming_requests: Default::default(),
 544                                                    outgoing_requests: Default::default(),
 545                                                    remove_outgoing_requests: Default::default(),
 546                                                },
 547                                            )
 548                                            .trace_err();
 549                                        }
 550                                    }
 551                                }
 552                            }
 553                        }
 554
 555                        if let Some(live_kit) = livekit_client.as_ref() {
 556                            if delete_livekit_room {
 557                                live_kit.delete_room(livekit_room).await.trace_err();
 558                            }
 559                        }
 560                    }
 561                }
 562
 563                app_state
 564                    .db
 565                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 566                    .await
 567                    .trace_err();
 568            }
 569            .instrument(span),
 570        );
 571        Ok(())
 572    }
 573
 574    pub fn teardown(&self) {
 575        self.peer.teardown();
 576        self.connection_pool.lock().reset();
 577        let _ = self.teardown.send(true);
 578    }
 579
 580    #[cfg(test)]
 581    pub fn reset(&self, id: ServerId) {
 582        self.teardown();
 583        *self.id.lock() = id;
 584        self.peer.reset(id.0 as u32);
 585        let _ = self.teardown.send(false);
 586    }
 587
 588    #[cfg(test)]
 589    pub fn id(&self) -> ServerId {
 590        *self.id.lock()
 591    }
 592
 593    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 594    where
 595        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 596        Fut: 'static + Send + Future<Output = Result<()>>,
 597        M: EnvelopedMessage,
 598    {
 599        let prev_handler = self.handlers.insert(
 600            TypeId::of::<M>(),
 601            Box::new(move |envelope, session| {
 602                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 603                let received_at = envelope.received_at;
 604                tracing::info!("message received");
 605                let start_time = Instant::now();
 606                let future = (handler)(*envelope, session);
 607                async move {
 608                    let result = future.await;
 609                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 610                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 611                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 612                    let payload_type = M::NAME;
 613
 614                    match result {
 615                        Err(error) => {
 616                            tracing::error!(
 617                                ?error,
 618                                total_duration_ms,
 619                                processing_duration_ms,
 620                                queue_duration_ms,
 621                                payload_type,
 622                                "error handling message"
 623                            )
 624                        }
 625                        Ok(()) => tracing::info!(
 626                            total_duration_ms,
 627                            processing_duration_ms,
 628                            queue_duration_ms,
 629                            "finished handling message"
 630                        ),
 631                    }
 632                }
 633                .boxed()
 634            }),
 635        );
 636        if prev_handler.is_some() {
 637            panic!("registered a handler for the same message twice");
 638        }
 639        self
 640    }
 641
 642    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 643    where
 644        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 645        Fut: 'static + Send + Future<Output = Result<()>>,
 646        M: EnvelopedMessage,
 647    {
 648        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 649        self
 650    }
 651
 652    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 653    where
 654        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 655        Fut: Send + Future<Output = Result<()>>,
 656        M: RequestMessage,
 657    {
 658        let handler = Arc::new(handler);
 659        self.add_handler(move |envelope, session| {
 660            let receipt = envelope.receipt();
 661            let handler = handler.clone();
 662            async move {
 663                let peer = session.peer.clone();
 664                let responded = Arc::new(AtomicBool::default());
 665                let response = Response {
 666                    peer: peer.clone(),
 667                    responded: responded.clone(),
 668                    receipt,
 669                };
 670                match (handler)(envelope.payload, response, session).await {
 671                    Ok(()) => {
 672                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 673                            Ok(())
 674                        } else {
 675                            Err(anyhow!("handler did not send a response"))?
 676                        }
 677                    }
 678                    Err(error) => {
 679                        let proto_err = match &error {
 680                            Error::Internal(err) => err.to_proto(),
 681                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 682                        };
 683                        peer.respond_with_error(receipt, proto_err)?;
 684                        Err(error)
 685                    }
 686                }
 687            }
 688        })
 689    }
 690
 691    #[allow(clippy::too_many_arguments)]
 692    pub fn handle_connection(
 693        self: &Arc<Self>,
 694        connection: Connection,
 695        address: String,
 696        principal: Principal,
 697        zed_version: ZedVersion,
 698        geoip_country_code: Option<String>,
 699        system_id: Option<String>,
 700        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 701        executor: Executor,
 702    ) -> impl Future<Output = ()> {
 703        let this = self.clone();
 704        let span = info_span!("handle connection", %address,
 705            connection_id=field::Empty,
 706            user_id=field::Empty,
 707            login=field::Empty,
 708            impersonator=field::Empty,
 709            geoip_country_code=field::Empty
 710        );
 711        principal.update_span(&span);
 712        if let Some(country_code) = geoip_country_code.as_ref() {
 713            span.record("geoip_country_code", country_code);
 714        }
 715
 716        let mut teardown = self.teardown.subscribe();
 717        async move {
 718            if *teardown.borrow() {
 719                tracing::error!("server is tearing down");
 720                return
 721            }
 722            let (connection_id, handle_io, mut incoming_rx) = this
 723                .peer
 724                .add_connection(connection, {
 725                    let executor = executor.clone();
 726                    move |duration| executor.sleep(duration)
 727                });
 728            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 729
 730            tracing::info!("connection opened");
 731
 732            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 733            let http_client = match ReqwestClient::user_agent(&user_agent) {
 734                Ok(http_client) => Arc::new(http_client),
 735                Err(error) => {
 736                    tracing::error!(?error, "failed to create HTTP client");
 737                    return;
 738                }
 739            };
 740
 741            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 742                    supermaven_admin_api_key.to_string(),
 743                    http_client.clone(),
 744                )));
 745
 746            let session = Session {
 747                principal: principal.clone(),
 748                connection_id,
 749                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 750                peer: this.peer.clone(),
 751                connection_pool: this.connection_pool.clone(),
 752                app_state: this.app_state.clone(),
 753                http_client,
 754                geoip_country_code,
 755                system_id,
 756                _executor: executor.clone(),
 757                supermaven_client,
 758            };
 759
 760            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 761                tracing::error!(?error, "failed to send initial client update");
 762                return;
 763            }
 764
 765            let handle_io = handle_io.fuse();
 766            futures::pin_mut!(handle_io);
 767
 768            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 769            // This prevents deadlocks when e.g., client A performs a request to client B and
 770            // client B performs a request to client A. If both clients stop processing further
 771            // messages until their respective request completes, they won't have a chance to
 772            // respond to the other client's request and cause a deadlock.
 773            //
 774            // This arrangement ensures we will attempt to process earlier messages first, but fall
 775            // back to processing messages arrived later in the spirit of making progress.
 776            let mut foreground_message_handlers = FuturesUnordered::new();
 777            let concurrent_handlers = Arc::new(Semaphore::new(256));
 778            loop {
 779                let next_message = async {
 780                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 781                    let message = incoming_rx.next().await;
 782                    (permit, message)
 783                }.fuse();
 784                futures::pin_mut!(next_message);
 785                futures::select_biased! {
 786                    _ = teardown.changed().fuse() => return,
 787                    result = handle_io => {
 788                        if let Err(error) = result {
 789                            tracing::error!(?error, "error handling I/O");
 790                        }
 791                        break;
 792                    }
 793                    _ = foreground_message_handlers.next() => {}
 794                    next_message = next_message => {
 795                        let (permit, message) = next_message;
 796                        if let Some(message) = message {
 797                            let type_name = message.payload_type_name();
 798                            // note: we copy all the fields from the parent span so we can query them in the logs.
 799                            // (https://github.com/tokio-rs/tracing/issues/2670).
 800                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 801                                user_id=field::Empty,
 802                                login=field::Empty,
 803                                impersonator=field::Empty,
 804                            );
 805                            principal.update_span(&span);
 806                            let span_enter = span.enter();
 807                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 808                                let is_background = message.is_background();
 809                                let handle_message = (handler)(message, session.clone());
 810                                drop(span_enter);
 811
 812                                let handle_message = async move {
 813                                    handle_message.await;
 814                                    drop(permit);
 815                                }.instrument(span);
 816                                if is_background {
 817                                    executor.spawn_detached(handle_message);
 818                                } else {
 819                                    foreground_message_handlers.push(handle_message);
 820                                }
 821                            } else {
 822                                tracing::error!("no message handler");
 823                            }
 824                        } else {
 825                            tracing::info!("connection closed");
 826                            break;
 827                        }
 828                    }
 829                }
 830            }
 831
 832            drop(foreground_message_handlers);
 833            tracing::info!("signing out");
 834            if let Err(error) = connection_lost(session, teardown, executor).await {
 835                tracing::error!(?error, "error signing out");
 836            }
 837
 838        }.instrument(span)
 839    }
 840
 841    async fn send_initial_client_update(
 842        &self,
 843        connection_id: ConnectionId,
 844        principal: &Principal,
 845        zed_version: ZedVersion,
 846        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 847        session: &Session,
 848    ) -> Result<()> {
 849        self.peer.send(
 850            connection_id,
 851            proto::Hello {
 852                peer_id: Some(connection_id.into()),
 853            },
 854        )?;
 855        tracing::info!("sent hello message");
 856        if let Some(send_connection_id) = send_connection_id.take() {
 857            let _ = send_connection_id.send(connection_id);
 858        }
 859
 860        match principal {
 861            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 862                if !user.connected_once {
 863                    self.peer.send(connection_id, proto::ShowContacts {})?;
 864                    self.app_state
 865                        .db
 866                        .set_user_connected_once(user.id, true)
 867                        .await?;
 868                }
 869
 870                update_user_plan(user.id, session).await?;
 871
 872                let contacts = self.app_state.db.get_contacts(user.id).await?;
 873
 874                {
 875                    let mut pool = self.connection_pool.lock();
 876                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 877                    self.peer.send(
 878                        connection_id,
 879                        build_initial_contacts_update(contacts, &pool),
 880                    )?;
 881                }
 882
 883                if should_auto_subscribe_to_channels(zed_version) {
 884                    subscribe_user_to_channels(user.id, session).await?;
 885                }
 886
 887                if let Some(incoming_call) =
 888                    self.app_state.db.incoming_call_for_user(user.id).await?
 889                {
 890                    self.peer.send(connection_id, incoming_call)?;
 891                }
 892
 893                update_user_contacts(user.id, session).await?;
 894            }
 895        }
 896
 897        Ok(())
 898    }
 899
 900    pub async fn invite_code_redeemed(
 901        self: &Arc<Self>,
 902        inviter_id: UserId,
 903        invitee_id: UserId,
 904    ) -> Result<()> {
 905        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 906            if let Some(code) = &user.invite_code {
 907                let pool = self.connection_pool.lock();
 908                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 909                for connection_id in pool.user_connection_ids(inviter_id) {
 910                    self.peer.send(
 911                        connection_id,
 912                        proto::UpdateContacts {
 913                            contacts: vec![invitee_contact.clone()],
 914                            ..Default::default()
 915                        },
 916                    )?;
 917                    self.peer.send(
 918                        connection_id,
 919                        proto::UpdateInviteInfo {
 920                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 921                            count: user.invite_count as u32,
 922                        },
 923                    )?;
 924                }
 925            }
 926        }
 927        Ok(())
 928    }
 929
 930    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 931        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 932            if let Some(invite_code) = &user.invite_code {
 933                let pool = self.connection_pool.lock();
 934                for connection_id in pool.user_connection_ids(user_id) {
 935                    self.peer.send(
 936                        connection_id,
 937                        proto::UpdateInviteInfo {
 938                            url: format!(
 939                                "{}{}",
 940                                self.app_state.config.invite_link_prefix, invite_code
 941                            ),
 942                            count: user.invite_count as u32,
 943                        },
 944                    )?;
 945                }
 946            }
 947        }
 948        Ok(())
 949    }
 950
 951    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 952        let pool = self.connection_pool.lock();
 953        for connection_id in pool.user_connection_ids(user_id) {
 954            self.peer
 955                .send(connection_id, proto::RefreshLlmToken {})
 956                .trace_err();
 957        }
 958    }
 959
 960    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 961        ServerSnapshot {
 962            connection_pool: ConnectionPoolGuard {
 963                guard: self.connection_pool.lock(),
 964                _not_send: PhantomData,
 965            },
 966            peer: &self.peer,
 967        }
 968    }
 969}
 970
 971impl<'a> Deref for ConnectionPoolGuard<'a> {
 972    type Target = ConnectionPool;
 973
 974    fn deref(&self) -> &Self::Target {
 975        &self.guard
 976    }
 977}
 978
 979impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 980    fn deref_mut(&mut self) -> &mut Self::Target {
 981        &mut self.guard
 982    }
 983}
 984
 985impl<'a> Drop for ConnectionPoolGuard<'a> {
 986    fn drop(&mut self) {
 987        #[cfg(test)]
 988        self.check_invariants();
 989    }
 990}
 991
 992fn broadcast<F>(
 993    sender_id: Option<ConnectionId>,
 994    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 995    mut f: F,
 996) where
 997    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 998{
 999    for receiver_id in receiver_ids {
1000        if Some(receiver_id) != sender_id {
1001            if let Err(error) = f(receiver_id) {
1002                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1003            }
1004        }
1005    }
1006}
1007
1008pub struct ProtocolVersion(u32);
1009
1010impl Header for ProtocolVersion {
1011    fn name() -> &'static HeaderName {
1012        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1013        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1014    }
1015
1016    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1017    where
1018        Self: Sized,
1019        I: Iterator<Item = &'i axum::http::HeaderValue>,
1020    {
1021        let version = values
1022            .next()
1023            .ok_or_else(axum::headers::Error::invalid)?
1024            .to_str()
1025            .map_err(|_| axum::headers::Error::invalid())?
1026            .parse()
1027            .map_err(|_| axum::headers::Error::invalid())?;
1028        Ok(Self(version))
1029    }
1030
1031    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1032        values.extend([self.0.to_string().parse().unwrap()]);
1033    }
1034}
1035
1036pub struct AppVersionHeader(SemanticVersion);
1037impl Header for AppVersionHeader {
1038    fn name() -> &'static HeaderName {
1039        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1040        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1041    }
1042
1043    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1044    where
1045        Self: Sized,
1046        I: Iterator<Item = &'i axum::http::HeaderValue>,
1047    {
1048        let version = values
1049            .next()
1050            .ok_or_else(axum::headers::Error::invalid)?
1051            .to_str()
1052            .map_err(|_| axum::headers::Error::invalid())?
1053            .parse()
1054            .map_err(|_| axum::headers::Error::invalid())?;
1055        Ok(Self(version))
1056    }
1057
1058    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1059        values.extend([self.0.to_string().parse().unwrap()]);
1060    }
1061}
1062
1063pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1064    Router::new()
1065        .route("/rpc", get(handle_websocket_request))
1066        .layer(
1067            ServiceBuilder::new()
1068                .layer(Extension(server.app_state.clone()))
1069                .layer(middleware::from_fn(auth::validate_header)),
1070        )
1071        .route("/metrics", get(handle_metrics))
1072        .layer(Extension(server))
1073}
1074
1075#[allow(clippy::too_many_arguments)]
1076pub async fn handle_websocket_request(
1077    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1078    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1079    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1080    Extension(server): Extension<Arc<Server>>,
1081    Extension(principal): Extension<Principal>,
1082    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1083    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1084    ws: WebSocketUpgrade,
1085) -> axum::response::Response {
1086    if protocol_version != rpc::PROTOCOL_VERSION {
1087        return (
1088            StatusCode::UPGRADE_REQUIRED,
1089            "client must be upgraded".to_string(),
1090        )
1091            .into_response();
1092    }
1093
1094    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1095        return (
1096            StatusCode::UPGRADE_REQUIRED,
1097            "no version header found".to_string(),
1098        )
1099            .into_response();
1100    };
1101
1102    if !version.can_collaborate() {
1103        return (
1104            StatusCode::UPGRADE_REQUIRED,
1105            "client must be upgraded".to_string(),
1106        )
1107            .into_response();
1108    }
1109
1110    let socket_address = socket_address.to_string();
1111    ws.on_upgrade(move |socket| {
1112        let socket = socket
1113            .map_ok(to_tungstenite_message)
1114            .err_into()
1115            .with(|message| async move { to_axum_message(message) });
1116        let connection = Connection::new(Box::pin(socket));
1117        async move {
1118            server
1119                .handle_connection(
1120                    connection,
1121                    socket_address,
1122                    principal,
1123                    version,
1124                    country_code_header.map(|header| header.to_string()),
1125                    system_id_header.map(|header| header.to_string()),
1126                    None,
1127                    Executor::Production,
1128                )
1129                .await;
1130        }
1131    })
1132}
1133
1134pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1135    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1136    let connections_metric = CONNECTIONS_METRIC
1137        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1138
1139    let connections = server
1140        .connection_pool
1141        .lock()
1142        .connections()
1143        .filter(|connection| !connection.admin)
1144        .count();
1145    connections_metric.set(connections as _);
1146
1147    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1148    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1149        register_int_gauge!(
1150            "shared_projects",
1151            "number of open projects with one or more guests"
1152        )
1153        .unwrap()
1154    });
1155
1156    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1157    shared_projects_metric.set(shared_projects as _);
1158
1159    let encoder = prometheus::TextEncoder::new();
1160    let metric_families = prometheus::gather();
1161    let encoded_metrics = encoder
1162        .encode_to_string(&metric_families)
1163        .map_err(|err| anyhow!("{}", err))?;
1164    Ok(encoded_metrics)
1165}
1166
1167#[instrument(err, skip(executor))]
1168async fn connection_lost(
1169    session: Session,
1170    mut teardown: watch::Receiver<bool>,
1171    executor: Executor,
1172) -> Result<()> {
1173    session.peer.disconnect(session.connection_id);
1174    session
1175        .connection_pool()
1176        .await
1177        .remove_connection(session.connection_id)?;
1178
1179    session
1180        .db()
1181        .await
1182        .connection_lost(session.connection_id)
1183        .await
1184        .trace_err();
1185
1186    futures::select_biased! {
1187        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1188
1189            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1190            leave_room_for_session(&session, session.connection_id).await.trace_err();
1191            leave_channel_buffers_for_session(&session)
1192                .await
1193                .trace_err();
1194
1195            if !session
1196                .connection_pool()
1197                .await
1198                .is_user_online(session.user_id())
1199            {
1200                let db = session.db().await;
1201                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1202                    room_updated(&room, &session.peer);
1203                }
1204            }
1205
1206            update_user_contacts(session.user_id(), &session).await?;
1207        },
1208        _ = teardown.changed().fuse() => {}
1209    }
1210
1211    Ok(())
1212}
1213
1214/// Acknowledges a ping from a client, used to keep the connection alive.
1215async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1216    response.send(proto::Ack {})?;
1217    Ok(())
1218}
1219
1220/// Creates a new room for calling (outside of channels)
1221async fn create_room(
1222    _request: proto::CreateRoom,
1223    response: Response<proto::CreateRoom>,
1224    session: Session,
1225) -> Result<()> {
1226    let livekit_room = nanoid::nanoid!(30);
1227
1228    let live_kit_connection_info = util::maybe!(async {
1229        let live_kit = session.app_state.livekit_client.as_ref();
1230        let live_kit = live_kit?;
1231        let user_id = session.user_id().to_string();
1232
1233        let token = live_kit
1234            .room_token(&livekit_room, &user_id.to_string())
1235            .trace_err()?;
1236
1237        Some(proto::LiveKitConnectionInfo {
1238            server_url: live_kit.url().into(),
1239            token,
1240            can_publish: true,
1241        })
1242    })
1243    .await;
1244
1245    let room = session
1246        .db()
1247        .await
1248        .create_room(session.user_id(), session.connection_id, &livekit_room)
1249        .await?;
1250
1251    response.send(proto::CreateRoomResponse {
1252        room: Some(room.clone()),
1253        live_kit_connection_info,
1254    })?;
1255
1256    update_user_contacts(session.user_id(), &session).await?;
1257    Ok(())
1258}
1259
1260/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1261async fn join_room(
1262    request: proto::JoinRoom,
1263    response: Response<proto::JoinRoom>,
1264    session: Session,
1265) -> Result<()> {
1266    let room_id = RoomId::from_proto(request.id);
1267
1268    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1269
1270    if let Some(channel_id) = channel_id {
1271        return join_channel_internal(channel_id, Box::new(response), session).await;
1272    }
1273
1274    let joined_room = {
1275        let room = session
1276            .db()
1277            .await
1278            .join_room(room_id, session.user_id(), session.connection_id)
1279            .await?;
1280        room_updated(&room.room, &session.peer);
1281        room.into_inner()
1282    };
1283
1284    for connection_id in session
1285        .connection_pool()
1286        .await
1287        .user_connection_ids(session.user_id())
1288    {
1289        session
1290            .peer
1291            .send(
1292                connection_id,
1293                proto::CallCanceled {
1294                    room_id: room_id.to_proto(),
1295                },
1296            )
1297            .trace_err();
1298    }
1299
1300    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1301    {
1302        live_kit
1303            .room_token(
1304                &joined_room.room.livekit_room,
1305                &session.user_id().to_string(),
1306            )
1307            .trace_err()
1308            .map(|token| proto::LiveKitConnectionInfo {
1309                server_url: live_kit.url().into(),
1310                token,
1311                can_publish: true,
1312            })
1313    } else {
1314        None
1315    };
1316
1317    response.send(proto::JoinRoomResponse {
1318        room: Some(joined_room.room),
1319        channel_id: None,
1320        live_kit_connection_info,
1321    })?;
1322
1323    update_user_contacts(session.user_id(), &session).await?;
1324    Ok(())
1325}
1326
1327/// Rejoin room is used to reconnect to a room after connection errors.
1328async fn rejoin_room(
1329    request: proto::RejoinRoom,
1330    response: Response<proto::RejoinRoom>,
1331    session: Session,
1332) -> Result<()> {
1333    let room;
1334    let channel;
1335    {
1336        let mut rejoined_room = session
1337            .db()
1338            .await
1339            .rejoin_room(request, session.user_id(), session.connection_id)
1340            .await?;
1341
1342        response.send(proto::RejoinRoomResponse {
1343            room: Some(rejoined_room.room.clone()),
1344            reshared_projects: rejoined_room
1345                .reshared_projects
1346                .iter()
1347                .map(|project| proto::ResharedProject {
1348                    id: project.id.to_proto(),
1349                    collaborators: project
1350                        .collaborators
1351                        .iter()
1352                        .map(|collaborator| collaborator.to_proto())
1353                        .collect(),
1354                })
1355                .collect(),
1356            rejoined_projects: rejoined_room
1357                .rejoined_projects
1358                .iter()
1359                .map(|rejoined_project| rejoined_project.to_proto())
1360                .collect(),
1361        })?;
1362        room_updated(&rejoined_room.room, &session.peer);
1363
1364        for project in &rejoined_room.reshared_projects {
1365            for collaborator in &project.collaborators {
1366                session
1367                    .peer
1368                    .send(
1369                        collaborator.connection_id,
1370                        proto::UpdateProjectCollaborator {
1371                            project_id: project.id.to_proto(),
1372                            old_peer_id: Some(project.old_connection_id.into()),
1373                            new_peer_id: Some(session.connection_id.into()),
1374                        },
1375                    )
1376                    .trace_err();
1377            }
1378
1379            broadcast(
1380                Some(session.connection_id),
1381                project
1382                    .collaborators
1383                    .iter()
1384                    .map(|collaborator| collaborator.connection_id),
1385                |connection_id| {
1386                    session.peer.forward_send(
1387                        session.connection_id,
1388                        connection_id,
1389                        proto::UpdateProject {
1390                            project_id: project.id.to_proto(),
1391                            worktrees: project.worktrees.clone(),
1392                        },
1393                    )
1394                },
1395            );
1396        }
1397
1398        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1399
1400        let rejoined_room = rejoined_room.into_inner();
1401
1402        room = rejoined_room.room;
1403        channel = rejoined_room.channel;
1404    }
1405
1406    if let Some(channel) = channel {
1407        channel_updated(
1408            &channel,
1409            &room,
1410            &session.peer,
1411            &*session.connection_pool().await,
1412        );
1413    }
1414
1415    update_user_contacts(session.user_id(), &session).await?;
1416    Ok(())
1417}
1418
1419fn notify_rejoined_projects(
1420    rejoined_projects: &mut Vec<RejoinedProject>,
1421    session: &Session,
1422) -> Result<()> {
1423    for project in rejoined_projects.iter() {
1424        for collaborator in &project.collaborators {
1425            session
1426                .peer
1427                .send(
1428                    collaborator.connection_id,
1429                    proto::UpdateProjectCollaborator {
1430                        project_id: project.id.to_proto(),
1431                        old_peer_id: Some(project.old_connection_id.into()),
1432                        new_peer_id: Some(session.connection_id.into()),
1433                    },
1434                )
1435                .trace_err();
1436        }
1437    }
1438
1439    for project in rejoined_projects {
1440        for worktree in mem::take(&mut project.worktrees) {
1441            // Stream this worktree's entries.
1442            let message = proto::UpdateWorktree {
1443                project_id: project.id.to_proto(),
1444                worktree_id: worktree.id,
1445                abs_path: worktree.abs_path.clone(),
1446                root_name: worktree.root_name,
1447                updated_entries: worktree.updated_entries,
1448                removed_entries: worktree.removed_entries,
1449                scan_id: worktree.scan_id,
1450                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1451                updated_repositories: worktree.updated_repositories,
1452                removed_repositories: worktree.removed_repositories,
1453            };
1454            for update in proto::split_worktree_update(message) {
1455                session.peer.send(session.connection_id, update.clone())?;
1456            }
1457
1458            // Stream this worktree's diagnostics.
1459            for summary in worktree.diagnostic_summaries {
1460                session.peer.send(
1461                    session.connection_id,
1462                    proto::UpdateDiagnosticSummary {
1463                        project_id: project.id.to_proto(),
1464                        worktree_id: worktree.id,
1465                        summary: Some(summary),
1466                    },
1467                )?;
1468            }
1469
1470            for settings_file in worktree.settings_files {
1471                session.peer.send(
1472                    session.connection_id,
1473                    proto::UpdateWorktreeSettings {
1474                        project_id: project.id.to_proto(),
1475                        worktree_id: worktree.id,
1476                        path: settings_file.path,
1477                        content: Some(settings_file.content),
1478                        kind: Some(settings_file.kind.to_proto().into()),
1479                    },
1480                )?;
1481            }
1482        }
1483
1484        for language_server in &project.language_servers {
1485            session.peer.send(
1486                session.connection_id,
1487                proto::UpdateLanguageServer {
1488                    project_id: project.id.to_proto(),
1489                    language_server_id: language_server.id,
1490                    variant: Some(
1491                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1492                            proto::LspDiskBasedDiagnosticsUpdated {},
1493                        ),
1494                    ),
1495                },
1496            )?;
1497        }
1498    }
1499    Ok(())
1500}
1501
1502/// leave room disconnects from the room.
1503async fn leave_room(
1504    _: proto::LeaveRoom,
1505    response: Response<proto::LeaveRoom>,
1506    session: Session,
1507) -> Result<()> {
1508    leave_room_for_session(&session, session.connection_id).await?;
1509    response.send(proto::Ack {})?;
1510    Ok(())
1511}
1512
1513/// Updates the permissions of someone else in the room.
1514async fn set_room_participant_role(
1515    request: proto::SetRoomParticipantRole,
1516    response: Response<proto::SetRoomParticipantRole>,
1517    session: Session,
1518) -> Result<()> {
1519    let user_id = UserId::from_proto(request.user_id);
1520    let role = ChannelRole::from(request.role());
1521
1522    let (livekit_room, can_publish) = {
1523        let room = session
1524            .db()
1525            .await
1526            .set_room_participant_role(
1527                session.user_id(),
1528                RoomId::from_proto(request.room_id),
1529                user_id,
1530                role,
1531            )
1532            .await?;
1533
1534        let livekit_room = room.livekit_room.clone();
1535        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1536        room_updated(&room, &session.peer);
1537        (livekit_room, can_publish)
1538    };
1539
1540    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1541        live_kit
1542            .update_participant(
1543                livekit_room.clone(),
1544                request.user_id.to_string(),
1545                livekit_server::proto::ParticipantPermission {
1546                    can_subscribe: true,
1547                    can_publish,
1548                    can_publish_data: can_publish,
1549                    hidden: false,
1550                    recorder: false,
1551                },
1552            )
1553            .await
1554            .trace_err();
1555    }
1556
1557    response.send(proto::Ack {})?;
1558    Ok(())
1559}
1560
1561/// Call someone else into the current room
1562async fn call(
1563    request: proto::Call,
1564    response: Response<proto::Call>,
1565    session: Session,
1566) -> Result<()> {
1567    let room_id = RoomId::from_proto(request.room_id);
1568    let calling_user_id = session.user_id();
1569    let calling_connection_id = session.connection_id;
1570    let called_user_id = UserId::from_proto(request.called_user_id);
1571    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1572    if !session
1573        .db()
1574        .await
1575        .has_contact(calling_user_id, called_user_id)
1576        .await?
1577    {
1578        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1579    }
1580
1581    let incoming_call = {
1582        let (room, incoming_call) = &mut *session
1583            .db()
1584            .await
1585            .call(
1586                room_id,
1587                calling_user_id,
1588                calling_connection_id,
1589                called_user_id,
1590                initial_project_id,
1591            )
1592            .await?;
1593        room_updated(room, &session.peer);
1594        mem::take(incoming_call)
1595    };
1596    update_user_contacts(called_user_id, &session).await?;
1597
1598    let mut calls = session
1599        .connection_pool()
1600        .await
1601        .user_connection_ids(called_user_id)
1602        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1603        .collect::<FuturesUnordered<_>>();
1604
1605    while let Some(call_response) = calls.next().await {
1606        match call_response.as_ref() {
1607            Ok(_) => {
1608                response.send(proto::Ack {})?;
1609                return Ok(());
1610            }
1611            Err(_) => {
1612                call_response.trace_err();
1613            }
1614        }
1615    }
1616
1617    {
1618        let room = session
1619            .db()
1620            .await
1621            .call_failed(room_id, called_user_id)
1622            .await?;
1623        room_updated(&room, &session.peer);
1624    }
1625    update_user_contacts(called_user_id, &session).await?;
1626
1627    Err(anyhow!("failed to ring user"))?
1628}
1629
1630/// Cancel an outgoing call.
1631async fn cancel_call(
1632    request: proto::CancelCall,
1633    response: Response<proto::CancelCall>,
1634    session: Session,
1635) -> Result<()> {
1636    let called_user_id = UserId::from_proto(request.called_user_id);
1637    let room_id = RoomId::from_proto(request.room_id);
1638    {
1639        let room = session
1640            .db()
1641            .await
1642            .cancel_call(room_id, session.connection_id, called_user_id)
1643            .await?;
1644        room_updated(&room, &session.peer);
1645    }
1646
1647    for connection_id in session
1648        .connection_pool()
1649        .await
1650        .user_connection_ids(called_user_id)
1651    {
1652        session
1653            .peer
1654            .send(
1655                connection_id,
1656                proto::CallCanceled {
1657                    room_id: room_id.to_proto(),
1658                },
1659            )
1660            .trace_err();
1661    }
1662    response.send(proto::Ack {})?;
1663
1664    update_user_contacts(called_user_id, &session).await?;
1665    Ok(())
1666}
1667
1668/// Decline an incoming call.
1669async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1670    let room_id = RoomId::from_proto(message.room_id);
1671    {
1672        let room = session
1673            .db()
1674            .await
1675            .decline_call(Some(room_id), session.user_id())
1676            .await?
1677            .ok_or_else(|| anyhow!("failed to decline call"))?;
1678        room_updated(&room, &session.peer);
1679    }
1680
1681    for connection_id in session
1682        .connection_pool()
1683        .await
1684        .user_connection_ids(session.user_id())
1685    {
1686        session
1687            .peer
1688            .send(
1689                connection_id,
1690                proto::CallCanceled {
1691                    room_id: room_id.to_proto(),
1692                },
1693            )
1694            .trace_err();
1695    }
1696    update_user_contacts(session.user_id(), &session).await?;
1697    Ok(())
1698}
1699
1700/// Updates other participants in the room with your current location.
1701async fn update_participant_location(
1702    request: proto::UpdateParticipantLocation,
1703    response: Response<proto::UpdateParticipantLocation>,
1704    session: Session,
1705) -> Result<()> {
1706    let room_id = RoomId::from_proto(request.room_id);
1707    let location = request
1708        .location
1709        .ok_or_else(|| anyhow!("invalid location"))?;
1710
1711    let db = session.db().await;
1712    let room = db
1713        .update_room_participant_location(room_id, session.connection_id, location)
1714        .await?;
1715
1716    room_updated(&room, &session.peer);
1717    response.send(proto::Ack {})?;
1718    Ok(())
1719}
1720
1721/// Share a project into the room.
1722async fn share_project(
1723    request: proto::ShareProject,
1724    response: Response<proto::ShareProject>,
1725    session: Session,
1726) -> Result<()> {
1727    let (project_id, room) = &*session
1728        .db()
1729        .await
1730        .share_project(
1731            RoomId::from_proto(request.room_id),
1732            session.connection_id,
1733            &request.worktrees,
1734            request.is_ssh_project,
1735        )
1736        .await?;
1737    response.send(proto::ShareProjectResponse {
1738        project_id: project_id.to_proto(),
1739    })?;
1740    room_updated(room, &session.peer);
1741
1742    Ok(())
1743}
1744
1745/// Unshare a project from the room.
1746async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1747    let project_id = ProjectId::from_proto(message.project_id);
1748    unshare_project_internal(project_id, session.connection_id, &session).await
1749}
1750
1751async fn unshare_project_internal(
1752    project_id: ProjectId,
1753    connection_id: ConnectionId,
1754    session: &Session,
1755) -> Result<()> {
1756    let delete = {
1757        let room_guard = session
1758            .db()
1759            .await
1760            .unshare_project(project_id, connection_id)
1761            .await?;
1762
1763        let (delete, room, guest_connection_ids) = &*room_guard;
1764
1765        let message = proto::UnshareProject {
1766            project_id: project_id.to_proto(),
1767        };
1768
1769        broadcast(
1770            Some(connection_id),
1771            guest_connection_ids.iter().copied(),
1772            |conn_id| session.peer.send(conn_id, message.clone()),
1773        );
1774        if let Some(room) = room {
1775            room_updated(room, &session.peer);
1776        }
1777
1778        *delete
1779    };
1780
1781    if delete {
1782        let db = session.db().await;
1783        db.delete_project(project_id).await?;
1784    }
1785
1786    Ok(())
1787}
1788
1789/// Join someone elses shared project.
1790async fn join_project(
1791    request: proto::JoinProject,
1792    response: Response<proto::JoinProject>,
1793    session: Session,
1794) -> Result<()> {
1795    let project_id = ProjectId::from_proto(request.project_id);
1796
1797    tracing::info!(%project_id, "join project");
1798
1799    let db = session.db().await;
1800    let (project, replica_id) = &mut *db
1801        .join_project(project_id, session.connection_id, session.user_id())
1802        .await?;
1803    drop(db);
1804    tracing::info!(%project_id, "join remote project");
1805    join_project_internal(response, session, project, replica_id)
1806}
1807
1808trait JoinProjectInternalResponse {
1809    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1810}
1811impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1812    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1813        Response::<proto::JoinProject>::send(self, result)
1814    }
1815}
1816
1817fn join_project_internal(
1818    response: impl JoinProjectInternalResponse,
1819    session: Session,
1820    project: &mut Project,
1821    replica_id: &ReplicaId,
1822) -> Result<()> {
1823    let collaborators = project
1824        .collaborators
1825        .iter()
1826        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1827        .map(|collaborator| collaborator.to_proto())
1828        .collect::<Vec<_>>();
1829    let project_id = project.id;
1830    let guest_user_id = session.user_id();
1831
1832    let worktrees = project
1833        .worktrees
1834        .iter()
1835        .map(|(id, worktree)| proto::WorktreeMetadata {
1836            id: *id,
1837            root_name: worktree.root_name.clone(),
1838            visible: worktree.visible,
1839            abs_path: worktree.abs_path.clone(),
1840        })
1841        .collect::<Vec<_>>();
1842
1843    let add_project_collaborator = proto::AddProjectCollaborator {
1844        project_id: project_id.to_proto(),
1845        collaborator: Some(proto::Collaborator {
1846            peer_id: Some(session.connection_id.into()),
1847            replica_id: replica_id.0 as u32,
1848            user_id: guest_user_id.to_proto(),
1849            is_host: false,
1850        }),
1851    };
1852
1853    for collaborator in &collaborators {
1854        session
1855            .peer
1856            .send(
1857                collaborator.peer_id.unwrap().into(),
1858                add_project_collaborator.clone(),
1859            )
1860            .trace_err();
1861    }
1862
1863    // First, we send the metadata associated with each worktree.
1864    response.send(proto::JoinProjectResponse {
1865        project_id: project.id.0 as u64,
1866        worktrees: worktrees.clone(),
1867        replica_id: replica_id.0 as u32,
1868        collaborators: collaborators.clone(),
1869        language_servers: project.language_servers.clone(),
1870        role: project.role.into(),
1871    })?;
1872
1873    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1874        // Stream this worktree's entries.
1875        let message = proto::UpdateWorktree {
1876            project_id: project_id.to_proto(),
1877            worktree_id,
1878            abs_path: worktree.abs_path.clone(),
1879            root_name: worktree.root_name,
1880            updated_entries: worktree.entries,
1881            removed_entries: Default::default(),
1882            scan_id: worktree.scan_id,
1883            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1884            updated_repositories: worktree.repository_entries.into_values().collect(),
1885            removed_repositories: Default::default(),
1886        };
1887        for update in proto::split_worktree_update(message) {
1888            session.peer.send(session.connection_id, update.clone())?;
1889        }
1890
1891        // Stream this worktree's diagnostics.
1892        for summary in worktree.diagnostic_summaries {
1893            session.peer.send(
1894                session.connection_id,
1895                proto::UpdateDiagnosticSummary {
1896                    project_id: project_id.to_proto(),
1897                    worktree_id: worktree.id,
1898                    summary: Some(summary),
1899                },
1900            )?;
1901        }
1902
1903        for settings_file in worktree.settings_files {
1904            session.peer.send(
1905                session.connection_id,
1906                proto::UpdateWorktreeSettings {
1907                    project_id: project_id.to_proto(),
1908                    worktree_id: worktree.id,
1909                    path: settings_file.path,
1910                    content: Some(settings_file.content),
1911                    kind: Some(settings_file.kind.to_proto() as i32),
1912                },
1913            )?;
1914        }
1915    }
1916
1917    for language_server in &project.language_servers {
1918        session.peer.send(
1919            session.connection_id,
1920            proto::UpdateLanguageServer {
1921                project_id: project_id.to_proto(),
1922                language_server_id: language_server.id,
1923                variant: Some(
1924                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1925                        proto::LspDiskBasedDiagnosticsUpdated {},
1926                    ),
1927                ),
1928            },
1929        )?;
1930    }
1931
1932    Ok(())
1933}
1934
1935/// Leave someone elses shared project.
1936async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1937    let sender_id = session.connection_id;
1938    let project_id = ProjectId::from_proto(request.project_id);
1939    let db = session.db().await;
1940
1941    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1942    tracing::info!(
1943        %project_id,
1944        "leave project"
1945    );
1946
1947    project_left(project, &session);
1948    if let Some(room) = room {
1949        room_updated(room, &session.peer);
1950    }
1951
1952    Ok(())
1953}
1954
1955/// Updates other participants with changes to the project
1956async fn update_project(
1957    request: proto::UpdateProject,
1958    response: Response<proto::UpdateProject>,
1959    session: Session,
1960) -> Result<()> {
1961    let project_id = ProjectId::from_proto(request.project_id);
1962    let (room, guest_connection_ids) = &*session
1963        .db()
1964        .await
1965        .update_project(project_id, session.connection_id, &request.worktrees)
1966        .await?;
1967    broadcast(
1968        Some(session.connection_id),
1969        guest_connection_ids.iter().copied(),
1970        |connection_id| {
1971            session
1972                .peer
1973                .forward_send(session.connection_id, connection_id, request.clone())
1974        },
1975    );
1976    if let Some(room) = room {
1977        room_updated(room, &session.peer);
1978    }
1979    response.send(proto::Ack {})?;
1980
1981    Ok(())
1982}
1983
1984/// Updates other participants with changes to the worktree
1985async fn update_worktree(
1986    request: proto::UpdateWorktree,
1987    response: Response<proto::UpdateWorktree>,
1988    session: Session,
1989) -> Result<()> {
1990    let guest_connection_ids = session
1991        .db()
1992        .await
1993        .update_worktree(&request, session.connection_id)
1994        .await?;
1995
1996    broadcast(
1997        Some(session.connection_id),
1998        guest_connection_ids.iter().copied(),
1999        |connection_id| {
2000            session
2001                .peer
2002                .forward_send(session.connection_id, connection_id, request.clone())
2003        },
2004    );
2005    response.send(proto::Ack {})?;
2006    Ok(())
2007}
2008
2009/// Updates other participants with changes to the diagnostics
2010async fn update_diagnostic_summary(
2011    message: proto::UpdateDiagnosticSummary,
2012    session: Session,
2013) -> Result<()> {
2014    let guest_connection_ids = session
2015        .db()
2016        .await
2017        .update_diagnostic_summary(&message, session.connection_id)
2018        .await?;
2019
2020    broadcast(
2021        Some(session.connection_id),
2022        guest_connection_ids.iter().copied(),
2023        |connection_id| {
2024            session
2025                .peer
2026                .forward_send(session.connection_id, connection_id, message.clone())
2027        },
2028    );
2029
2030    Ok(())
2031}
2032
2033/// Updates other participants with changes to the worktree settings
2034async fn update_worktree_settings(
2035    message: proto::UpdateWorktreeSettings,
2036    session: Session,
2037) -> Result<()> {
2038    let guest_connection_ids = session
2039        .db()
2040        .await
2041        .update_worktree_settings(&message, session.connection_id)
2042        .await?;
2043
2044    broadcast(
2045        Some(session.connection_id),
2046        guest_connection_ids.iter().copied(),
2047        |connection_id| {
2048            session
2049                .peer
2050                .forward_send(session.connection_id, connection_id, message.clone())
2051        },
2052    );
2053
2054    Ok(())
2055}
2056
2057/// Notify other participants that a  language server has started.
2058async fn start_language_server(
2059    request: proto::StartLanguageServer,
2060    session: Session,
2061) -> Result<()> {
2062    let guest_connection_ids = session
2063        .db()
2064        .await
2065        .start_language_server(&request, session.connection_id)
2066        .await?;
2067
2068    broadcast(
2069        Some(session.connection_id),
2070        guest_connection_ids.iter().copied(),
2071        |connection_id| {
2072            session
2073                .peer
2074                .forward_send(session.connection_id, connection_id, request.clone())
2075        },
2076    );
2077    Ok(())
2078}
2079
2080/// Notify other participants that a language server has changed.
2081async fn update_language_server(
2082    request: proto::UpdateLanguageServer,
2083    session: Session,
2084) -> Result<()> {
2085    let project_id = ProjectId::from_proto(request.project_id);
2086    let project_connection_ids = session
2087        .db()
2088        .await
2089        .project_connection_ids(project_id, session.connection_id, true)
2090        .await?;
2091    broadcast(
2092        Some(session.connection_id),
2093        project_connection_ids.iter().copied(),
2094        |connection_id| {
2095            session
2096                .peer
2097                .forward_send(session.connection_id, connection_id, request.clone())
2098        },
2099    );
2100    Ok(())
2101}
2102
2103/// forward a project request to the host. These requests should be read only
2104/// as guests are allowed to send them.
2105async fn forward_read_only_project_request<T>(
2106    request: T,
2107    response: Response<T>,
2108    session: Session,
2109) -> Result<()>
2110where
2111    T: EntityMessage + RequestMessage,
2112{
2113    let project_id = ProjectId::from_proto(request.remote_entity_id());
2114    let host_connection_id = session
2115        .db()
2116        .await
2117        .host_for_read_only_project_request(project_id, session.connection_id)
2118        .await?;
2119    let payload = session
2120        .peer
2121        .forward_request(session.connection_id, host_connection_id, request)
2122        .await?;
2123    response.send(payload)?;
2124    Ok(())
2125}
2126
2127async fn forward_find_search_candidates_request(
2128    request: proto::FindSearchCandidates,
2129    response: Response<proto::FindSearchCandidates>,
2130    session: Session,
2131) -> Result<()> {
2132    let project_id = ProjectId::from_proto(request.remote_entity_id());
2133    let host_connection_id = session
2134        .db()
2135        .await
2136        .host_for_read_only_project_request(project_id, session.connection_id)
2137        .await?;
2138    let payload = session
2139        .peer
2140        .forward_request(session.connection_id, host_connection_id, request)
2141        .await?;
2142    response.send(payload)?;
2143    Ok(())
2144}
2145
2146/// forward a project request to the host. These requests are disallowed
2147/// for guests.
2148async fn forward_mutating_project_request<T>(
2149    request: T,
2150    response: Response<T>,
2151    session: Session,
2152) -> Result<()>
2153where
2154    T: EntityMessage + RequestMessage,
2155{
2156    let project_id = ProjectId::from_proto(request.remote_entity_id());
2157
2158    let host_connection_id = session
2159        .db()
2160        .await
2161        .host_for_mutating_project_request(project_id, session.connection_id)
2162        .await?;
2163    let payload = session
2164        .peer
2165        .forward_request(session.connection_id, host_connection_id, request)
2166        .await?;
2167    response.send(payload)?;
2168    Ok(())
2169}
2170
2171/// Notify other participants that a new buffer has been created
2172async fn create_buffer_for_peer(
2173    request: proto::CreateBufferForPeer,
2174    session: Session,
2175) -> Result<()> {
2176    session
2177        .db()
2178        .await
2179        .check_user_is_project_host(
2180            ProjectId::from_proto(request.project_id),
2181            session.connection_id,
2182        )
2183        .await?;
2184    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2185    session
2186        .peer
2187        .forward_send(session.connection_id, peer_id.into(), request)?;
2188    Ok(())
2189}
2190
2191/// Notify other participants that a buffer has been updated. This is
2192/// allowed for guests as long as the update is limited to selections.
2193async fn update_buffer(
2194    request: proto::UpdateBuffer,
2195    response: Response<proto::UpdateBuffer>,
2196    session: Session,
2197) -> Result<()> {
2198    let project_id = ProjectId::from_proto(request.project_id);
2199    let mut capability = Capability::ReadOnly;
2200
2201    for op in request.operations.iter() {
2202        match op.variant {
2203            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2204            Some(_) => capability = Capability::ReadWrite,
2205        }
2206    }
2207
2208    let host = {
2209        let guard = session
2210            .db()
2211            .await
2212            .connections_for_buffer_update(project_id, session.connection_id, capability)
2213            .await?;
2214
2215        let (host, guests) = &*guard;
2216
2217        broadcast(
2218            Some(session.connection_id),
2219            guests.clone(),
2220            |connection_id| {
2221                session
2222                    .peer
2223                    .forward_send(session.connection_id, connection_id, request.clone())
2224            },
2225        );
2226
2227        *host
2228    };
2229
2230    if host != session.connection_id {
2231        session
2232            .peer
2233            .forward_request(session.connection_id, host, request.clone())
2234            .await?;
2235    }
2236
2237    response.send(proto::Ack {})?;
2238    Ok(())
2239}
2240
2241async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2242    let project_id = ProjectId::from_proto(message.project_id);
2243
2244    let operation = message.operation.as_ref().context("invalid operation")?;
2245    let capability = match operation.variant.as_ref() {
2246        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2247            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2248                match buffer_op.variant {
2249                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2250                        Capability::ReadOnly
2251                    }
2252                    _ => Capability::ReadWrite,
2253                }
2254            } else {
2255                Capability::ReadWrite
2256            }
2257        }
2258        Some(_) => Capability::ReadWrite,
2259        None => Capability::ReadOnly,
2260    };
2261
2262    let guard = session
2263        .db()
2264        .await
2265        .connections_for_buffer_update(project_id, session.connection_id, capability)
2266        .await?;
2267
2268    let (host, guests) = &*guard;
2269
2270    broadcast(
2271        Some(session.connection_id),
2272        guests.iter().chain([host]).copied(),
2273        |connection_id| {
2274            session
2275                .peer
2276                .forward_send(session.connection_id, connection_id, message.clone())
2277        },
2278    );
2279
2280    Ok(())
2281}
2282
2283/// Notify other participants that a project has been updated.
2284async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2285    request: T,
2286    session: Session,
2287) -> Result<()> {
2288    let project_id = ProjectId::from_proto(request.remote_entity_id());
2289    let project_connection_ids = session
2290        .db()
2291        .await
2292        .project_connection_ids(project_id, session.connection_id, false)
2293        .await?;
2294
2295    broadcast(
2296        Some(session.connection_id),
2297        project_connection_ids.iter().copied(),
2298        |connection_id| {
2299            session
2300                .peer
2301                .forward_send(session.connection_id, connection_id, request.clone())
2302        },
2303    );
2304    Ok(())
2305}
2306
2307/// Start following another user in a call.
2308async fn follow(
2309    request: proto::Follow,
2310    response: Response<proto::Follow>,
2311    session: Session,
2312) -> Result<()> {
2313    let room_id = RoomId::from_proto(request.room_id);
2314    let project_id = request.project_id.map(ProjectId::from_proto);
2315    let leader_id = request
2316        .leader_id
2317        .ok_or_else(|| anyhow!("invalid leader id"))?
2318        .into();
2319    let follower_id = session.connection_id;
2320
2321    session
2322        .db()
2323        .await
2324        .check_room_participants(room_id, leader_id, session.connection_id)
2325        .await?;
2326
2327    let response_payload = session
2328        .peer
2329        .forward_request(session.connection_id, leader_id, request)
2330        .await?;
2331    response.send(response_payload)?;
2332
2333    if let Some(project_id) = project_id {
2334        let room = session
2335            .db()
2336            .await
2337            .follow(room_id, project_id, leader_id, follower_id)
2338            .await?;
2339        room_updated(&room, &session.peer);
2340    }
2341
2342    Ok(())
2343}
2344
2345/// Stop following another user in a call.
2346async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2347    let room_id = RoomId::from_proto(request.room_id);
2348    let project_id = request.project_id.map(ProjectId::from_proto);
2349    let leader_id = request
2350        .leader_id
2351        .ok_or_else(|| anyhow!("invalid leader id"))?
2352        .into();
2353    let follower_id = session.connection_id;
2354
2355    session
2356        .db()
2357        .await
2358        .check_room_participants(room_id, leader_id, session.connection_id)
2359        .await?;
2360
2361    session
2362        .peer
2363        .forward_send(session.connection_id, leader_id, request)?;
2364
2365    if let Some(project_id) = project_id {
2366        let room = session
2367            .db()
2368            .await
2369            .unfollow(room_id, project_id, leader_id, follower_id)
2370            .await?;
2371        room_updated(&room, &session.peer);
2372    }
2373
2374    Ok(())
2375}
2376
2377/// Notify everyone following you of your current location.
2378async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2379    let room_id = RoomId::from_proto(request.room_id);
2380    let database = session.db.lock().await;
2381
2382    let connection_ids = if let Some(project_id) = request.project_id {
2383        let project_id = ProjectId::from_proto(project_id);
2384        database
2385            .project_connection_ids(project_id, session.connection_id, true)
2386            .await?
2387    } else {
2388        database
2389            .room_connection_ids(room_id, session.connection_id)
2390            .await?
2391    };
2392
2393    // For now, don't send view update messages back to that view's current leader.
2394    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2395        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2396        _ => None,
2397    });
2398
2399    for connection_id in connection_ids.iter().cloned() {
2400        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2401            session
2402                .peer
2403                .forward_send(session.connection_id, connection_id, request.clone())?;
2404        }
2405    }
2406    Ok(())
2407}
2408
2409/// Get public data about users.
2410async fn get_users(
2411    request: proto::GetUsers,
2412    response: Response<proto::GetUsers>,
2413    session: Session,
2414) -> Result<()> {
2415    let user_ids = request
2416        .user_ids
2417        .into_iter()
2418        .map(UserId::from_proto)
2419        .collect();
2420    let users = session
2421        .db()
2422        .await
2423        .get_users_by_ids(user_ids)
2424        .await?
2425        .into_iter()
2426        .map(|user| proto::User {
2427            id: user.id.to_proto(),
2428            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2429            github_login: user.github_login,
2430            email: user.email_address,
2431            name: user.name,
2432        })
2433        .collect();
2434    response.send(proto::UsersResponse { users })?;
2435    Ok(())
2436}
2437
2438/// Search for users (to invite) buy Github login
2439async fn fuzzy_search_users(
2440    request: proto::FuzzySearchUsers,
2441    response: Response<proto::FuzzySearchUsers>,
2442    session: Session,
2443) -> Result<()> {
2444    let query = request.query;
2445    let users = match query.len() {
2446        0 => vec![],
2447        1 | 2 => session
2448            .db()
2449            .await
2450            .get_user_by_github_login(&query)
2451            .await?
2452            .into_iter()
2453            .collect(),
2454        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2455    };
2456    let users = users
2457        .into_iter()
2458        .filter(|user| user.id != session.user_id())
2459        .map(|user| proto::User {
2460            id: user.id.to_proto(),
2461            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2462            github_login: user.github_login,
2463            name: user.name,
2464            email: user.email_address,
2465        })
2466        .collect();
2467    response.send(proto::UsersResponse { users })?;
2468    Ok(())
2469}
2470
2471/// Send a contact request to another user.
2472async fn request_contact(
2473    request: proto::RequestContact,
2474    response: Response<proto::RequestContact>,
2475    session: Session,
2476) -> Result<()> {
2477    let requester_id = session.user_id();
2478    let responder_id = UserId::from_proto(request.responder_id);
2479    if requester_id == responder_id {
2480        return Err(anyhow!("cannot add yourself as a contact"))?;
2481    }
2482
2483    let notifications = session
2484        .db()
2485        .await
2486        .send_contact_request(requester_id, responder_id)
2487        .await?;
2488
2489    // Update outgoing contact requests of requester
2490    let mut update = proto::UpdateContacts::default();
2491    update.outgoing_requests.push(responder_id.to_proto());
2492    for connection_id in session
2493        .connection_pool()
2494        .await
2495        .user_connection_ids(requester_id)
2496    {
2497        session.peer.send(connection_id, update.clone())?;
2498    }
2499
2500    // Update incoming contact requests of responder
2501    let mut update = proto::UpdateContacts::default();
2502    update
2503        .incoming_requests
2504        .push(proto::IncomingContactRequest {
2505            requester_id: requester_id.to_proto(),
2506        });
2507    let connection_pool = session.connection_pool().await;
2508    for connection_id in connection_pool.user_connection_ids(responder_id) {
2509        session.peer.send(connection_id, update.clone())?;
2510    }
2511
2512    send_notifications(&connection_pool, &session.peer, notifications);
2513
2514    response.send(proto::Ack {})?;
2515    Ok(())
2516}
2517
2518/// Accept or decline a contact request
2519async fn respond_to_contact_request(
2520    request: proto::RespondToContactRequest,
2521    response: Response<proto::RespondToContactRequest>,
2522    session: Session,
2523) -> Result<()> {
2524    let responder_id = session.user_id();
2525    let requester_id = UserId::from_proto(request.requester_id);
2526    let db = session.db().await;
2527    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2528        db.dismiss_contact_notification(responder_id, requester_id)
2529            .await?;
2530    } else {
2531        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2532
2533        let notifications = db
2534            .respond_to_contact_request(responder_id, requester_id, accept)
2535            .await?;
2536        let requester_busy = db.is_user_busy(requester_id).await?;
2537        let responder_busy = db.is_user_busy(responder_id).await?;
2538
2539        let pool = session.connection_pool().await;
2540        // Update responder with new contact
2541        let mut update = proto::UpdateContacts::default();
2542        if accept {
2543            update
2544                .contacts
2545                .push(contact_for_user(requester_id, requester_busy, &pool));
2546        }
2547        update
2548            .remove_incoming_requests
2549            .push(requester_id.to_proto());
2550        for connection_id in pool.user_connection_ids(responder_id) {
2551            session.peer.send(connection_id, update.clone())?;
2552        }
2553
2554        // Update requester with new contact
2555        let mut update = proto::UpdateContacts::default();
2556        if accept {
2557            update
2558                .contacts
2559                .push(contact_for_user(responder_id, responder_busy, &pool));
2560        }
2561        update
2562            .remove_outgoing_requests
2563            .push(responder_id.to_proto());
2564
2565        for connection_id in pool.user_connection_ids(requester_id) {
2566            session.peer.send(connection_id, update.clone())?;
2567        }
2568
2569        send_notifications(&pool, &session.peer, notifications);
2570    }
2571
2572    response.send(proto::Ack {})?;
2573    Ok(())
2574}
2575
2576/// Remove a contact.
2577async fn remove_contact(
2578    request: proto::RemoveContact,
2579    response: Response<proto::RemoveContact>,
2580    session: Session,
2581) -> Result<()> {
2582    let requester_id = session.user_id();
2583    let responder_id = UserId::from_proto(request.user_id);
2584    let db = session.db().await;
2585    let (contact_accepted, deleted_notification_id) =
2586        db.remove_contact(requester_id, responder_id).await?;
2587
2588    let pool = session.connection_pool().await;
2589    // Update outgoing contact requests of requester
2590    let mut update = proto::UpdateContacts::default();
2591    if contact_accepted {
2592        update.remove_contacts.push(responder_id.to_proto());
2593    } else {
2594        update
2595            .remove_outgoing_requests
2596            .push(responder_id.to_proto());
2597    }
2598    for connection_id in pool.user_connection_ids(requester_id) {
2599        session.peer.send(connection_id, update.clone())?;
2600    }
2601
2602    // Update incoming contact requests of responder
2603    let mut update = proto::UpdateContacts::default();
2604    if contact_accepted {
2605        update.remove_contacts.push(requester_id.to_proto());
2606    } else {
2607        update
2608            .remove_incoming_requests
2609            .push(requester_id.to_proto());
2610    }
2611    for connection_id in pool.user_connection_ids(responder_id) {
2612        session.peer.send(connection_id, update.clone())?;
2613        if let Some(notification_id) = deleted_notification_id {
2614            session.peer.send(
2615                connection_id,
2616                proto::DeleteNotification {
2617                    notification_id: notification_id.to_proto(),
2618                },
2619            )?;
2620        }
2621    }
2622
2623    response.send(proto::Ack {})?;
2624    Ok(())
2625}
2626
2627fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2628    version.0.minor() < 139
2629}
2630
2631async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2632    let plan = session.current_plan(&session.db().await).await?;
2633
2634    session
2635        .peer
2636        .send(
2637            session.connection_id,
2638            proto::UpdateUserPlan { plan: plan.into() },
2639        )
2640        .trace_err();
2641
2642    Ok(())
2643}
2644
2645async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2646    subscribe_user_to_channels(session.user_id(), &session).await?;
2647    Ok(())
2648}
2649
2650async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2651    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2652    let mut pool = session.connection_pool().await;
2653    for membership in &channels_for_user.channel_memberships {
2654        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2655    }
2656    session.peer.send(
2657        session.connection_id,
2658        build_update_user_channels(&channels_for_user),
2659    )?;
2660    session.peer.send(
2661        session.connection_id,
2662        build_channels_update(channels_for_user),
2663    )?;
2664    Ok(())
2665}
2666
2667/// Creates a new channel.
2668async fn create_channel(
2669    request: proto::CreateChannel,
2670    response: Response<proto::CreateChannel>,
2671    session: Session,
2672) -> Result<()> {
2673    let db = session.db().await;
2674
2675    let parent_id = request.parent_id.map(ChannelId::from_proto);
2676    let (channel, membership) = db
2677        .create_channel(&request.name, parent_id, session.user_id())
2678        .await?;
2679
2680    let root_id = channel.root_id();
2681    let channel = Channel::from_model(channel);
2682
2683    response.send(proto::CreateChannelResponse {
2684        channel: Some(channel.to_proto()),
2685        parent_id: request.parent_id,
2686    })?;
2687
2688    let mut connection_pool = session.connection_pool().await;
2689    if let Some(membership) = membership {
2690        connection_pool.subscribe_to_channel(
2691            membership.user_id,
2692            membership.channel_id,
2693            membership.role,
2694        );
2695        let update = proto::UpdateUserChannels {
2696            channel_memberships: vec![proto::ChannelMembership {
2697                channel_id: membership.channel_id.to_proto(),
2698                role: membership.role.into(),
2699            }],
2700            ..Default::default()
2701        };
2702        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2703            session.peer.send(connection_id, update.clone())?;
2704        }
2705    }
2706
2707    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2708        if !role.can_see_channel(channel.visibility) {
2709            continue;
2710        }
2711
2712        let update = proto::UpdateChannels {
2713            channels: vec![channel.to_proto()],
2714            ..Default::default()
2715        };
2716        session.peer.send(connection_id, update.clone())?;
2717    }
2718
2719    Ok(())
2720}
2721
2722/// Delete a channel
2723async fn delete_channel(
2724    request: proto::DeleteChannel,
2725    response: Response<proto::DeleteChannel>,
2726    session: Session,
2727) -> Result<()> {
2728    let db = session.db().await;
2729
2730    let channel_id = request.channel_id;
2731    let (root_channel, removed_channels) = db
2732        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2733        .await?;
2734    response.send(proto::Ack {})?;
2735
2736    // Notify members of removed channels
2737    let mut update = proto::UpdateChannels::default();
2738    update
2739        .delete_channels
2740        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2741
2742    let connection_pool = session.connection_pool().await;
2743    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2744        session.peer.send(connection_id, update.clone())?;
2745    }
2746
2747    Ok(())
2748}
2749
2750/// Invite someone to join a channel.
2751async fn invite_channel_member(
2752    request: proto::InviteChannelMember,
2753    response: Response<proto::InviteChannelMember>,
2754    session: Session,
2755) -> Result<()> {
2756    let db = session.db().await;
2757    let channel_id = ChannelId::from_proto(request.channel_id);
2758    let invitee_id = UserId::from_proto(request.user_id);
2759    let InviteMemberResult {
2760        channel,
2761        notifications,
2762    } = db
2763        .invite_channel_member(
2764            channel_id,
2765            invitee_id,
2766            session.user_id(),
2767            request.role().into(),
2768        )
2769        .await?;
2770
2771    let update = proto::UpdateChannels {
2772        channel_invitations: vec![channel.to_proto()],
2773        ..Default::default()
2774    };
2775
2776    let connection_pool = session.connection_pool().await;
2777    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2778        session.peer.send(connection_id, update.clone())?;
2779    }
2780
2781    send_notifications(&connection_pool, &session.peer, notifications);
2782
2783    response.send(proto::Ack {})?;
2784    Ok(())
2785}
2786
2787/// remove someone from a channel
2788async fn remove_channel_member(
2789    request: proto::RemoveChannelMember,
2790    response: Response<proto::RemoveChannelMember>,
2791    session: Session,
2792) -> Result<()> {
2793    let db = session.db().await;
2794    let channel_id = ChannelId::from_proto(request.channel_id);
2795    let member_id = UserId::from_proto(request.user_id);
2796
2797    let RemoveChannelMemberResult {
2798        membership_update,
2799        notification_id,
2800    } = db
2801        .remove_channel_member(channel_id, member_id, session.user_id())
2802        .await?;
2803
2804    let mut connection_pool = session.connection_pool().await;
2805    notify_membership_updated(
2806        &mut connection_pool,
2807        membership_update,
2808        member_id,
2809        &session.peer,
2810    );
2811    for connection_id in connection_pool.user_connection_ids(member_id) {
2812        if let Some(notification_id) = notification_id {
2813            session
2814                .peer
2815                .send(
2816                    connection_id,
2817                    proto::DeleteNotification {
2818                        notification_id: notification_id.to_proto(),
2819                    },
2820                )
2821                .trace_err();
2822        }
2823    }
2824
2825    response.send(proto::Ack {})?;
2826    Ok(())
2827}
2828
2829/// Toggle the channel between public and private.
2830/// Care is taken to maintain the invariant that public channels only descend from public channels,
2831/// (though members-only channels can appear at any point in the hierarchy).
2832async fn set_channel_visibility(
2833    request: proto::SetChannelVisibility,
2834    response: Response<proto::SetChannelVisibility>,
2835    session: Session,
2836) -> Result<()> {
2837    let db = session.db().await;
2838    let channel_id = ChannelId::from_proto(request.channel_id);
2839    let visibility = request.visibility().into();
2840
2841    let channel_model = db
2842        .set_channel_visibility(channel_id, visibility, session.user_id())
2843        .await?;
2844    let root_id = channel_model.root_id();
2845    let channel = Channel::from_model(channel_model);
2846
2847    let mut connection_pool = session.connection_pool().await;
2848    for (user_id, role) in connection_pool
2849        .channel_user_ids(root_id)
2850        .collect::<Vec<_>>()
2851        .into_iter()
2852    {
2853        let update = if role.can_see_channel(channel.visibility) {
2854            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2855            proto::UpdateChannels {
2856                channels: vec![channel.to_proto()],
2857                ..Default::default()
2858            }
2859        } else {
2860            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2861            proto::UpdateChannels {
2862                delete_channels: vec![channel.id.to_proto()],
2863                ..Default::default()
2864            }
2865        };
2866
2867        for connection_id in connection_pool.user_connection_ids(user_id) {
2868            session.peer.send(connection_id, update.clone())?;
2869        }
2870    }
2871
2872    response.send(proto::Ack {})?;
2873    Ok(())
2874}
2875
2876/// Alter the role for a user in the channel.
2877async fn set_channel_member_role(
2878    request: proto::SetChannelMemberRole,
2879    response: Response<proto::SetChannelMemberRole>,
2880    session: Session,
2881) -> Result<()> {
2882    let db = session.db().await;
2883    let channel_id = ChannelId::from_proto(request.channel_id);
2884    let member_id = UserId::from_proto(request.user_id);
2885    let result = db
2886        .set_channel_member_role(
2887            channel_id,
2888            session.user_id(),
2889            member_id,
2890            request.role().into(),
2891        )
2892        .await?;
2893
2894    match result {
2895        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2896            let mut connection_pool = session.connection_pool().await;
2897            notify_membership_updated(
2898                &mut connection_pool,
2899                membership_update,
2900                member_id,
2901                &session.peer,
2902            )
2903        }
2904        db::SetMemberRoleResult::InviteUpdated(channel) => {
2905            let update = proto::UpdateChannels {
2906                channel_invitations: vec![channel.to_proto()],
2907                ..Default::default()
2908            };
2909
2910            for connection_id in session
2911                .connection_pool()
2912                .await
2913                .user_connection_ids(member_id)
2914            {
2915                session.peer.send(connection_id, update.clone())?;
2916            }
2917        }
2918    }
2919
2920    response.send(proto::Ack {})?;
2921    Ok(())
2922}
2923
2924/// Change the name of a channel
2925async fn rename_channel(
2926    request: proto::RenameChannel,
2927    response: Response<proto::RenameChannel>,
2928    session: Session,
2929) -> Result<()> {
2930    let db = session.db().await;
2931    let channel_id = ChannelId::from_proto(request.channel_id);
2932    let channel_model = db
2933        .rename_channel(channel_id, session.user_id(), &request.name)
2934        .await?;
2935    let root_id = channel_model.root_id();
2936    let channel = Channel::from_model(channel_model);
2937
2938    response.send(proto::RenameChannelResponse {
2939        channel: Some(channel.to_proto()),
2940    })?;
2941
2942    let connection_pool = session.connection_pool().await;
2943    let update = proto::UpdateChannels {
2944        channels: vec![channel.to_proto()],
2945        ..Default::default()
2946    };
2947    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2948        if role.can_see_channel(channel.visibility) {
2949            session.peer.send(connection_id, update.clone())?;
2950        }
2951    }
2952
2953    Ok(())
2954}
2955
2956/// Move a channel to a new parent.
2957async fn move_channel(
2958    request: proto::MoveChannel,
2959    response: Response<proto::MoveChannel>,
2960    session: Session,
2961) -> Result<()> {
2962    let channel_id = ChannelId::from_proto(request.channel_id);
2963    let to = ChannelId::from_proto(request.to);
2964
2965    let (root_id, channels) = session
2966        .db()
2967        .await
2968        .move_channel(channel_id, to, session.user_id())
2969        .await?;
2970
2971    let connection_pool = session.connection_pool().await;
2972    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2973        let channels = channels
2974            .iter()
2975            .filter_map(|channel| {
2976                if role.can_see_channel(channel.visibility) {
2977                    Some(channel.to_proto())
2978                } else {
2979                    None
2980                }
2981            })
2982            .collect::<Vec<_>>();
2983        if channels.is_empty() {
2984            continue;
2985        }
2986
2987        let update = proto::UpdateChannels {
2988            channels,
2989            ..Default::default()
2990        };
2991
2992        session.peer.send(connection_id, update.clone())?;
2993    }
2994
2995    response.send(Ack {})?;
2996    Ok(())
2997}
2998
2999/// Get the list of channel members
3000async fn get_channel_members(
3001    request: proto::GetChannelMembers,
3002    response: Response<proto::GetChannelMembers>,
3003    session: Session,
3004) -> Result<()> {
3005    let db = session.db().await;
3006    let channel_id = ChannelId::from_proto(request.channel_id);
3007    let limit = if request.limit == 0 {
3008        u16::MAX as u64
3009    } else {
3010        request.limit
3011    };
3012    let (members, users) = db
3013        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3014        .await?;
3015    response.send(proto::GetChannelMembersResponse { members, users })?;
3016    Ok(())
3017}
3018
3019/// Accept or decline a channel invitation.
3020async fn respond_to_channel_invite(
3021    request: proto::RespondToChannelInvite,
3022    response: Response<proto::RespondToChannelInvite>,
3023    session: Session,
3024) -> Result<()> {
3025    let db = session.db().await;
3026    let channel_id = ChannelId::from_proto(request.channel_id);
3027    let RespondToChannelInvite {
3028        membership_update,
3029        notifications,
3030    } = db
3031        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3032        .await?;
3033
3034    let mut connection_pool = session.connection_pool().await;
3035    if let Some(membership_update) = membership_update {
3036        notify_membership_updated(
3037            &mut connection_pool,
3038            membership_update,
3039            session.user_id(),
3040            &session.peer,
3041        );
3042    } else {
3043        let update = proto::UpdateChannels {
3044            remove_channel_invitations: vec![channel_id.to_proto()],
3045            ..Default::default()
3046        };
3047
3048        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3049            session.peer.send(connection_id, update.clone())?;
3050        }
3051    };
3052
3053    send_notifications(&connection_pool, &session.peer, notifications);
3054
3055    response.send(proto::Ack {})?;
3056
3057    Ok(())
3058}
3059
3060/// Join the channels' room
3061async fn join_channel(
3062    request: proto::JoinChannel,
3063    response: Response<proto::JoinChannel>,
3064    session: Session,
3065) -> Result<()> {
3066    let channel_id = ChannelId::from_proto(request.channel_id);
3067    join_channel_internal(channel_id, Box::new(response), session).await
3068}
3069
3070trait JoinChannelInternalResponse {
3071    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3072}
3073impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3074    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3075        Response::<proto::JoinChannel>::send(self, result)
3076    }
3077}
3078impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3079    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3080        Response::<proto::JoinRoom>::send(self, result)
3081    }
3082}
3083
3084async fn join_channel_internal(
3085    channel_id: ChannelId,
3086    response: Box<impl JoinChannelInternalResponse>,
3087    session: Session,
3088) -> Result<()> {
3089    let joined_room = {
3090        let mut db = session.db().await;
3091        // If zed quits without leaving the room, and the user re-opens zed before the
3092        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3093        // room they were in.
3094        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3095            tracing::info!(
3096                stale_connection_id = %connection,
3097                "cleaning up stale connection",
3098            );
3099            drop(db);
3100            leave_room_for_session(&session, connection).await?;
3101            db = session.db().await;
3102        }
3103
3104        let (joined_room, membership_updated, role) = db
3105            .join_channel(channel_id, session.user_id(), session.connection_id)
3106            .await?;
3107
3108        let live_kit_connection_info =
3109            session
3110                .app_state
3111                .livekit_client
3112                .as_ref()
3113                .and_then(|live_kit| {
3114                    let (can_publish, token) = if role == ChannelRole::Guest {
3115                        (
3116                            false,
3117                            live_kit
3118                                .guest_token(
3119                                    &joined_room.room.livekit_room,
3120                                    &session.user_id().to_string(),
3121                                )
3122                                .trace_err()?,
3123                        )
3124                    } else {
3125                        (
3126                            true,
3127                            live_kit
3128                                .room_token(
3129                                    &joined_room.room.livekit_room,
3130                                    &session.user_id().to_string(),
3131                                )
3132                                .trace_err()?,
3133                        )
3134                    };
3135
3136                    Some(LiveKitConnectionInfo {
3137                        server_url: live_kit.url().into(),
3138                        token,
3139                        can_publish,
3140                    })
3141                });
3142
3143        response.send(proto::JoinRoomResponse {
3144            room: Some(joined_room.room.clone()),
3145            channel_id: joined_room
3146                .channel
3147                .as_ref()
3148                .map(|channel| channel.id.to_proto()),
3149            live_kit_connection_info,
3150        })?;
3151
3152        let mut connection_pool = session.connection_pool().await;
3153        if let Some(membership_updated) = membership_updated {
3154            notify_membership_updated(
3155                &mut connection_pool,
3156                membership_updated,
3157                session.user_id(),
3158                &session.peer,
3159            );
3160        }
3161
3162        room_updated(&joined_room.room, &session.peer);
3163
3164        joined_room
3165    };
3166
3167    channel_updated(
3168        &joined_room
3169            .channel
3170            .ok_or_else(|| anyhow!("channel not returned"))?,
3171        &joined_room.room,
3172        &session.peer,
3173        &*session.connection_pool().await,
3174    );
3175
3176    update_user_contacts(session.user_id(), &session).await?;
3177    Ok(())
3178}
3179
3180/// Start editing the channel notes
3181async fn join_channel_buffer(
3182    request: proto::JoinChannelBuffer,
3183    response: Response<proto::JoinChannelBuffer>,
3184    session: Session,
3185) -> Result<()> {
3186    let db = session.db().await;
3187    let channel_id = ChannelId::from_proto(request.channel_id);
3188
3189    let open_response = db
3190        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3191        .await?;
3192
3193    let collaborators = open_response.collaborators.clone();
3194    response.send(open_response)?;
3195
3196    let update = UpdateChannelBufferCollaborators {
3197        channel_id: channel_id.to_proto(),
3198        collaborators: collaborators.clone(),
3199    };
3200    channel_buffer_updated(
3201        session.connection_id,
3202        collaborators
3203            .iter()
3204            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3205        &update,
3206        &session.peer,
3207    );
3208
3209    Ok(())
3210}
3211
3212/// Edit the channel notes
3213async fn update_channel_buffer(
3214    request: proto::UpdateChannelBuffer,
3215    session: Session,
3216) -> Result<()> {
3217    let db = session.db().await;
3218    let channel_id = ChannelId::from_proto(request.channel_id);
3219
3220    let (collaborators, epoch, version) = db
3221        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3222        .await?;
3223
3224    channel_buffer_updated(
3225        session.connection_id,
3226        collaborators.clone(),
3227        &proto::UpdateChannelBuffer {
3228            channel_id: channel_id.to_proto(),
3229            operations: request.operations,
3230        },
3231        &session.peer,
3232    );
3233
3234    let pool = &*session.connection_pool().await;
3235
3236    let non_collaborators =
3237        pool.channel_connection_ids(channel_id)
3238            .filter_map(|(connection_id, _)| {
3239                if collaborators.contains(&connection_id) {
3240                    None
3241                } else {
3242                    Some(connection_id)
3243                }
3244            });
3245
3246    broadcast(None, non_collaborators, |peer_id| {
3247        session.peer.send(
3248            peer_id,
3249            proto::UpdateChannels {
3250                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3251                    channel_id: channel_id.to_proto(),
3252                    epoch: epoch as u64,
3253                    version: version.clone(),
3254                }],
3255                ..Default::default()
3256            },
3257        )
3258    });
3259
3260    Ok(())
3261}
3262
3263/// Rejoin the channel notes after a connection blip
3264async fn rejoin_channel_buffers(
3265    request: proto::RejoinChannelBuffers,
3266    response: Response<proto::RejoinChannelBuffers>,
3267    session: Session,
3268) -> Result<()> {
3269    let db = session.db().await;
3270    let buffers = db
3271        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3272        .await?;
3273
3274    for rejoined_buffer in &buffers {
3275        let collaborators_to_notify = rejoined_buffer
3276            .buffer
3277            .collaborators
3278            .iter()
3279            .filter_map(|c| Some(c.peer_id?.into()));
3280        channel_buffer_updated(
3281            session.connection_id,
3282            collaborators_to_notify,
3283            &proto::UpdateChannelBufferCollaborators {
3284                channel_id: rejoined_buffer.buffer.channel_id,
3285                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3286            },
3287            &session.peer,
3288        );
3289    }
3290
3291    response.send(proto::RejoinChannelBuffersResponse {
3292        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3293    })?;
3294
3295    Ok(())
3296}
3297
3298/// Stop editing the channel notes
3299async fn leave_channel_buffer(
3300    request: proto::LeaveChannelBuffer,
3301    response: Response<proto::LeaveChannelBuffer>,
3302    session: Session,
3303) -> Result<()> {
3304    let db = session.db().await;
3305    let channel_id = ChannelId::from_proto(request.channel_id);
3306
3307    let left_buffer = db
3308        .leave_channel_buffer(channel_id, session.connection_id)
3309        .await?;
3310
3311    response.send(Ack {})?;
3312
3313    channel_buffer_updated(
3314        session.connection_id,
3315        left_buffer.connections,
3316        &proto::UpdateChannelBufferCollaborators {
3317            channel_id: channel_id.to_proto(),
3318            collaborators: left_buffer.collaborators,
3319        },
3320        &session.peer,
3321    );
3322
3323    Ok(())
3324}
3325
3326fn channel_buffer_updated<T: EnvelopedMessage>(
3327    sender_id: ConnectionId,
3328    collaborators: impl IntoIterator<Item = ConnectionId>,
3329    message: &T,
3330    peer: &Peer,
3331) {
3332    broadcast(Some(sender_id), collaborators, |peer_id| {
3333        peer.send(peer_id, message.clone())
3334    });
3335}
3336
3337fn send_notifications(
3338    connection_pool: &ConnectionPool,
3339    peer: &Peer,
3340    notifications: db::NotificationBatch,
3341) {
3342    for (user_id, notification) in notifications {
3343        for connection_id in connection_pool.user_connection_ids(user_id) {
3344            if let Err(error) = peer.send(
3345                connection_id,
3346                proto::AddNotification {
3347                    notification: Some(notification.clone()),
3348                },
3349            ) {
3350                tracing::error!(
3351                    "failed to send notification to {:?} {}",
3352                    connection_id,
3353                    error
3354                );
3355            }
3356        }
3357    }
3358}
3359
3360/// Send a message to the channel
3361async fn send_channel_message(
3362    request: proto::SendChannelMessage,
3363    response: Response<proto::SendChannelMessage>,
3364    session: Session,
3365) -> Result<()> {
3366    // Validate the message body.
3367    let body = request.body.trim().to_string();
3368    if body.len() > MAX_MESSAGE_LEN {
3369        return Err(anyhow!("message is too long"))?;
3370    }
3371    if body.is_empty() {
3372        return Err(anyhow!("message can't be blank"))?;
3373    }
3374
3375    // TODO: adjust mentions if body is trimmed
3376
3377    let timestamp = OffsetDateTime::now_utc();
3378    let nonce = request
3379        .nonce
3380        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3381
3382    let channel_id = ChannelId::from_proto(request.channel_id);
3383    let CreatedChannelMessage {
3384        message_id,
3385        participant_connection_ids,
3386        notifications,
3387    } = session
3388        .db()
3389        .await
3390        .create_channel_message(
3391            channel_id,
3392            session.user_id(),
3393            &body,
3394            &request.mentions,
3395            timestamp,
3396            nonce.clone().into(),
3397            request.reply_to_message_id.map(MessageId::from_proto),
3398        )
3399        .await?;
3400
3401    let message = proto::ChannelMessage {
3402        sender_id: session.user_id().to_proto(),
3403        id: message_id.to_proto(),
3404        body,
3405        mentions: request.mentions,
3406        timestamp: timestamp.unix_timestamp() as u64,
3407        nonce: Some(nonce),
3408        reply_to_message_id: request.reply_to_message_id,
3409        edited_at: None,
3410    };
3411    broadcast(
3412        Some(session.connection_id),
3413        participant_connection_ids.clone(),
3414        |connection| {
3415            session.peer.send(
3416                connection,
3417                proto::ChannelMessageSent {
3418                    channel_id: channel_id.to_proto(),
3419                    message: Some(message.clone()),
3420                },
3421            )
3422        },
3423    );
3424    response.send(proto::SendChannelMessageResponse {
3425        message: Some(message),
3426    })?;
3427
3428    let pool = &*session.connection_pool().await;
3429    let non_participants =
3430        pool.channel_connection_ids(channel_id)
3431            .filter_map(|(connection_id, _)| {
3432                if participant_connection_ids.contains(&connection_id) {
3433                    None
3434                } else {
3435                    Some(connection_id)
3436                }
3437            });
3438    broadcast(None, non_participants, |peer_id| {
3439        session.peer.send(
3440            peer_id,
3441            proto::UpdateChannels {
3442                latest_channel_message_ids: vec![proto::ChannelMessageId {
3443                    channel_id: channel_id.to_proto(),
3444                    message_id: message_id.to_proto(),
3445                }],
3446                ..Default::default()
3447            },
3448        )
3449    });
3450    send_notifications(pool, &session.peer, notifications);
3451
3452    Ok(())
3453}
3454
3455/// Delete a channel message
3456async fn remove_channel_message(
3457    request: proto::RemoveChannelMessage,
3458    response: Response<proto::RemoveChannelMessage>,
3459    session: Session,
3460) -> Result<()> {
3461    let channel_id = ChannelId::from_proto(request.channel_id);
3462    let message_id = MessageId::from_proto(request.message_id);
3463    let (connection_ids, existing_notification_ids) = session
3464        .db()
3465        .await
3466        .remove_channel_message(channel_id, message_id, session.user_id())
3467        .await?;
3468
3469    broadcast(
3470        Some(session.connection_id),
3471        connection_ids,
3472        move |connection| {
3473            session.peer.send(connection, request.clone())?;
3474
3475            for notification_id in &existing_notification_ids {
3476                session.peer.send(
3477                    connection,
3478                    proto::DeleteNotification {
3479                        notification_id: (*notification_id).to_proto(),
3480                    },
3481                )?;
3482            }
3483
3484            Ok(())
3485        },
3486    );
3487    response.send(proto::Ack {})?;
3488    Ok(())
3489}
3490
3491async fn update_channel_message(
3492    request: proto::UpdateChannelMessage,
3493    response: Response<proto::UpdateChannelMessage>,
3494    session: Session,
3495) -> Result<()> {
3496    let channel_id = ChannelId::from_proto(request.channel_id);
3497    let message_id = MessageId::from_proto(request.message_id);
3498    let updated_at = OffsetDateTime::now_utc();
3499    let UpdatedChannelMessage {
3500        message_id,
3501        participant_connection_ids,
3502        notifications,
3503        reply_to_message_id,
3504        timestamp,
3505        deleted_mention_notification_ids,
3506        updated_mention_notifications,
3507    } = session
3508        .db()
3509        .await
3510        .update_channel_message(
3511            channel_id,
3512            message_id,
3513            session.user_id(),
3514            request.body.as_str(),
3515            &request.mentions,
3516            updated_at,
3517        )
3518        .await?;
3519
3520    let nonce = request
3521        .nonce
3522        .clone()
3523        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3524
3525    let message = proto::ChannelMessage {
3526        sender_id: session.user_id().to_proto(),
3527        id: message_id.to_proto(),
3528        body: request.body.clone(),
3529        mentions: request.mentions.clone(),
3530        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3531        nonce: Some(nonce),
3532        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3533        edited_at: Some(updated_at.unix_timestamp() as u64),
3534    };
3535
3536    response.send(proto::Ack {})?;
3537
3538    let pool = &*session.connection_pool().await;
3539    broadcast(
3540        Some(session.connection_id),
3541        participant_connection_ids,
3542        |connection| {
3543            session.peer.send(
3544                connection,
3545                proto::ChannelMessageUpdate {
3546                    channel_id: channel_id.to_proto(),
3547                    message: Some(message.clone()),
3548                },
3549            )?;
3550
3551            for notification_id in &deleted_mention_notification_ids {
3552                session.peer.send(
3553                    connection,
3554                    proto::DeleteNotification {
3555                        notification_id: (*notification_id).to_proto(),
3556                    },
3557                )?;
3558            }
3559
3560            for notification in &updated_mention_notifications {
3561                session.peer.send(
3562                    connection,
3563                    proto::UpdateNotification {
3564                        notification: Some(notification.clone()),
3565                    },
3566                )?;
3567            }
3568
3569            Ok(())
3570        },
3571    );
3572
3573    send_notifications(pool, &session.peer, notifications);
3574
3575    Ok(())
3576}
3577
3578/// Mark a channel message as read
3579async fn acknowledge_channel_message(
3580    request: proto::AckChannelMessage,
3581    session: Session,
3582) -> Result<()> {
3583    let channel_id = ChannelId::from_proto(request.channel_id);
3584    let message_id = MessageId::from_proto(request.message_id);
3585    let notifications = session
3586        .db()
3587        .await
3588        .observe_channel_message(channel_id, session.user_id(), message_id)
3589        .await?;
3590    send_notifications(
3591        &*session.connection_pool().await,
3592        &session.peer,
3593        notifications,
3594    );
3595    Ok(())
3596}
3597
3598/// Mark a buffer version as synced
3599async fn acknowledge_buffer_version(
3600    request: proto::AckBufferOperation,
3601    session: Session,
3602) -> Result<()> {
3603    let buffer_id = BufferId::from_proto(request.buffer_id);
3604    session
3605        .db()
3606        .await
3607        .observe_buffer_version(
3608            buffer_id,
3609            session.user_id(),
3610            request.epoch as i32,
3611            &request.version,
3612        )
3613        .await?;
3614    Ok(())
3615}
3616
3617async fn count_language_model_tokens(
3618    request: proto::CountLanguageModelTokens,
3619    response: Response<proto::CountLanguageModelTokens>,
3620    session: Session,
3621    config: &Config,
3622) -> Result<()> {
3623    authorize_access_to_legacy_llm_endpoints(&session).await?;
3624
3625    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3626        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3627        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3628    };
3629
3630    session
3631        .app_state
3632        .rate_limiter
3633        .check(&*rate_limit, session.user_id())
3634        .await?;
3635
3636    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3637        Some(proto::LanguageModelProvider::Google) => {
3638            let api_key = config
3639                .google_ai_api_key
3640                .as_ref()
3641                .context("no Google AI API key configured on the server")?;
3642            google_ai::count_tokens(
3643                session.http_client.as_ref(),
3644                google_ai::API_URL,
3645                api_key,
3646                serde_json::from_str(&request.request)?,
3647            )
3648            .await?
3649        }
3650        _ => return Err(anyhow!("unsupported provider"))?,
3651    };
3652
3653    response.send(proto::CountLanguageModelTokensResponse {
3654        token_count: result.total_tokens as u32,
3655    })?;
3656
3657    Ok(())
3658}
3659
3660struct ZedProCountLanguageModelTokensRateLimit;
3661
3662impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3663    fn capacity(&self) -> usize {
3664        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3665            .ok()
3666            .and_then(|v| v.parse().ok())
3667            .unwrap_or(600) // Picked arbitrarily
3668    }
3669
3670    fn refill_duration(&self) -> chrono::Duration {
3671        chrono::Duration::hours(1)
3672    }
3673
3674    fn db_name(&self) -> &'static str {
3675        "zed-pro:count-language-model-tokens"
3676    }
3677}
3678
3679struct FreeCountLanguageModelTokensRateLimit;
3680
3681impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3682    fn capacity(&self) -> usize {
3683        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3684            .ok()
3685            .and_then(|v| v.parse().ok())
3686            .unwrap_or(600 / 10) // Picked arbitrarily
3687    }
3688
3689    fn refill_duration(&self) -> chrono::Duration {
3690        chrono::Duration::hours(1)
3691    }
3692
3693    fn db_name(&self) -> &'static str {
3694        "free:count-language-model-tokens"
3695    }
3696}
3697
3698struct ZedProComputeEmbeddingsRateLimit;
3699
3700impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3701    fn capacity(&self) -> usize {
3702        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3703            .ok()
3704            .and_then(|v| v.parse().ok())
3705            .unwrap_or(5000) // Picked arbitrarily
3706    }
3707
3708    fn refill_duration(&self) -> chrono::Duration {
3709        chrono::Duration::hours(1)
3710    }
3711
3712    fn db_name(&self) -> &'static str {
3713        "zed-pro:compute-embeddings"
3714    }
3715}
3716
3717struct FreeComputeEmbeddingsRateLimit;
3718
3719impl RateLimit for FreeComputeEmbeddingsRateLimit {
3720    fn capacity(&self) -> usize {
3721        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3722            .ok()
3723            .and_then(|v| v.parse().ok())
3724            .unwrap_or(5000 / 10) // Picked arbitrarily
3725    }
3726
3727    fn refill_duration(&self) -> chrono::Duration {
3728        chrono::Duration::hours(1)
3729    }
3730
3731    fn db_name(&self) -> &'static str {
3732        "free:compute-embeddings"
3733    }
3734}
3735
3736async fn compute_embeddings(
3737    request: proto::ComputeEmbeddings,
3738    response: Response<proto::ComputeEmbeddings>,
3739    session: Session,
3740    api_key: Option<Arc<str>>,
3741) -> Result<()> {
3742    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3743    authorize_access_to_legacy_llm_endpoints(&session).await?;
3744
3745    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3746        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3747        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3748    };
3749
3750    session
3751        .app_state
3752        .rate_limiter
3753        .check(&*rate_limit, session.user_id())
3754        .await?;
3755
3756    let embeddings = match request.model.as_str() {
3757        "openai/text-embedding-3-small" => {
3758            open_ai::embed(
3759                session.http_client.as_ref(),
3760                OPEN_AI_API_URL,
3761                &api_key,
3762                OpenAiEmbeddingModel::TextEmbedding3Small,
3763                request.texts.iter().map(|text| text.as_str()),
3764            )
3765            .await?
3766        }
3767        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3768    };
3769
3770    let embeddings = request
3771        .texts
3772        .iter()
3773        .map(|text| {
3774            let mut hasher = sha2::Sha256::new();
3775            hasher.update(text.as_bytes());
3776            let result = hasher.finalize();
3777            result.to_vec()
3778        })
3779        .zip(
3780            embeddings
3781                .data
3782                .into_iter()
3783                .map(|embedding| embedding.embedding),
3784        )
3785        .collect::<HashMap<_, _>>();
3786
3787    let db = session.db().await;
3788    db.save_embeddings(&request.model, &embeddings)
3789        .await
3790        .context("failed to save embeddings")
3791        .trace_err();
3792
3793    response.send(proto::ComputeEmbeddingsResponse {
3794        embeddings: embeddings
3795            .into_iter()
3796            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3797            .collect(),
3798    })?;
3799    Ok(())
3800}
3801
3802async fn get_cached_embeddings(
3803    request: proto::GetCachedEmbeddings,
3804    response: Response<proto::GetCachedEmbeddings>,
3805    session: Session,
3806) -> Result<()> {
3807    authorize_access_to_legacy_llm_endpoints(&session).await?;
3808
3809    let db = session.db().await;
3810    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3811
3812    response.send(proto::GetCachedEmbeddingsResponse {
3813        embeddings: embeddings
3814            .into_iter()
3815            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3816            .collect(),
3817    })?;
3818    Ok(())
3819}
3820
3821/// This is leftover from before the LLM service.
3822///
3823/// The endpoints protected by this check will be moved there eventually.
3824async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3825    if session.is_staff() {
3826        Ok(())
3827    } else {
3828        Err(anyhow!("permission denied"))?
3829    }
3830}
3831
3832/// Get a Supermaven API key for the user
3833async fn get_supermaven_api_key(
3834    _request: proto::GetSupermavenApiKey,
3835    response: Response<proto::GetSupermavenApiKey>,
3836    session: Session,
3837) -> Result<()> {
3838    let user_id: String = session.user_id().to_string();
3839    if !session.is_staff() {
3840        return Err(anyhow!("supermaven not enabled for this account"))?;
3841    }
3842
3843    let email = session
3844        .email()
3845        .ok_or_else(|| anyhow!("user must have an email"))?;
3846
3847    let supermaven_admin_api = session
3848        .supermaven_client
3849        .as_ref()
3850        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3851
3852    let result = supermaven_admin_api
3853        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3854        .await?;
3855
3856    response.send(proto::GetSupermavenApiKeyResponse {
3857        api_key: result.api_key,
3858    })?;
3859
3860    Ok(())
3861}
3862
3863/// Start receiving chat updates for a channel
3864async fn join_channel_chat(
3865    request: proto::JoinChannelChat,
3866    response: Response<proto::JoinChannelChat>,
3867    session: Session,
3868) -> Result<()> {
3869    let channel_id = ChannelId::from_proto(request.channel_id);
3870
3871    let db = session.db().await;
3872    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3873        .await?;
3874    let messages = db
3875        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3876        .await?;
3877    response.send(proto::JoinChannelChatResponse {
3878        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3879        messages,
3880    })?;
3881    Ok(())
3882}
3883
3884/// Stop receiving chat updates for a channel
3885async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3886    let channel_id = ChannelId::from_proto(request.channel_id);
3887    session
3888        .db()
3889        .await
3890        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3891        .await?;
3892    Ok(())
3893}
3894
3895/// Retrieve the chat history for a channel
3896async fn get_channel_messages(
3897    request: proto::GetChannelMessages,
3898    response: Response<proto::GetChannelMessages>,
3899    session: Session,
3900) -> Result<()> {
3901    let channel_id = ChannelId::from_proto(request.channel_id);
3902    let messages = session
3903        .db()
3904        .await
3905        .get_channel_messages(
3906            channel_id,
3907            session.user_id(),
3908            MESSAGE_COUNT_PER_PAGE,
3909            Some(MessageId::from_proto(request.before_message_id)),
3910        )
3911        .await?;
3912    response.send(proto::GetChannelMessagesResponse {
3913        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914        messages,
3915    })?;
3916    Ok(())
3917}
3918
3919/// Retrieve specific chat messages
3920async fn get_channel_messages_by_id(
3921    request: proto::GetChannelMessagesById,
3922    response: Response<proto::GetChannelMessagesById>,
3923    session: Session,
3924) -> Result<()> {
3925    let message_ids = request
3926        .message_ids
3927        .iter()
3928        .map(|id| MessageId::from_proto(*id))
3929        .collect::<Vec<_>>();
3930    let messages = session
3931        .db()
3932        .await
3933        .get_channel_messages_by_id(session.user_id(), &message_ids)
3934        .await?;
3935    response.send(proto::GetChannelMessagesResponse {
3936        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3937        messages,
3938    })?;
3939    Ok(())
3940}
3941
3942/// Retrieve the current users notifications
3943async fn get_notifications(
3944    request: proto::GetNotifications,
3945    response: Response<proto::GetNotifications>,
3946    session: Session,
3947) -> Result<()> {
3948    let notifications = session
3949        .db()
3950        .await
3951        .get_notifications(
3952            session.user_id(),
3953            NOTIFICATION_COUNT_PER_PAGE,
3954            request.before_id.map(db::NotificationId::from_proto),
3955        )
3956        .await?;
3957    response.send(proto::GetNotificationsResponse {
3958        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3959        notifications,
3960    })?;
3961    Ok(())
3962}
3963
3964/// Mark notifications as read
3965async fn mark_notification_as_read(
3966    request: proto::MarkNotificationRead,
3967    response: Response<proto::MarkNotificationRead>,
3968    session: Session,
3969) -> Result<()> {
3970    let database = &session.db().await;
3971    let notifications = database
3972        .mark_notification_as_read_by_id(
3973            session.user_id(),
3974            NotificationId::from_proto(request.notification_id),
3975        )
3976        .await?;
3977    send_notifications(
3978        &*session.connection_pool().await,
3979        &session.peer,
3980        notifications,
3981    );
3982    response.send(proto::Ack {})?;
3983    Ok(())
3984}
3985
3986/// Get the current users information
3987async fn get_private_user_info(
3988    _request: proto::GetPrivateUserInfo,
3989    response: Response<proto::GetPrivateUserInfo>,
3990    session: Session,
3991) -> Result<()> {
3992    let db = session.db().await;
3993
3994    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3995    let user = db
3996        .get_user_by_id(session.user_id())
3997        .await?
3998        .ok_or_else(|| anyhow!("user not found"))?;
3999    let flags = db.get_user_flags(session.user_id()).await?;
4000
4001    response.send(proto::GetPrivateUserInfoResponse {
4002        metrics_id,
4003        staff: user.admin,
4004        flags,
4005        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4006    })?;
4007    Ok(())
4008}
4009
4010/// Accept the terms of service (tos) on behalf of the current user
4011async fn accept_terms_of_service(
4012    _request: proto::AcceptTermsOfService,
4013    response: Response<proto::AcceptTermsOfService>,
4014    session: Session,
4015) -> Result<()> {
4016    let db = session.db().await;
4017
4018    let accepted_tos_at = Utc::now();
4019    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4020        .await?;
4021
4022    response.send(proto::AcceptTermsOfServiceResponse {
4023        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4024    })?;
4025    Ok(())
4026}
4027
4028/// The minimum account age an account must have in order to use the LLM service.
4029const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4030
4031async fn get_llm_api_token(
4032    _request: proto::GetLlmToken,
4033    response: Response<proto::GetLlmToken>,
4034    session: Session,
4035) -> Result<()> {
4036    let db = session.db().await;
4037
4038    let flags = db.get_user_flags(session.user_id()).await?;
4039    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4040    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4041    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4042
4043    if !session.is_staff() && !has_language_models_feature_flag {
4044        Err(anyhow!("permission denied"))?
4045    }
4046
4047    let user_id = session.user_id();
4048    let user = db
4049        .get_user_by_id(user_id)
4050        .await?
4051        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4052
4053    if user.accepted_tos_at.is_none() {
4054        Err(anyhow!("terms of service not accepted"))?
4055    }
4056
4057    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4058
4059    let bypass_account_age_check =
4060        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4061    if !bypass_account_age_check {
4062        let mut account_created_at = user.created_at;
4063        if let Some(github_created_at) = user.github_user_created_at {
4064            account_created_at = account_created_at.min(github_created_at);
4065        }
4066        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4067            Err(anyhow!("account too young"))?
4068        }
4069    }
4070
4071    let billing_preferences = db.get_billing_preferences(user.id).await?;
4072
4073    let token = LlmTokenClaims::create(
4074        &user,
4075        session.is_staff(),
4076        billing_preferences,
4077        has_llm_closed_beta_feature_flag,
4078        has_predict_edits_feature_flag,
4079        has_llm_subscription,
4080        session.current_plan(&db).await?,
4081        session.system_id.clone(),
4082        &session.app_state.config,
4083    )?;
4084    response.send(proto::GetLlmTokenResponse { token })?;
4085    Ok(())
4086}
4087
4088fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4089    let message = match message {
4090        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4091        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4092        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4093        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4094        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4095            code: frame.code.into(),
4096            reason: frame.reason,
4097        })),
4098        // We should never receive a frame while reading the message, according
4099        // to the `tungstenite` maintainers:
4100        //
4101        // > It cannot occur when you read messages from the WebSocket, but it
4102        // > can be used when you want to send the raw frames (e.g. you want to
4103        // > send the frames to the WebSocket without composing the full message first).
4104        // >
4105        // > — https://github.com/snapview/tungstenite-rs/issues/268
4106        TungsteniteMessage::Frame(_) => {
4107            bail!("received an unexpected frame while reading the message")
4108        }
4109    };
4110
4111    Ok(message)
4112}
4113
4114fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4115    match message {
4116        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4117        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4118        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4119        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4120        AxumMessage::Close(frame) => {
4121            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4122                code: frame.code.into(),
4123                reason: frame.reason,
4124            }))
4125        }
4126    }
4127}
4128
4129fn notify_membership_updated(
4130    connection_pool: &mut ConnectionPool,
4131    result: MembershipUpdated,
4132    user_id: UserId,
4133    peer: &Peer,
4134) {
4135    for membership in &result.new_channels.channel_memberships {
4136        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4137    }
4138    for channel_id in &result.removed_channels {
4139        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4140    }
4141
4142    let user_channels_update = proto::UpdateUserChannels {
4143        channel_memberships: result
4144            .new_channels
4145            .channel_memberships
4146            .iter()
4147            .map(|cm| proto::ChannelMembership {
4148                channel_id: cm.channel_id.to_proto(),
4149                role: cm.role.into(),
4150            })
4151            .collect(),
4152        ..Default::default()
4153    };
4154
4155    let mut update = build_channels_update(result.new_channels);
4156    update.delete_channels = result
4157        .removed_channels
4158        .into_iter()
4159        .map(|id| id.to_proto())
4160        .collect();
4161    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4162
4163    for connection_id in connection_pool.user_connection_ids(user_id) {
4164        peer.send(connection_id, user_channels_update.clone())
4165            .trace_err();
4166        peer.send(connection_id, update.clone()).trace_err();
4167    }
4168}
4169
4170fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4171    proto::UpdateUserChannels {
4172        channel_memberships: channels
4173            .channel_memberships
4174            .iter()
4175            .map(|m| proto::ChannelMembership {
4176                channel_id: m.channel_id.to_proto(),
4177                role: m.role.into(),
4178            })
4179            .collect(),
4180        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4181        observed_channel_message_id: channels.observed_channel_messages.clone(),
4182    }
4183}
4184
4185fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4186    let mut update = proto::UpdateChannels::default();
4187
4188    for channel in channels.channels {
4189        update.channels.push(channel.to_proto());
4190    }
4191
4192    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4193    update.latest_channel_message_ids = channels.latest_channel_messages;
4194
4195    for (channel_id, participants) in channels.channel_participants {
4196        update
4197            .channel_participants
4198            .push(proto::ChannelParticipants {
4199                channel_id: channel_id.to_proto(),
4200                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4201            });
4202    }
4203
4204    for channel in channels.invited_channels {
4205        update.channel_invitations.push(channel.to_proto());
4206    }
4207
4208    update
4209}
4210
4211fn build_initial_contacts_update(
4212    contacts: Vec<db::Contact>,
4213    pool: &ConnectionPool,
4214) -> proto::UpdateContacts {
4215    let mut update = proto::UpdateContacts::default();
4216
4217    for contact in contacts {
4218        match contact {
4219            db::Contact::Accepted { user_id, busy } => {
4220                update.contacts.push(contact_for_user(user_id, busy, pool));
4221            }
4222            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4223            db::Contact::Incoming { user_id } => {
4224                update
4225                    .incoming_requests
4226                    .push(proto::IncomingContactRequest {
4227                        requester_id: user_id.to_proto(),
4228                    })
4229            }
4230        }
4231    }
4232
4233    update
4234}
4235
4236fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4237    proto::Contact {
4238        user_id: user_id.to_proto(),
4239        online: pool.is_user_online(user_id),
4240        busy,
4241    }
4242}
4243
4244fn room_updated(room: &proto::Room, peer: &Peer) {
4245    broadcast(
4246        None,
4247        room.participants
4248            .iter()
4249            .filter_map(|participant| Some(participant.peer_id?.into())),
4250        |peer_id| {
4251            peer.send(
4252                peer_id,
4253                proto::RoomUpdated {
4254                    room: Some(room.clone()),
4255                },
4256            )
4257        },
4258    );
4259}
4260
4261fn channel_updated(
4262    channel: &db::channel::Model,
4263    room: &proto::Room,
4264    peer: &Peer,
4265    pool: &ConnectionPool,
4266) {
4267    let participants = room
4268        .participants
4269        .iter()
4270        .map(|p| p.user_id)
4271        .collect::<Vec<_>>();
4272
4273    broadcast(
4274        None,
4275        pool.channel_connection_ids(channel.root_id())
4276            .filter_map(|(channel_id, role)| {
4277                role.can_see_channel(channel.visibility)
4278                    .then_some(channel_id)
4279            }),
4280        |peer_id| {
4281            peer.send(
4282                peer_id,
4283                proto::UpdateChannels {
4284                    channel_participants: vec![proto::ChannelParticipants {
4285                        channel_id: channel.id.to_proto(),
4286                        participant_user_ids: participants.clone(),
4287                    }],
4288                    ..Default::default()
4289                },
4290            )
4291        },
4292    );
4293}
4294
4295async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4296    let db = session.db().await;
4297
4298    let contacts = db.get_contacts(user_id).await?;
4299    let busy = db.is_user_busy(user_id).await?;
4300
4301    let pool = session.connection_pool().await;
4302    let updated_contact = contact_for_user(user_id, busy, &pool);
4303    for contact in contacts {
4304        if let db::Contact::Accepted {
4305            user_id: contact_user_id,
4306            ..
4307        } = contact
4308        {
4309            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4310                session
4311                    .peer
4312                    .send(
4313                        contact_conn_id,
4314                        proto::UpdateContacts {
4315                            contacts: vec![updated_contact.clone()],
4316                            remove_contacts: Default::default(),
4317                            incoming_requests: Default::default(),
4318                            remove_incoming_requests: Default::default(),
4319                            outgoing_requests: Default::default(),
4320                            remove_outgoing_requests: Default::default(),
4321                        },
4322                    )
4323                    .trace_err();
4324            }
4325        }
4326    }
4327    Ok(())
4328}
4329
4330async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4331    let mut contacts_to_update = HashSet::default();
4332
4333    let room_id;
4334    let canceled_calls_to_user_ids;
4335    let livekit_room;
4336    let delete_livekit_room;
4337    let room;
4338    let channel;
4339
4340    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4341        contacts_to_update.insert(session.user_id());
4342
4343        for project in left_room.left_projects.values() {
4344            project_left(project, session);
4345        }
4346
4347        room_id = RoomId::from_proto(left_room.room.id);
4348        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4349        livekit_room = mem::take(&mut left_room.room.livekit_room);
4350        delete_livekit_room = left_room.deleted;
4351        room = mem::take(&mut left_room.room);
4352        channel = mem::take(&mut left_room.channel);
4353
4354        room_updated(&room, &session.peer);
4355    } else {
4356        return Ok(());
4357    }
4358
4359    if let Some(channel) = channel {
4360        channel_updated(
4361            &channel,
4362            &room,
4363            &session.peer,
4364            &*session.connection_pool().await,
4365        );
4366    }
4367
4368    {
4369        let pool = session.connection_pool().await;
4370        for canceled_user_id in canceled_calls_to_user_ids {
4371            for connection_id in pool.user_connection_ids(canceled_user_id) {
4372                session
4373                    .peer
4374                    .send(
4375                        connection_id,
4376                        proto::CallCanceled {
4377                            room_id: room_id.to_proto(),
4378                        },
4379                    )
4380                    .trace_err();
4381            }
4382            contacts_to_update.insert(canceled_user_id);
4383        }
4384    }
4385
4386    for contact_user_id in contacts_to_update {
4387        update_user_contacts(contact_user_id, session).await?;
4388    }
4389
4390    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4391        live_kit
4392            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4393            .await
4394            .trace_err();
4395
4396        if delete_livekit_room {
4397            live_kit.delete_room(livekit_room).await.trace_err();
4398        }
4399    }
4400
4401    Ok(())
4402}
4403
4404async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4405    let left_channel_buffers = session
4406        .db()
4407        .await
4408        .leave_channel_buffers(session.connection_id)
4409        .await?;
4410
4411    for left_buffer in left_channel_buffers {
4412        channel_buffer_updated(
4413            session.connection_id,
4414            left_buffer.connections,
4415            &proto::UpdateChannelBufferCollaborators {
4416                channel_id: left_buffer.channel_id.to_proto(),
4417                collaborators: left_buffer.collaborators,
4418            },
4419            &session.peer,
4420        );
4421    }
4422
4423    Ok(())
4424}
4425
4426fn project_left(project: &db::LeftProject, session: &Session) {
4427    for connection_id in &project.connection_ids {
4428        if project.should_unshare {
4429            session
4430                .peer
4431                .send(
4432                    *connection_id,
4433                    proto::UnshareProject {
4434                        project_id: project.id.to_proto(),
4435                    },
4436                )
4437                .trace_err();
4438        } else {
4439            session
4440                .peer
4441                .send(
4442                    *connection_id,
4443                    proto::RemoveProjectCollaborator {
4444                        project_id: project.id.to_proto(),
4445                        peer_id: Some(session.connection_id.into()),
4446                    },
4447                )
4448                .trace_err();
4449        }
4450    }
4451}
4452
4453pub trait ResultExt {
4454    type Ok;
4455
4456    fn trace_err(self) -> Option<Self::Ok>;
4457}
4458
4459impl<T, E> ResultExt for Result<T, E>
4460where
4461    E: std::fmt::Debug,
4462{
4463    type Ok = T;
4464
4465    #[track_caller]
4466    fn trace_err(self) -> Option<T> {
4467        match self {
4468            Ok(value) => Some(value),
4469            Err(error) => {
4470                tracing::error!("{:?}", error);
4471                None
4472            }
4473        }
4474    }
4475}