rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 341            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 342            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 344            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 345            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 346            .add_message_handler(create_buffer_for_peer)
 347            .add_request_handler(update_buffer)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 353            .add_request_handler(get_users)
 354            .add_request_handler(fuzzy_search_users)
 355            .add_request_handler(request_contact)
 356            .add_request_handler(remove_contact)
 357            .add_request_handler(respond_to_contact_request)
 358            .add_message_handler(subscribe_to_channels)
 359            .add_request_handler(create_channel)
 360            .add_request_handler(delete_channel)
 361            .add_request_handler(invite_channel_member)
 362            .add_request_handler(remove_channel_member)
 363            .add_request_handler(set_channel_member_role)
 364            .add_request_handler(set_channel_visibility)
 365            .add_request_handler(rename_channel)
 366            .add_request_handler(join_channel_buffer)
 367            .add_request_handler(leave_channel_buffer)
 368            .add_message_handler(update_channel_buffer)
 369            .add_request_handler(rejoin_channel_buffers)
 370            .add_request_handler(get_channel_members)
 371            .add_request_handler(respond_to_channel_invite)
 372            .add_request_handler(join_channel)
 373            .add_request_handler(join_channel_chat)
 374            .add_message_handler(leave_channel_chat)
 375            .add_request_handler(send_channel_message)
 376            .add_request_handler(remove_channel_message)
 377            .add_request_handler(update_channel_message)
 378            .add_request_handler(get_channel_messages)
 379            .add_request_handler(get_channel_messages_by_id)
 380            .add_request_handler(get_notifications)
 381            .add_request_handler(mark_notification_as_read)
 382            .add_request_handler(move_channel)
 383            .add_request_handler(follow)
 384            .add_message_handler(unfollow)
 385            .add_message_handler(update_followers)
 386            .add_request_handler(get_private_user_info)
 387            .add_request_handler(get_llm_api_token)
 388            .add_request_handler(accept_terms_of_service)
 389            .add_message_handler(acknowledge_channel_message)
 390            .add_message_handler(acknowledge_buffer_version)
 391            .add_request_handler(get_supermaven_api_key)
 392            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 395            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 398            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 399            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 400            .add_message_handler(update_context)
 401            .add_request_handler({
 402                let app_state = app_state.clone();
 403                move |request, response, session| {
 404                    let app_state = app_state.clone();
 405                    async move {
 406                        count_language_model_tokens(request, response, session, &app_state.config)
 407                            .await
 408                    }
 409                }
 410            })
 411            .add_request_handler(get_cached_embeddings)
 412            .add_request_handler({
 413                let app_state = app_state.clone();
 414                move |request, response, session| {
 415                    compute_embeddings(
 416                        request,
 417                        response,
 418                        session,
 419                        app_state.config.openai_api_key.clone(),
 420                    )
 421                }
 422            });
 423
 424        Arc::new(server)
 425    }
 426
 427    pub async fn start(&self) -> Result<()> {
 428        let server_id = *self.id.lock();
 429        let app_state = self.app_state.clone();
 430        let peer = self.peer.clone();
 431        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 432        let pool = self.connection_pool.clone();
 433        let livekit_client = self.app_state.livekit_client.clone();
 434
 435        let span = info_span!("start server");
 436        self.app_state.executor.spawn_detached(
 437            async move {
 438                tracing::info!("waiting for cleanup timeout");
 439                timeout.await;
 440                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 441                if let Some((room_ids, channel_ids)) = app_state
 442                    .db
 443                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 444                    .await
 445                    .trace_err()
 446                {
 447                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 448                    tracing::info!(
 449                        stale_channel_buffer_count = channel_ids.len(),
 450                        "retrieved stale channel buffers"
 451                    );
 452
 453                    for channel_id in channel_ids {
 454                        if let Some(refreshed_channel_buffer) = app_state
 455                            .db
 456                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 457                            .await
 458                            .trace_err()
 459                        {
 460                            for connection_id in refreshed_channel_buffer.connection_ids {
 461                                peer.send(
 462                                    connection_id,
 463                                    proto::UpdateChannelBufferCollaborators {
 464                                        channel_id: channel_id.to_proto(),
 465                                        collaborators: refreshed_channel_buffer
 466                                            .collaborators
 467                                            .clone(),
 468                                    },
 469                                )
 470                                .trace_err();
 471                            }
 472                        }
 473                    }
 474
 475                    for room_id in room_ids {
 476                        let mut contacts_to_update = HashSet::default();
 477                        let mut canceled_calls_to_user_ids = Vec::new();
 478                        let mut livekit_room = String::new();
 479                        let mut delete_livekit_room = false;
 480
 481                        if let Some(mut refreshed_room) = app_state
 482                            .db
 483                            .clear_stale_room_participants(room_id, server_id)
 484                            .await
 485                            .trace_err()
 486                        {
 487                            tracing::info!(
 488                                room_id = room_id.0,
 489                                new_participant_count = refreshed_room.room.participants.len(),
 490                                "refreshed room"
 491                            );
 492                            room_updated(&refreshed_room.room, &peer);
 493                            if let Some(channel) = refreshed_room.channel.as_ref() {
 494                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 495                            }
 496                            contacts_to_update
 497                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 498                            contacts_to_update
 499                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 500                            canceled_calls_to_user_ids =
 501                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 502                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 503                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 504                        }
 505
 506                        {
 507                            let pool = pool.lock();
 508                            for canceled_user_id in canceled_calls_to_user_ids {
 509                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 510                                    peer.send(
 511                                        connection_id,
 512                                        proto::CallCanceled {
 513                                            room_id: room_id.to_proto(),
 514                                        },
 515                                    )
 516                                    .trace_err();
 517                                }
 518                            }
 519                        }
 520
 521                        for user_id in contacts_to_update {
 522                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 523                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 524                            if let Some((busy, contacts)) = busy.zip(contacts) {
 525                                let pool = pool.lock();
 526                                let updated_contact = contact_for_user(user_id, busy, &pool);
 527                                for contact in contacts {
 528                                    if let db::Contact::Accepted {
 529                                        user_id: contact_user_id,
 530                                        ..
 531                                    } = contact
 532                                    {
 533                                        for contact_conn_id in
 534                                            pool.user_connection_ids(contact_user_id)
 535                                        {
 536                                            peer.send(
 537                                                contact_conn_id,
 538                                                proto::UpdateContacts {
 539                                                    contacts: vec![updated_contact.clone()],
 540                                                    remove_contacts: Default::default(),
 541                                                    incoming_requests: Default::default(),
 542                                                    remove_incoming_requests: Default::default(),
 543                                                    outgoing_requests: Default::default(),
 544                                                    remove_outgoing_requests: Default::default(),
 545                                                },
 546                                            )
 547                                            .trace_err();
 548                                        }
 549                                    }
 550                                }
 551                            }
 552                        }
 553
 554                        if let Some(live_kit) = livekit_client.as_ref() {
 555                            if delete_livekit_room {
 556                                live_kit.delete_room(livekit_room).await.trace_err();
 557                            }
 558                        }
 559                    }
 560                }
 561
 562                app_state
 563                    .db
 564                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 565                    .await
 566                    .trace_err();
 567            }
 568            .instrument(span),
 569        );
 570        Ok(())
 571    }
 572
 573    pub fn teardown(&self) {
 574        self.peer.teardown();
 575        self.connection_pool.lock().reset();
 576        let _ = self.teardown.send(true);
 577    }
 578
 579    #[cfg(test)]
 580    pub fn reset(&self, id: ServerId) {
 581        self.teardown();
 582        *self.id.lock() = id;
 583        self.peer.reset(id.0 as u32);
 584        let _ = self.teardown.send(false);
 585    }
 586
 587    #[cfg(test)]
 588    pub fn id(&self) -> ServerId {
 589        *self.id.lock()
 590    }
 591
 592    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 593    where
 594        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 595        Fut: 'static + Send + Future<Output = Result<()>>,
 596        M: EnvelopedMessage,
 597    {
 598        let prev_handler = self.handlers.insert(
 599            TypeId::of::<M>(),
 600            Box::new(move |envelope, session| {
 601                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 602                let received_at = envelope.received_at;
 603                tracing::info!("message received");
 604                let start_time = Instant::now();
 605                let future = (handler)(*envelope, session);
 606                async move {
 607                    let result = future.await;
 608                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 609                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 610                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 611                    let payload_type = M::NAME;
 612
 613                    match result {
 614                        Err(error) => {
 615                            tracing::error!(
 616                                ?error,
 617                                total_duration_ms,
 618                                processing_duration_ms,
 619                                queue_duration_ms,
 620                                payload_type,
 621                                "error handling message"
 622                            )
 623                        }
 624                        Ok(()) => tracing::info!(
 625                            total_duration_ms,
 626                            processing_duration_ms,
 627                            queue_duration_ms,
 628                            "finished handling message"
 629                        ),
 630                    }
 631                }
 632                .boxed()
 633            }),
 634        );
 635        if prev_handler.is_some() {
 636            panic!("registered a handler for the same message twice");
 637        }
 638        self
 639    }
 640
 641    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 642    where
 643        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 644        Fut: 'static + Send + Future<Output = Result<()>>,
 645        M: EnvelopedMessage,
 646    {
 647        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 648        self
 649    }
 650
 651    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 652    where
 653        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 654        Fut: Send + Future<Output = Result<()>>,
 655        M: RequestMessage,
 656    {
 657        let handler = Arc::new(handler);
 658        self.add_handler(move |envelope, session| {
 659            let receipt = envelope.receipt();
 660            let handler = handler.clone();
 661            async move {
 662                let peer = session.peer.clone();
 663                let responded = Arc::new(AtomicBool::default());
 664                let response = Response {
 665                    peer: peer.clone(),
 666                    responded: responded.clone(),
 667                    receipt,
 668                };
 669                match (handler)(envelope.payload, response, session).await {
 670                    Ok(()) => {
 671                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 672                            Ok(())
 673                        } else {
 674                            Err(anyhow!("handler did not send a response"))?
 675                        }
 676                    }
 677                    Err(error) => {
 678                        let proto_err = match &error {
 679                            Error::Internal(err) => err.to_proto(),
 680                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 681                        };
 682                        peer.respond_with_error(receipt, proto_err)?;
 683                        Err(error)
 684                    }
 685                }
 686            }
 687        })
 688    }
 689
 690    #[allow(clippy::too_many_arguments)]
 691    pub fn handle_connection(
 692        self: &Arc<Self>,
 693        connection: Connection,
 694        address: String,
 695        principal: Principal,
 696        zed_version: ZedVersion,
 697        geoip_country_code: Option<String>,
 698        system_id: Option<String>,
 699        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 700        executor: Executor,
 701    ) -> impl Future<Output = ()> {
 702        let this = self.clone();
 703        let span = info_span!("handle connection", %address,
 704            connection_id=field::Empty,
 705            user_id=field::Empty,
 706            login=field::Empty,
 707            impersonator=field::Empty,
 708            geoip_country_code=field::Empty
 709        );
 710        principal.update_span(&span);
 711        if let Some(country_code) = geoip_country_code.as_ref() {
 712            span.record("geoip_country_code", country_code);
 713        }
 714
 715        let mut teardown = self.teardown.subscribe();
 716        async move {
 717            if *teardown.borrow() {
 718                tracing::error!("server is tearing down");
 719                return
 720            }
 721            let (connection_id, handle_io, mut incoming_rx) = this
 722                .peer
 723                .add_connection(connection, {
 724                    let executor = executor.clone();
 725                    move |duration| executor.sleep(duration)
 726                });
 727            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 728
 729            tracing::info!("connection opened");
 730
 731            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 732            let http_client = match ReqwestClient::user_agent(&user_agent) {
 733                Ok(http_client) => Arc::new(http_client),
 734                Err(error) => {
 735                    tracing::error!(?error, "failed to create HTTP client");
 736                    return;
 737                }
 738            };
 739
 740            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 741                    supermaven_admin_api_key.to_string(),
 742                    http_client.clone(),
 743                )));
 744
 745            let session = Session {
 746                principal: principal.clone(),
 747                connection_id,
 748                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 749                peer: this.peer.clone(),
 750                connection_pool: this.connection_pool.clone(),
 751                app_state: this.app_state.clone(),
 752                http_client,
 753                geoip_country_code,
 754                system_id,
 755                _executor: executor.clone(),
 756                supermaven_client,
 757            };
 758
 759            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 760                tracing::error!(?error, "failed to send initial client update");
 761                return;
 762            }
 763
 764            let handle_io = handle_io.fuse();
 765            futures::pin_mut!(handle_io);
 766
 767            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 768            // This prevents deadlocks when e.g., client A performs a request to client B and
 769            // client B performs a request to client A. If both clients stop processing further
 770            // messages until their respective request completes, they won't have a chance to
 771            // respond to the other client's request and cause a deadlock.
 772            //
 773            // This arrangement ensures we will attempt to process earlier messages first, but fall
 774            // back to processing messages arrived later in the spirit of making progress.
 775            let mut foreground_message_handlers = FuturesUnordered::new();
 776            let concurrent_handlers = Arc::new(Semaphore::new(256));
 777            loop {
 778                let next_message = async {
 779                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 780                    let message = incoming_rx.next().await;
 781                    (permit, message)
 782                }.fuse();
 783                futures::pin_mut!(next_message);
 784                futures::select_biased! {
 785                    _ = teardown.changed().fuse() => return,
 786                    result = handle_io => {
 787                        if let Err(error) = result {
 788                            tracing::error!(?error, "error handling I/O");
 789                        }
 790                        break;
 791                    }
 792                    _ = foreground_message_handlers.next() => {}
 793                    next_message = next_message => {
 794                        let (permit, message) = next_message;
 795                        if let Some(message) = message {
 796                            let type_name = message.payload_type_name();
 797                            // note: we copy all the fields from the parent span so we can query them in the logs.
 798                            // (https://github.com/tokio-rs/tracing/issues/2670).
 799                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 800                                user_id=field::Empty,
 801                                login=field::Empty,
 802                                impersonator=field::Empty,
 803                            );
 804                            principal.update_span(&span);
 805                            let span_enter = span.enter();
 806                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 807                                let is_background = message.is_background();
 808                                let handle_message = (handler)(message, session.clone());
 809                                drop(span_enter);
 810
 811                                let handle_message = async move {
 812                                    handle_message.await;
 813                                    drop(permit);
 814                                }.instrument(span);
 815                                if is_background {
 816                                    executor.spawn_detached(handle_message);
 817                                } else {
 818                                    foreground_message_handlers.push(handle_message);
 819                                }
 820                            } else {
 821                                tracing::error!("no message handler");
 822                            }
 823                        } else {
 824                            tracing::info!("connection closed");
 825                            break;
 826                        }
 827                    }
 828                }
 829            }
 830
 831            drop(foreground_message_handlers);
 832            tracing::info!("signing out");
 833            if let Err(error) = connection_lost(session, teardown, executor).await {
 834                tracing::error!(?error, "error signing out");
 835            }
 836
 837        }.instrument(span)
 838    }
 839
 840    async fn send_initial_client_update(
 841        &self,
 842        connection_id: ConnectionId,
 843        principal: &Principal,
 844        zed_version: ZedVersion,
 845        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 846        session: &Session,
 847    ) -> Result<()> {
 848        self.peer.send(
 849            connection_id,
 850            proto::Hello {
 851                peer_id: Some(connection_id.into()),
 852            },
 853        )?;
 854        tracing::info!("sent hello message");
 855        if let Some(send_connection_id) = send_connection_id.take() {
 856            let _ = send_connection_id.send(connection_id);
 857        }
 858
 859        match principal {
 860            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 861                if !user.connected_once {
 862                    self.peer.send(connection_id, proto::ShowContacts {})?;
 863                    self.app_state
 864                        .db
 865                        .set_user_connected_once(user.id, true)
 866                        .await?;
 867                }
 868
 869                update_user_plan(user.id, session).await?;
 870
 871                let contacts = self.app_state.db.get_contacts(user.id).await?;
 872
 873                {
 874                    let mut pool = self.connection_pool.lock();
 875                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 876                    self.peer.send(
 877                        connection_id,
 878                        build_initial_contacts_update(contacts, &pool),
 879                    )?;
 880                }
 881
 882                if should_auto_subscribe_to_channels(zed_version) {
 883                    subscribe_user_to_channels(user.id, session).await?;
 884                }
 885
 886                if let Some(incoming_call) =
 887                    self.app_state.db.incoming_call_for_user(user.id).await?
 888                {
 889                    self.peer.send(connection_id, incoming_call)?;
 890                }
 891
 892                update_user_contacts(user.id, session).await?;
 893            }
 894        }
 895
 896        Ok(())
 897    }
 898
 899    pub async fn invite_code_redeemed(
 900        self: &Arc<Self>,
 901        inviter_id: UserId,
 902        invitee_id: UserId,
 903    ) -> Result<()> {
 904        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 905            if let Some(code) = &user.invite_code {
 906                let pool = self.connection_pool.lock();
 907                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 908                for connection_id in pool.user_connection_ids(inviter_id) {
 909                    self.peer.send(
 910                        connection_id,
 911                        proto::UpdateContacts {
 912                            contacts: vec![invitee_contact.clone()],
 913                            ..Default::default()
 914                        },
 915                    )?;
 916                    self.peer.send(
 917                        connection_id,
 918                        proto::UpdateInviteInfo {
 919                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 920                            count: user.invite_count as u32,
 921                        },
 922                    )?;
 923                }
 924            }
 925        }
 926        Ok(())
 927    }
 928
 929    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 930        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 931            if let Some(invite_code) = &user.invite_code {
 932                let pool = self.connection_pool.lock();
 933                for connection_id in pool.user_connection_ids(user_id) {
 934                    self.peer.send(
 935                        connection_id,
 936                        proto::UpdateInviteInfo {
 937                            url: format!(
 938                                "{}{}",
 939                                self.app_state.config.invite_link_prefix, invite_code
 940                            ),
 941                            count: user.invite_count as u32,
 942                        },
 943                    )?;
 944                }
 945            }
 946        }
 947        Ok(())
 948    }
 949
 950    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 951        let pool = self.connection_pool.lock();
 952        for connection_id in pool.user_connection_ids(user_id) {
 953            self.peer
 954                .send(connection_id, proto::RefreshLlmToken {})
 955                .trace_err();
 956        }
 957    }
 958
 959    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 960        ServerSnapshot {
 961            connection_pool: ConnectionPoolGuard {
 962                guard: self.connection_pool.lock(),
 963                _not_send: PhantomData,
 964            },
 965            peer: &self.peer,
 966        }
 967    }
 968}
 969
 970impl<'a> Deref for ConnectionPoolGuard<'a> {
 971    type Target = ConnectionPool;
 972
 973    fn deref(&self) -> &Self::Target {
 974        &self.guard
 975    }
 976}
 977
 978impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 979    fn deref_mut(&mut self) -> &mut Self::Target {
 980        &mut self.guard
 981    }
 982}
 983
 984impl<'a> Drop for ConnectionPoolGuard<'a> {
 985    fn drop(&mut self) {
 986        #[cfg(test)]
 987        self.check_invariants();
 988    }
 989}
 990
 991fn broadcast<F>(
 992    sender_id: Option<ConnectionId>,
 993    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 994    mut f: F,
 995) where
 996    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 997{
 998    for receiver_id in receiver_ids {
 999        if Some(receiver_id) != sender_id {
1000            if let Err(error) = f(receiver_id) {
1001                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1002            }
1003        }
1004    }
1005}
1006
1007pub struct ProtocolVersion(u32);
1008
1009impl Header for ProtocolVersion {
1010    fn name() -> &'static HeaderName {
1011        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1012        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1013    }
1014
1015    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1016    where
1017        Self: Sized,
1018        I: Iterator<Item = &'i axum::http::HeaderValue>,
1019    {
1020        let version = values
1021            .next()
1022            .ok_or_else(axum::headers::Error::invalid)?
1023            .to_str()
1024            .map_err(|_| axum::headers::Error::invalid())?
1025            .parse()
1026            .map_err(|_| axum::headers::Error::invalid())?;
1027        Ok(Self(version))
1028    }
1029
1030    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1031        values.extend([self.0.to_string().parse().unwrap()]);
1032    }
1033}
1034
1035pub struct AppVersionHeader(SemanticVersion);
1036impl Header for AppVersionHeader {
1037    fn name() -> &'static HeaderName {
1038        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1039        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1040    }
1041
1042    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1043    where
1044        Self: Sized,
1045        I: Iterator<Item = &'i axum::http::HeaderValue>,
1046    {
1047        let version = values
1048            .next()
1049            .ok_or_else(axum::headers::Error::invalid)?
1050            .to_str()
1051            .map_err(|_| axum::headers::Error::invalid())?
1052            .parse()
1053            .map_err(|_| axum::headers::Error::invalid())?;
1054        Ok(Self(version))
1055    }
1056
1057    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1058        values.extend([self.0.to_string().parse().unwrap()]);
1059    }
1060}
1061
1062pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1063    Router::new()
1064        .route("/rpc", get(handle_websocket_request))
1065        .layer(
1066            ServiceBuilder::new()
1067                .layer(Extension(server.app_state.clone()))
1068                .layer(middleware::from_fn(auth::validate_header)),
1069        )
1070        .route("/metrics", get(handle_metrics))
1071        .layer(Extension(server))
1072}
1073
1074#[allow(clippy::too_many_arguments)]
1075pub async fn handle_websocket_request(
1076    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1077    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1078    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1079    Extension(server): Extension<Arc<Server>>,
1080    Extension(principal): Extension<Principal>,
1081    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1082    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1083    ws: WebSocketUpgrade,
1084) -> axum::response::Response {
1085    if protocol_version != rpc::PROTOCOL_VERSION {
1086        return (
1087            StatusCode::UPGRADE_REQUIRED,
1088            "client must be upgraded".to_string(),
1089        )
1090            .into_response();
1091    }
1092
1093    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1094        return (
1095            StatusCode::UPGRADE_REQUIRED,
1096            "no version header found".to_string(),
1097        )
1098            .into_response();
1099    };
1100
1101    if !version.can_collaborate() {
1102        return (
1103            StatusCode::UPGRADE_REQUIRED,
1104            "client must be upgraded".to_string(),
1105        )
1106            .into_response();
1107    }
1108
1109    let socket_address = socket_address.to_string();
1110    ws.on_upgrade(move |socket| {
1111        let socket = socket
1112            .map_ok(to_tungstenite_message)
1113            .err_into()
1114            .with(|message| async move { to_axum_message(message) });
1115        let connection = Connection::new(Box::pin(socket));
1116        async move {
1117            server
1118                .handle_connection(
1119                    connection,
1120                    socket_address,
1121                    principal,
1122                    version,
1123                    country_code_header.map(|header| header.to_string()),
1124                    system_id_header.map(|header| header.to_string()),
1125                    None,
1126                    Executor::Production,
1127                )
1128                .await;
1129        }
1130    })
1131}
1132
1133pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1134    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1135    let connections_metric = CONNECTIONS_METRIC
1136        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1137
1138    let connections = server
1139        .connection_pool
1140        .lock()
1141        .connections()
1142        .filter(|connection| !connection.admin)
1143        .count();
1144    connections_metric.set(connections as _);
1145
1146    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1147    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1148        register_int_gauge!(
1149            "shared_projects",
1150            "number of open projects with one or more guests"
1151        )
1152        .unwrap()
1153    });
1154
1155    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1156    shared_projects_metric.set(shared_projects as _);
1157
1158    let encoder = prometheus::TextEncoder::new();
1159    let metric_families = prometheus::gather();
1160    let encoded_metrics = encoder
1161        .encode_to_string(&metric_families)
1162        .map_err(|err| anyhow!("{}", err))?;
1163    Ok(encoded_metrics)
1164}
1165
1166#[instrument(err, skip(executor))]
1167async fn connection_lost(
1168    session: Session,
1169    mut teardown: watch::Receiver<bool>,
1170    executor: Executor,
1171) -> Result<()> {
1172    session.peer.disconnect(session.connection_id);
1173    session
1174        .connection_pool()
1175        .await
1176        .remove_connection(session.connection_id)?;
1177
1178    session
1179        .db()
1180        .await
1181        .connection_lost(session.connection_id)
1182        .await
1183        .trace_err();
1184
1185    futures::select_biased! {
1186        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1187
1188            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1189            leave_room_for_session(&session, session.connection_id).await.trace_err();
1190            leave_channel_buffers_for_session(&session)
1191                .await
1192                .trace_err();
1193
1194            if !session
1195                .connection_pool()
1196                .await
1197                .is_user_online(session.user_id())
1198            {
1199                let db = session.db().await;
1200                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1201                    room_updated(&room, &session.peer);
1202                }
1203            }
1204
1205            update_user_contacts(session.user_id(), &session).await?;
1206        },
1207        _ = teardown.changed().fuse() => {}
1208    }
1209
1210    Ok(())
1211}
1212
1213/// Acknowledges a ping from a client, used to keep the connection alive.
1214async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1215    response.send(proto::Ack {})?;
1216    Ok(())
1217}
1218
1219/// Creates a new room for calling (outside of channels)
1220async fn create_room(
1221    _request: proto::CreateRoom,
1222    response: Response<proto::CreateRoom>,
1223    session: Session,
1224) -> Result<()> {
1225    let livekit_room = nanoid::nanoid!(30);
1226
1227    let live_kit_connection_info = util::maybe!(async {
1228        let live_kit = session.app_state.livekit_client.as_ref();
1229        let live_kit = live_kit?;
1230        let user_id = session.user_id().to_string();
1231
1232        let token = live_kit
1233            .room_token(&livekit_room, &user_id.to_string())
1234            .trace_err()?;
1235
1236        Some(proto::LiveKitConnectionInfo {
1237            server_url: live_kit.url().into(),
1238            token,
1239            can_publish: true,
1240        })
1241    })
1242    .await;
1243
1244    let room = session
1245        .db()
1246        .await
1247        .create_room(session.user_id(), session.connection_id, &livekit_room)
1248        .await?;
1249
1250    response.send(proto::CreateRoomResponse {
1251        room: Some(room.clone()),
1252        live_kit_connection_info,
1253    })?;
1254
1255    update_user_contacts(session.user_id(), &session).await?;
1256    Ok(())
1257}
1258
1259/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1260async fn join_room(
1261    request: proto::JoinRoom,
1262    response: Response<proto::JoinRoom>,
1263    session: Session,
1264) -> Result<()> {
1265    let room_id = RoomId::from_proto(request.id);
1266
1267    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1268
1269    if let Some(channel_id) = channel_id {
1270        return join_channel_internal(channel_id, Box::new(response), session).await;
1271    }
1272
1273    let joined_room = {
1274        let room = session
1275            .db()
1276            .await
1277            .join_room(room_id, session.user_id(), session.connection_id)
1278            .await?;
1279        room_updated(&room.room, &session.peer);
1280        room.into_inner()
1281    };
1282
1283    for connection_id in session
1284        .connection_pool()
1285        .await
1286        .user_connection_ids(session.user_id())
1287    {
1288        session
1289            .peer
1290            .send(
1291                connection_id,
1292                proto::CallCanceled {
1293                    room_id: room_id.to_proto(),
1294                },
1295            )
1296            .trace_err();
1297    }
1298
1299    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1300    {
1301        live_kit
1302            .room_token(
1303                &joined_room.room.livekit_room,
1304                &session.user_id().to_string(),
1305            )
1306            .trace_err()
1307            .map(|token| proto::LiveKitConnectionInfo {
1308                server_url: live_kit.url().into(),
1309                token,
1310                can_publish: true,
1311            })
1312    } else {
1313        None
1314    };
1315
1316    response.send(proto::JoinRoomResponse {
1317        room: Some(joined_room.room),
1318        channel_id: None,
1319        live_kit_connection_info,
1320    })?;
1321
1322    update_user_contacts(session.user_id(), &session).await?;
1323    Ok(())
1324}
1325
1326/// Rejoin room is used to reconnect to a room after connection errors.
1327async fn rejoin_room(
1328    request: proto::RejoinRoom,
1329    response: Response<proto::RejoinRoom>,
1330    session: Session,
1331) -> Result<()> {
1332    let room;
1333    let channel;
1334    {
1335        let mut rejoined_room = session
1336            .db()
1337            .await
1338            .rejoin_room(request, session.user_id(), session.connection_id)
1339            .await?;
1340
1341        response.send(proto::RejoinRoomResponse {
1342            room: Some(rejoined_room.room.clone()),
1343            reshared_projects: rejoined_room
1344                .reshared_projects
1345                .iter()
1346                .map(|project| proto::ResharedProject {
1347                    id: project.id.to_proto(),
1348                    collaborators: project
1349                        .collaborators
1350                        .iter()
1351                        .map(|collaborator| collaborator.to_proto())
1352                        .collect(),
1353                })
1354                .collect(),
1355            rejoined_projects: rejoined_room
1356                .rejoined_projects
1357                .iter()
1358                .map(|rejoined_project| rejoined_project.to_proto())
1359                .collect(),
1360        })?;
1361        room_updated(&rejoined_room.room, &session.peer);
1362
1363        for project in &rejoined_room.reshared_projects {
1364            for collaborator in &project.collaborators {
1365                session
1366                    .peer
1367                    .send(
1368                        collaborator.connection_id,
1369                        proto::UpdateProjectCollaborator {
1370                            project_id: project.id.to_proto(),
1371                            old_peer_id: Some(project.old_connection_id.into()),
1372                            new_peer_id: Some(session.connection_id.into()),
1373                        },
1374                    )
1375                    .trace_err();
1376            }
1377
1378            broadcast(
1379                Some(session.connection_id),
1380                project
1381                    .collaborators
1382                    .iter()
1383                    .map(|collaborator| collaborator.connection_id),
1384                |connection_id| {
1385                    session.peer.forward_send(
1386                        session.connection_id,
1387                        connection_id,
1388                        proto::UpdateProject {
1389                            project_id: project.id.to_proto(),
1390                            worktrees: project.worktrees.clone(),
1391                        },
1392                    )
1393                },
1394            );
1395        }
1396
1397        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1398
1399        let rejoined_room = rejoined_room.into_inner();
1400
1401        room = rejoined_room.room;
1402        channel = rejoined_room.channel;
1403    }
1404
1405    if let Some(channel) = channel {
1406        channel_updated(
1407            &channel,
1408            &room,
1409            &session.peer,
1410            &*session.connection_pool().await,
1411        );
1412    }
1413
1414    update_user_contacts(session.user_id(), &session).await?;
1415    Ok(())
1416}
1417
1418fn notify_rejoined_projects(
1419    rejoined_projects: &mut Vec<RejoinedProject>,
1420    session: &Session,
1421) -> Result<()> {
1422    for project in rejoined_projects.iter() {
1423        for collaborator in &project.collaborators {
1424            session
1425                .peer
1426                .send(
1427                    collaborator.connection_id,
1428                    proto::UpdateProjectCollaborator {
1429                        project_id: project.id.to_proto(),
1430                        old_peer_id: Some(project.old_connection_id.into()),
1431                        new_peer_id: Some(session.connection_id.into()),
1432                    },
1433                )
1434                .trace_err();
1435        }
1436    }
1437
1438    for project in rejoined_projects {
1439        for worktree in mem::take(&mut project.worktrees) {
1440            // Stream this worktree's entries.
1441            let message = proto::UpdateWorktree {
1442                project_id: project.id.to_proto(),
1443                worktree_id: worktree.id,
1444                abs_path: worktree.abs_path.clone(),
1445                root_name: worktree.root_name,
1446                updated_entries: worktree.updated_entries,
1447                removed_entries: worktree.removed_entries,
1448                scan_id: worktree.scan_id,
1449                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1450                updated_repositories: worktree.updated_repositories,
1451                removed_repositories: worktree.removed_repositories,
1452            };
1453            for update in proto::split_worktree_update(message) {
1454                session.peer.send(session.connection_id, update.clone())?;
1455            }
1456
1457            // Stream this worktree's diagnostics.
1458            for summary in worktree.diagnostic_summaries {
1459                session.peer.send(
1460                    session.connection_id,
1461                    proto::UpdateDiagnosticSummary {
1462                        project_id: project.id.to_proto(),
1463                        worktree_id: worktree.id,
1464                        summary: Some(summary),
1465                    },
1466                )?;
1467            }
1468
1469            for settings_file in worktree.settings_files {
1470                session.peer.send(
1471                    session.connection_id,
1472                    proto::UpdateWorktreeSettings {
1473                        project_id: project.id.to_proto(),
1474                        worktree_id: worktree.id,
1475                        path: settings_file.path,
1476                        content: Some(settings_file.content),
1477                        kind: Some(settings_file.kind.to_proto().into()),
1478                    },
1479                )?;
1480            }
1481        }
1482
1483        for language_server in &project.language_servers {
1484            session.peer.send(
1485                session.connection_id,
1486                proto::UpdateLanguageServer {
1487                    project_id: project.id.to_proto(),
1488                    language_server_id: language_server.id,
1489                    variant: Some(
1490                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1491                            proto::LspDiskBasedDiagnosticsUpdated {},
1492                        ),
1493                    ),
1494                },
1495            )?;
1496        }
1497    }
1498    Ok(())
1499}
1500
1501/// leave room disconnects from the room.
1502async fn leave_room(
1503    _: proto::LeaveRoom,
1504    response: Response<proto::LeaveRoom>,
1505    session: Session,
1506) -> Result<()> {
1507    leave_room_for_session(&session, session.connection_id).await?;
1508    response.send(proto::Ack {})?;
1509    Ok(())
1510}
1511
1512/// Updates the permissions of someone else in the room.
1513async fn set_room_participant_role(
1514    request: proto::SetRoomParticipantRole,
1515    response: Response<proto::SetRoomParticipantRole>,
1516    session: Session,
1517) -> Result<()> {
1518    let user_id = UserId::from_proto(request.user_id);
1519    let role = ChannelRole::from(request.role());
1520
1521    let (livekit_room, can_publish) = {
1522        let room = session
1523            .db()
1524            .await
1525            .set_room_participant_role(
1526                session.user_id(),
1527                RoomId::from_proto(request.room_id),
1528                user_id,
1529                role,
1530            )
1531            .await?;
1532
1533        let livekit_room = room.livekit_room.clone();
1534        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1535        room_updated(&room, &session.peer);
1536        (livekit_room, can_publish)
1537    };
1538
1539    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1540        live_kit
1541            .update_participant(
1542                livekit_room.clone(),
1543                request.user_id.to_string(),
1544                livekit_server::proto::ParticipantPermission {
1545                    can_subscribe: true,
1546                    can_publish,
1547                    can_publish_data: can_publish,
1548                    hidden: false,
1549                    recorder: false,
1550                },
1551            )
1552            .await
1553            .trace_err();
1554    }
1555
1556    response.send(proto::Ack {})?;
1557    Ok(())
1558}
1559
1560/// Call someone else into the current room
1561async fn call(
1562    request: proto::Call,
1563    response: Response<proto::Call>,
1564    session: Session,
1565) -> Result<()> {
1566    let room_id = RoomId::from_proto(request.room_id);
1567    let calling_user_id = session.user_id();
1568    let calling_connection_id = session.connection_id;
1569    let called_user_id = UserId::from_proto(request.called_user_id);
1570    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1571    if !session
1572        .db()
1573        .await
1574        .has_contact(calling_user_id, called_user_id)
1575        .await?
1576    {
1577        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1578    }
1579
1580    let incoming_call = {
1581        let (room, incoming_call) = &mut *session
1582            .db()
1583            .await
1584            .call(
1585                room_id,
1586                calling_user_id,
1587                calling_connection_id,
1588                called_user_id,
1589                initial_project_id,
1590            )
1591            .await?;
1592        room_updated(room, &session.peer);
1593        mem::take(incoming_call)
1594    };
1595    update_user_contacts(called_user_id, &session).await?;
1596
1597    let mut calls = session
1598        .connection_pool()
1599        .await
1600        .user_connection_ids(called_user_id)
1601        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1602        .collect::<FuturesUnordered<_>>();
1603
1604    while let Some(call_response) = calls.next().await {
1605        match call_response.as_ref() {
1606            Ok(_) => {
1607                response.send(proto::Ack {})?;
1608                return Ok(());
1609            }
1610            Err(_) => {
1611                call_response.trace_err();
1612            }
1613        }
1614    }
1615
1616    {
1617        let room = session
1618            .db()
1619            .await
1620            .call_failed(room_id, called_user_id)
1621            .await?;
1622        room_updated(&room, &session.peer);
1623    }
1624    update_user_contacts(called_user_id, &session).await?;
1625
1626    Err(anyhow!("failed to ring user"))?
1627}
1628
1629/// Cancel an outgoing call.
1630async fn cancel_call(
1631    request: proto::CancelCall,
1632    response: Response<proto::CancelCall>,
1633    session: Session,
1634) -> Result<()> {
1635    let called_user_id = UserId::from_proto(request.called_user_id);
1636    let room_id = RoomId::from_proto(request.room_id);
1637    {
1638        let room = session
1639            .db()
1640            .await
1641            .cancel_call(room_id, session.connection_id, called_user_id)
1642            .await?;
1643        room_updated(&room, &session.peer);
1644    }
1645
1646    for connection_id in session
1647        .connection_pool()
1648        .await
1649        .user_connection_ids(called_user_id)
1650    {
1651        session
1652            .peer
1653            .send(
1654                connection_id,
1655                proto::CallCanceled {
1656                    room_id: room_id.to_proto(),
1657                },
1658            )
1659            .trace_err();
1660    }
1661    response.send(proto::Ack {})?;
1662
1663    update_user_contacts(called_user_id, &session).await?;
1664    Ok(())
1665}
1666
1667/// Decline an incoming call.
1668async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1669    let room_id = RoomId::from_proto(message.room_id);
1670    {
1671        let room = session
1672            .db()
1673            .await
1674            .decline_call(Some(room_id), session.user_id())
1675            .await?
1676            .ok_or_else(|| anyhow!("failed to decline call"))?;
1677        room_updated(&room, &session.peer);
1678    }
1679
1680    for connection_id in session
1681        .connection_pool()
1682        .await
1683        .user_connection_ids(session.user_id())
1684    {
1685        session
1686            .peer
1687            .send(
1688                connection_id,
1689                proto::CallCanceled {
1690                    room_id: room_id.to_proto(),
1691                },
1692            )
1693            .trace_err();
1694    }
1695    update_user_contacts(session.user_id(), &session).await?;
1696    Ok(())
1697}
1698
1699/// Updates other participants in the room with your current location.
1700async fn update_participant_location(
1701    request: proto::UpdateParticipantLocation,
1702    response: Response<proto::UpdateParticipantLocation>,
1703    session: Session,
1704) -> Result<()> {
1705    let room_id = RoomId::from_proto(request.room_id);
1706    let location = request
1707        .location
1708        .ok_or_else(|| anyhow!("invalid location"))?;
1709
1710    let db = session.db().await;
1711    let room = db
1712        .update_room_participant_location(room_id, session.connection_id, location)
1713        .await?;
1714
1715    room_updated(&room, &session.peer);
1716    response.send(proto::Ack {})?;
1717    Ok(())
1718}
1719
1720/// Share a project into the room.
1721async fn share_project(
1722    request: proto::ShareProject,
1723    response: Response<proto::ShareProject>,
1724    session: Session,
1725) -> Result<()> {
1726    let (project_id, room) = &*session
1727        .db()
1728        .await
1729        .share_project(
1730            RoomId::from_proto(request.room_id),
1731            session.connection_id,
1732            &request.worktrees,
1733            request.is_ssh_project,
1734        )
1735        .await?;
1736    response.send(proto::ShareProjectResponse {
1737        project_id: project_id.to_proto(),
1738    })?;
1739    room_updated(room, &session.peer);
1740
1741    Ok(())
1742}
1743
1744/// Unshare a project from the room.
1745async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1746    let project_id = ProjectId::from_proto(message.project_id);
1747    unshare_project_internal(project_id, session.connection_id, &session).await
1748}
1749
1750async fn unshare_project_internal(
1751    project_id: ProjectId,
1752    connection_id: ConnectionId,
1753    session: &Session,
1754) -> Result<()> {
1755    let delete = {
1756        let room_guard = session
1757            .db()
1758            .await
1759            .unshare_project(project_id, connection_id)
1760            .await?;
1761
1762        let (delete, room, guest_connection_ids) = &*room_guard;
1763
1764        let message = proto::UnshareProject {
1765            project_id: project_id.to_proto(),
1766        };
1767
1768        broadcast(
1769            Some(connection_id),
1770            guest_connection_ids.iter().copied(),
1771            |conn_id| session.peer.send(conn_id, message.clone()),
1772        );
1773        if let Some(room) = room {
1774            room_updated(room, &session.peer);
1775        }
1776
1777        *delete
1778    };
1779
1780    if delete {
1781        let db = session.db().await;
1782        db.delete_project(project_id).await?;
1783    }
1784
1785    Ok(())
1786}
1787
1788/// Join someone elses shared project.
1789async fn join_project(
1790    request: proto::JoinProject,
1791    response: Response<proto::JoinProject>,
1792    session: Session,
1793) -> Result<()> {
1794    let project_id = ProjectId::from_proto(request.project_id);
1795
1796    tracing::info!(%project_id, "join project");
1797
1798    let db = session.db().await;
1799    let (project, replica_id) = &mut *db
1800        .join_project(project_id, session.connection_id, session.user_id())
1801        .await?;
1802    drop(db);
1803    tracing::info!(%project_id, "join remote project");
1804    join_project_internal(response, session, project, replica_id)
1805}
1806
1807trait JoinProjectInternalResponse {
1808    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1809}
1810impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1811    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1812        Response::<proto::JoinProject>::send(self, result)
1813    }
1814}
1815
1816fn join_project_internal(
1817    response: impl JoinProjectInternalResponse,
1818    session: Session,
1819    project: &mut Project,
1820    replica_id: &ReplicaId,
1821) -> Result<()> {
1822    let collaborators = project
1823        .collaborators
1824        .iter()
1825        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1826        .map(|collaborator| collaborator.to_proto())
1827        .collect::<Vec<_>>();
1828    let project_id = project.id;
1829    let guest_user_id = session.user_id();
1830
1831    let worktrees = project
1832        .worktrees
1833        .iter()
1834        .map(|(id, worktree)| proto::WorktreeMetadata {
1835            id: *id,
1836            root_name: worktree.root_name.clone(),
1837            visible: worktree.visible,
1838            abs_path: worktree.abs_path.clone(),
1839        })
1840        .collect::<Vec<_>>();
1841
1842    let add_project_collaborator = proto::AddProjectCollaborator {
1843        project_id: project_id.to_proto(),
1844        collaborator: Some(proto::Collaborator {
1845            peer_id: Some(session.connection_id.into()),
1846            replica_id: replica_id.0 as u32,
1847            user_id: guest_user_id.to_proto(),
1848            is_host: false,
1849        }),
1850    };
1851
1852    for collaborator in &collaborators {
1853        session
1854            .peer
1855            .send(
1856                collaborator.peer_id.unwrap().into(),
1857                add_project_collaborator.clone(),
1858            )
1859            .trace_err();
1860    }
1861
1862    // First, we send the metadata associated with each worktree.
1863    response.send(proto::JoinProjectResponse {
1864        project_id: project.id.0 as u64,
1865        worktrees: worktrees.clone(),
1866        replica_id: replica_id.0 as u32,
1867        collaborators: collaborators.clone(),
1868        language_servers: project.language_servers.clone(),
1869        role: project.role.into(),
1870    })?;
1871
1872    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1873        // Stream this worktree's entries.
1874        let message = proto::UpdateWorktree {
1875            project_id: project_id.to_proto(),
1876            worktree_id,
1877            abs_path: worktree.abs_path.clone(),
1878            root_name: worktree.root_name,
1879            updated_entries: worktree.entries,
1880            removed_entries: Default::default(),
1881            scan_id: worktree.scan_id,
1882            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1883            updated_repositories: worktree.repository_entries.into_values().collect(),
1884            removed_repositories: Default::default(),
1885        };
1886        for update in proto::split_worktree_update(message) {
1887            session.peer.send(session.connection_id, update.clone())?;
1888        }
1889
1890        // Stream this worktree's diagnostics.
1891        for summary in worktree.diagnostic_summaries {
1892            session.peer.send(
1893                session.connection_id,
1894                proto::UpdateDiagnosticSummary {
1895                    project_id: project_id.to_proto(),
1896                    worktree_id: worktree.id,
1897                    summary: Some(summary),
1898                },
1899            )?;
1900        }
1901
1902        for settings_file in worktree.settings_files {
1903            session.peer.send(
1904                session.connection_id,
1905                proto::UpdateWorktreeSettings {
1906                    project_id: project_id.to_proto(),
1907                    worktree_id: worktree.id,
1908                    path: settings_file.path,
1909                    content: Some(settings_file.content),
1910                    kind: Some(settings_file.kind.to_proto() as i32),
1911                },
1912            )?;
1913        }
1914    }
1915
1916    for language_server in &project.language_servers {
1917        session.peer.send(
1918            session.connection_id,
1919            proto::UpdateLanguageServer {
1920                project_id: project_id.to_proto(),
1921                language_server_id: language_server.id,
1922                variant: Some(
1923                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1924                        proto::LspDiskBasedDiagnosticsUpdated {},
1925                    ),
1926                ),
1927            },
1928        )?;
1929    }
1930
1931    Ok(())
1932}
1933
1934/// Leave someone elses shared project.
1935async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1936    let sender_id = session.connection_id;
1937    let project_id = ProjectId::from_proto(request.project_id);
1938    let db = session.db().await;
1939
1940    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1941    tracing::info!(
1942        %project_id,
1943        "leave project"
1944    );
1945
1946    project_left(project, &session);
1947    if let Some(room) = room {
1948        room_updated(room, &session.peer);
1949    }
1950
1951    Ok(())
1952}
1953
1954/// Updates other participants with changes to the project
1955async fn update_project(
1956    request: proto::UpdateProject,
1957    response: Response<proto::UpdateProject>,
1958    session: Session,
1959) -> Result<()> {
1960    let project_id = ProjectId::from_proto(request.project_id);
1961    let (room, guest_connection_ids) = &*session
1962        .db()
1963        .await
1964        .update_project(project_id, session.connection_id, &request.worktrees)
1965        .await?;
1966    broadcast(
1967        Some(session.connection_id),
1968        guest_connection_ids.iter().copied(),
1969        |connection_id| {
1970            session
1971                .peer
1972                .forward_send(session.connection_id, connection_id, request.clone())
1973        },
1974    );
1975    if let Some(room) = room {
1976        room_updated(room, &session.peer);
1977    }
1978    response.send(proto::Ack {})?;
1979
1980    Ok(())
1981}
1982
1983/// Updates other participants with changes to the worktree
1984async fn update_worktree(
1985    request: proto::UpdateWorktree,
1986    response: Response<proto::UpdateWorktree>,
1987    session: Session,
1988) -> Result<()> {
1989    let guest_connection_ids = session
1990        .db()
1991        .await
1992        .update_worktree(&request, session.connection_id)
1993        .await?;
1994
1995    broadcast(
1996        Some(session.connection_id),
1997        guest_connection_ids.iter().copied(),
1998        |connection_id| {
1999            session
2000                .peer
2001                .forward_send(session.connection_id, connection_id, request.clone())
2002        },
2003    );
2004    response.send(proto::Ack {})?;
2005    Ok(())
2006}
2007
2008/// Updates other participants with changes to the diagnostics
2009async fn update_diagnostic_summary(
2010    message: proto::UpdateDiagnosticSummary,
2011    session: Session,
2012) -> Result<()> {
2013    let guest_connection_ids = session
2014        .db()
2015        .await
2016        .update_diagnostic_summary(&message, session.connection_id)
2017        .await?;
2018
2019    broadcast(
2020        Some(session.connection_id),
2021        guest_connection_ids.iter().copied(),
2022        |connection_id| {
2023            session
2024                .peer
2025                .forward_send(session.connection_id, connection_id, message.clone())
2026        },
2027    );
2028
2029    Ok(())
2030}
2031
2032/// Updates other participants with changes to the worktree settings
2033async fn update_worktree_settings(
2034    message: proto::UpdateWorktreeSettings,
2035    session: Session,
2036) -> Result<()> {
2037    let guest_connection_ids = session
2038        .db()
2039        .await
2040        .update_worktree_settings(&message, session.connection_id)
2041        .await?;
2042
2043    broadcast(
2044        Some(session.connection_id),
2045        guest_connection_ids.iter().copied(),
2046        |connection_id| {
2047            session
2048                .peer
2049                .forward_send(session.connection_id, connection_id, message.clone())
2050        },
2051    );
2052
2053    Ok(())
2054}
2055
2056/// Notify other participants that a  language server has started.
2057async fn start_language_server(
2058    request: proto::StartLanguageServer,
2059    session: Session,
2060) -> Result<()> {
2061    let guest_connection_ids = session
2062        .db()
2063        .await
2064        .start_language_server(&request, session.connection_id)
2065        .await?;
2066
2067    broadcast(
2068        Some(session.connection_id),
2069        guest_connection_ids.iter().copied(),
2070        |connection_id| {
2071            session
2072                .peer
2073                .forward_send(session.connection_id, connection_id, request.clone())
2074        },
2075    );
2076    Ok(())
2077}
2078
2079/// Notify other participants that a language server has changed.
2080async fn update_language_server(
2081    request: proto::UpdateLanguageServer,
2082    session: Session,
2083) -> Result<()> {
2084    let project_id = ProjectId::from_proto(request.project_id);
2085    let project_connection_ids = session
2086        .db()
2087        .await
2088        .project_connection_ids(project_id, session.connection_id, true)
2089        .await?;
2090    broadcast(
2091        Some(session.connection_id),
2092        project_connection_ids.iter().copied(),
2093        |connection_id| {
2094            session
2095                .peer
2096                .forward_send(session.connection_id, connection_id, request.clone())
2097        },
2098    );
2099    Ok(())
2100}
2101
2102/// forward a project request to the host. These requests should be read only
2103/// as guests are allowed to send them.
2104async fn forward_read_only_project_request<T>(
2105    request: T,
2106    response: Response<T>,
2107    session: Session,
2108) -> Result<()>
2109where
2110    T: EntityMessage + RequestMessage,
2111{
2112    let project_id = ProjectId::from_proto(request.remote_entity_id());
2113    let host_connection_id = session
2114        .db()
2115        .await
2116        .host_for_read_only_project_request(project_id, session.connection_id)
2117        .await?;
2118    let payload = session
2119        .peer
2120        .forward_request(session.connection_id, host_connection_id, request)
2121        .await?;
2122    response.send(payload)?;
2123    Ok(())
2124}
2125
2126async fn forward_find_search_candidates_request(
2127    request: proto::FindSearchCandidates,
2128    response: Response<proto::FindSearchCandidates>,
2129    session: Session,
2130) -> Result<()> {
2131    let project_id = ProjectId::from_proto(request.remote_entity_id());
2132    let host_connection_id = session
2133        .db()
2134        .await
2135        .host_for_read_only_project_request(project_id, session.connection_id)
2136        .await?;
2137    let payload = session
2138        .peer
2139        .forward_request(session.connection_id, host_connection_id, request)
2140        .await?;
2141    response.send(payload)?;
2142    Ok(())
2143}
2144
2145/// forward a project request to the host. These requests are disallowed
2146/// for guests.
2147async fn forward_mutating_project_request<T>(
2148    request: T,
2149    response: Response<T>,
2150    session: Session,
2151) -> Result<()>
2152where
2153    T: EntityMessage + RequestMessage,
2154{
2155    let project_id = ProjectId::from_proto(request.remote_entity_id());
2156
2157    let host_connection_id = session
2158        .db()
2159        .await
2160        .host_for_mutating_project_request(project_id, session.connection_id)
2161        .await?;
2162    let payload = session
2163        .peer
2164        .forward_request(session.connection_id, host_connection_id, request)
2165        .await?;
2166    response.send(payload)?;
2167    Ok(())
2168}
2169
2170/// Notify other participants that a new buffer has been created
2171async fn create_buffer_for_peer(
2172    request: proto::CreateBufferForPeer,
2173    session: Session,
2174) -> Result<()> {
2175    session
2176        .db()
2177        .await
2178        .check_user_is_project_host(
2179            ProjectId::from_proto(request.project_id),
2180            session.connection_id,
2181        )
2182        .await?;
2183    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2184    session
2185        .peer
2186        .forward_send(session.connection_id, peer_id.into(), request)?;
2187    Ok(())
2188}
2189
2190/// Notify other participants that a buffer has been updated. This is
2191/// allowed for guests as long as the update is limited to selections.
2192async fn update_buffer(
2193    request: proto::UpdateBuffer,
2194    response: Response<proto::UpdateBuffer>,
2195    session: Session,
2196) -> Result<()> {
2197    let project_id = ProjectId::from_proto(request.project_id);
2198    let mut capability = Capability::ReadOnly;
2199
2200    for op in request.operations.iter() {
2201        match op.variant {
2202            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2203            Some(_) => capability = Capability::ReadWrite,
2204        }
2205    }
2206
2207    let host = {
2208        let guard = session
2209            .db()
2210            .await
2211            .connections_for_buffer_update(project_id, session.connection_id, capability)
2212            .await?;
2213
2214        let (host, guests) = &*guard;
2215
2216        broadcast(
2217            Some(session.connection_id),
2218            guests.clone(),
2219            |connection_id| {
2220                session
2221                    .peer
2222                    .forward_send(session.connection_id, connection_id, request.clone())
2223            },
2224        );
2225
2226        *host
2227    };
2228
2229    if host != session.connection_id {
2230        session
2231            .peer
2232            .forward_request(session.connection_id, host, request.clone())
2233            .await?;
2234    }
2235
2236    response.send(proto::Ack {})?;
2237    Ok(())
2238}
2239
2240async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2241    let project_id = ProjectId::from_proto(message.project_id);
2242
2243    let operation = message.operation.as_ref().context("invalid operation")?;
2244    let capability = match operation.variant.as_ref() {
2245        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2246            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2247                match buffer_op.variant {
2248                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2249                        Capability::ReadOnly
2250                    }
2251                    _ => Capability::ReadWrite,
2252                }
2253            } else {
2254                Capability::ReadWrite
2255            }
2256        }
2257        Some(_) => Capability::ReadWrite,
2258        None => Capability::ReadOnly,
2259    };
2260
2261    let guard = session
2262        .db()
2263        .await
2264        .connections_for_buffer_update(project_id, session.connection_id, capability)
2265        .await?;
2266
2267    let (host, guests) = &*guard;
2268
2269    broadcast(
2270        Some(session.connection_id),
2271        guests.iter().chain([host]).copied(),
2272        |connection_id| {
2273            session
2274                .peer
2275                .forward_send(session.connection_id, connection_id, message.clone())
2276        },
2277    );
2278
2279    Ok(())
2280}
2281
2282/// Notify other participants that a project has been updated.
2283async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2284    request: T,
2285    session: Session,
2286) -> Result<()> {
2287    let project_id = ProjectId::from_proto(request.remote_entity_id());
2288    let project_connection_ids = session
2289        .db()
2290        .await
2291        .project_connection_ids(project_id, session.connection_id, false)
2292        .await?;
2293
2294    broadcast(
2295        Some(session.connection_id),
2296        project_connection_ids.iter().copied(),
2297        |connection_id| {
2298            session
2299                .peer
2300                .forward_send(session.connection_id, connection_id, request.clone())
2301        },
2302    );
2303    Ok(())
2304}
2305
2306/// Start following another user in a call.
2307async fn follow(
2308    request: proto::Follow,
2309    response: Response<proto::Follow>,
2310    session: Session,
2311) -> Result<()> {
2312    let room_id = RoomId::from_proto(request.room_id);
2313    let project_id = request.project_id.map(ProjectId::from_proto);
2314    let leader_id = request
2315        .leader_id
2316        .ok_or_else(|| anyhow!("invalid leader id"))?
2317        .into();
2318    let follower_id = session.connection_id;
2319
2320    session
2321        .db()
2322        .await
2323        .check_room_participants(room_id, leader_id, session.connection_id)
2324        .await?;
2325
2326    let response_payload = session
2327        .peer
2328        .forward_request(session.connection_id, leader_id, request)
2329        .await?;
2330    response.send(response_payload)?;
2331
2332    if let Some(project_id) = project_id {
2333        let room = session
2334            .db()
2335            .await
2336            .follow(room_id, project_id, leader_id, follower_id)
2337            .await?;
2338        room_updated(&room, &session.peer);
2339    }
2340
2341    Ok(())
2342}
2343
2344/// Stop following another user in a call.
2345async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2346    let room_id = RoomId::from_proto(request.room_id);
2347    let project_id = request.project_id.map(ProjectId::from_proto);
2348    let leader_id = request
2349        .leader_id
2350        .ok_or_else(|| anyhow!("invalid leader id"))?
2351        .into();
2352    let follower_id = session.connection_id;
2353
2354    session
2355        .db()
2356        .await
2357        .check_room_participants(room_id, leader_id, session.connection_id)
2358        .await?;
2359
2360    session
2361        .peer
2362        .forward_send(session.connection_id, leader_id, request)?;
2363
2364    if let Some(project_id) = project_id {
2365        let room = session
2366            .db()
2367            .await
2368            .unfollow(room_id, project_id, leader_id, follower_id)
2369            .await?;
2370        room_updated(&room, &session.peer);
2371    }
2372
2373    Ok(())
2374}
2375
2376/// Notify everyone following you of your current location.
2377async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2378    let room_id = RoomId::from_proto(request.room_id);
2379    let database = session.db.lock().await;
2380
2381    let connection_ids = if let Some(project_id) = request.project_id {
2382        let project_id = ProjectId::from_proto(project_id);
2383        database
2384            .project_connection_ids(project_id, session.connection_id, true)
2385            .await?
2386    } else {
2387        database
2388            .room_connection_ids(room_id, session.connection_id)
2389            .await?
2390    };
2391
2392    // For now, don't send view update messages back to that view's current leader.
2393    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2394        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2395        _ => None,
2396    });
2397
2398    for connection_id in connection_ids.iter().cloned() {
2399        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2400            session
2401                .peer
2402                .forward_send(session.connection_id, connection_id, request.clone())?;
2403        }
2404    }
2405    Ok(())
2406}
2407
2408/// Get public data about users.
2409async fn get_users(
2410    request: proto::GetUsers,
2411    response: Response<proto::GetUsers>,
2412    session: Session,
2413) -> Result<()> {
2414    let user_ids = request
2415        .user_ids
2416        .into_iter()
2417        .map(UserId::from_proto)
2418        .collect();
2419    let users = session
2420        .db()
2421        .await
2422        .get_users_by_ids(user_ids)
2423        .await?
2424        .into_iter()
2425        .map(|user| proto::User {
2426            id: user.id.to_proto(),
2427            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2428            github_login: user.github_login,
2429            email: user.email_address,
2430            name: user.name,
2431        })
2432        .collect();
2433    response.send(proto::UsersResponse { users })?;
2434    Ok(())
2435}
2436
2437/// Search for users (to invite) buy Github login
2438async fn fuzzy_search_users(
2439    request: proto::FuzzySearchUsers,
2440    response: Response<proto::FuzzySearchUsers>,
2441    session: Session,
2442) -> Result<()> {
2443    let query = request.query;
2444    let users = match query.len() {
2445        0 => vec![],
2446        1 | 2 => session
2447            .db()
2448            .await
2449            .get_user_by_github_login(&query)
2450            .await?
2451            .into_iter()
2452            .collect(),
2453        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2454    };
2455    let users = users
2456        .into_iter()
2457        .filter(|user| user.id != session.user_id())
2458        .map(|user| proto::User {
2459            id: user.id.to_proto(),
2460            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2461            github_login: user.github_login,
2462            name: user.name,
2463            email: user.email_address,
2464        })
2465        .collect();
2466    response.send(proto::UsersResponse { users })?;
2467    Ok(())
2468}
2469
2470/// Send a contact request to another user.
2471async fn request_contact(
2472    request: proto::RequestContact,
2473    response: Response<proto::RequestContact>,
2474    session: Session,
2475) -> Result<()> {
2476    let requester_id = session.user_id();
2477    let responder_id = UserId::from_proto(request.responder_id);
2478    if requester_id == responder_id {
2479        return Err(anyhow!("cannot add yourself as a contact"))?;
2480    }
2481
2482    let notifications = session
2483        .db()
2484        .await
2485        .send_contact_request(requester_id, responder_id)
2486        .await?;
2487
2488    // Update outgoing contact requests of requester
2489    let mut update = proto::UpdateContacts::default();
2490    update.outgoing_requests.push(responder_id.to_proto());
2491    for connection_id in session
2492        .connection_pool()
2493        .await
2494        .user_connection_ids(requester_id)
2495    {
2496        session.peer.send(connection_id, update.clone())?;
2497    }
2498
2499    // Update incoming contact requests of responder
2500    let mut update = proto::UpdateContacts::default();
2501    update
2502        .incoming_requests
2503        .push(proto::IncomingContactRequest {
2504            requester_id: requester_id.to_proto(),
2505        });
2506    let connection_pool = session.connection_pool().await;
2507    for connection_id in connection_pool.user_connection_ids(responder_id) {
2508        session.peer.send(connection_id, update.clone())?;
2509    }
2510
2511    send_notifications(&connection_pool, &session.peer, notifications);
2512
2513    response.send(proto::Ack {})?;
2514    Ok(())
2515}
2516
2517/// Accept or decline a contact request
2518async fn respond_to_contact_request(
2519    request: proto::RespondToContactRequest,
2520    response: Response<proto::RespondToContactRequest>,
2521    session: Session,
2522) -> Result<()> {
2523    let responder_id = session.user_id();
2524    let requester_id = UserId::from_proto(request.requester_id);
2525    let db = session.db().await;
2526    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2527        db.dismiss_contact_notification(responder_id, requester_id)
2528            .await?;
2529    } else {
2530        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2531
2532        let notifications = db
2533            .respond_to_contact_request(responder_id, requester_id, accept)
2534            .await?;
2535        let requester_busy = db.is_user_busy(requester_id).await?;
2536        let responder_busy = db.is_user_busy(responder_id).await?;
2537
2538        let pool = session.connection_pool().await;
2539        // Update responder with new contact
2540        let mut update = proto::UpdateContacts::default();
2541        if accept {
2542            update
2543                .contacts
2544                .push(contact_for_user(requester_id, requester_busy, &pool));
2545        }
2546        update
2547            .remove_incoming_requests
2548            .push(requester_id.to_proto());
2549        for connection_id in pool.user_connection_ids(responder_id) {
2550            session.peer.send(connection_id, update.clone())?;
2551        }
2552
2553        // Update requester with new contact
2554        let mut update = proto::UpdateContacts::default();
2555        if accept {
2556            update
2557                .contacts
2558                .push(contact_for_user(responder_id, responder_busy, &pool));
2559        }
2560        update
2561            .remove_outgoing_requests
2562            .push(responder_id.to_proto());
2563
2564        for connection_id in pool.user_connection_ids(requester_id) {
2565            session.peer.send(connection_id, update.clone())?;
2566        }
2567
2568        send_notifications(&pool, &session.peer, notifications);
2569    }
2570
2571    response.send(proto::Ack {})?;
2572    Ok(())
2573}
2574
2575/// Remove a contact.
2576async fn remove_contact(
2577    request: proto::RemoveContact,
2578    response: Response<proto::RemoveContact>,
2579    session: Session,
2580) -> Result<()> {
2581    let requester_id = session.user_id();
2582    let responder_id = UserId::from_proto(request.user_id);
2583    let db = session.db().await;
2584    let (contact_accepted, deleted_notification_id) =
2585        db.remove_contact(requester_id, responder_id).await?;
2586
2587    let pool = session.connection_pool().await;
2588    // Update outgoing contact requests of requester
2589    let mut update = proto::UpdateContacts::default();
2590    if contact_accepted {
2591        update.remove_contacts.push(responder_id.to_proto());
2592    } else {
2593        update
2594            .remove_outgoing_requests
2595            .push(responder_id.to_proto());
2596    }
2597    for connection_id in pool.user_connection_ids(requester_id) {
2598        session.peer.send(connection_id, update.clone())?;
2599    }
2600
2601    // Update incoming contact requests of responder
2602    let mut update = proto::UpdateContacts::default();
2603    if contact_accepted {
2604        update.remove_contacts.push(requester_id.to_proto());
2605    } else {
2606        update
2607            .remove_incoming_requests
2608            .push(requester_id.to_proto());
2609    }
2610    for connection_id in pool.user_connection_ids(responder_id) {
2611        session.peer.send(connection_id, update.clone())?;
2612        if let Some(notification_id) = deleted_notification_id {
2613            session.peer.send(
2614                connection_id,
2615                proto::DeleteNotification {
2616                    notification_id: notification_id.to_proto(),
2617                },
2618            )?;
2619        }
2620    }
2621
2622    response.send(proto::Ack {})?;
2623    Ok(())
2624}
2625
2626fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2627    version.0.minor() < 139
2628}
2629
2630async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2631    let plan = session.current_plan(&session.db().await).await?;
2632
2633    session
2634        .peer
2635        .send(
2636            session.connection_id,
2637            proto::UpdateUserPlan { plan: plan.into() },
2638        )
2639        .trace_err();
2640
2641    Ok(())
2642}
2643
2644async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2645    subscribe_user_to_channels(session.user_id(), &session).await?;
2646    Ok(())
2647}
2648
2649async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2650    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2651    let mut pool = session.connection_pool().await;
2652    for membership in &channels_for_user.channel_memberships {
2653        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2654    }
2655    session.peer.send(
2656        session.connection_id,
2657        build_update_user_channels(&channels_for_user),
2658    )?;
2659    session.peer.send(
2660        session.connection_id,
2661        build_channels_update(channels_for_user),
2662    )?;
2663    Ok(())
2664}
2665
2666/// Creates a new channel.
2667async fn create_channel(
2668    request: proto::CreateChannel,
2669    response: Response<proto::CreateChannel>,
2670    session: Session,
2671) -> Result<()> {
2672    let db = session.db().await;
2673
2674    let parent_id = request.parent_id.map(ChannelId::from_proto);
2675    let (channel, membership) = db
2676        .create_channel(&request.name, parent_id, session.user_id())
2677        .await?;
2678
2679    let root_id = channel.root_id();
2680    let channel = Channel::from_model(channel);
2681
2682    response.send(proto::CreateChannelResponse {
2683        channel: Some(channel.to_proto()),
2684        parent_id: request.parent_id,
2685    })?;
2686
2687    let mut connection_pool = session.connection_pool().await;
2688    if let Some(membership) = membership {
2689        connection_pool.subscribe_to_channel(
2690            membership.user_id,
2691            membership.channel_id,
2692            membership.role,
2693        );
2694        let update = proto::UpdateUserChannels {
2695            channel_memberships: vec![proto::ChannelMembership {
2696                channel_id: membership.channel_id.to_proto(),
2697                role: membership.role.into(),
2698            }],
2699            ..Default::default()
2700        };
2701        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2702            session.peer.send(connection_id, update.clone())?;
2703        }
2704    }
2705
2706    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2707        if !role.can_see_channel(channel.visibility) {
2708            continue;
2709        }
2710
2711        let update = proto::UpdateChannels {
2712            channels: vec![channel.to_proto()],
2713            ..Default::default()
2714        };
2715        session.peer.send(connection_id, update.clone())?;
2716    }
2717
2718    Ok(())
2719}
2720
2721/// Delete a channel
2722async fn delete_channel(
2723    request: proto::DeleteChannel,
2724    response: Response<proto::DeleteChannel>,
2725    session: Session,
2726) -> Result<()> {
2727    let db = session.db().await;
2728
2729    let channel_id = request.channel_id;
2730    let (root_channel, removed_channels) = db
2731        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2732        .await?;
2733    response.send(proto::Ack {})?;
2734
2735    // Notify members of removed channels
2736    let mut update = proto::UpdateChannels::default();
2737    update
2738        .delete_channels
2739        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2740
2741    let connection_pool = session.connection_pool().await;
2742    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2743        session.peer.send(connection_id, update.clone())?;
2744    }
2745
2746    Ok(())
2747}
2748
2749/// Invite someone to join a channel.
2750async fn invite_channel_member(
2751    request: proto::InviteChannelMember,
2752    response: Response<proto::InviteChannelMember>,
2753    session: Session,
2754) -> Result<()> {
2755    let db = session.db().await;
2756    let channel_id = ChannelId::from_proto(request.channel_id);
2757    let invitee_id = UserId::from_proto(request.user_id);
2758    let InviteMemberResult {
2759        channel,
2760        notifications,
2761    } = db
2762        .invite_channel_member(
2763            channel_id,
2764            invitee_id,
2765            session.user_id(),
2766            request.role().into(),
2767        )
2768        .await?;
2769
2770    let update = proto::UpdateChannels {
2771        channel_invitations: vec![channel.to_proto()],
2772        ..Default::default()
2773    };
2774
2775    let connection_pool = session.connection_pool().await;
2776    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2777        session.peer.send(connection_id, update.clone())?;
2778    }
2779
2780    send_notifications(&connection_pool, &session.peer, notifications);
2781
2782    response.send(proto::Ack {})?;
2783    Ok(())
2784}
2785
2786/// remove someone from a channel
2787async fn remove_channel_member(
2788    request: proto::RemoveChannelMember,
2789    response: Response<proto::RemoveChannelMember>,
2790    session: Session,
2791) -> Result<()> {
2792    let db = session.db().await;
2793    let channel_id = ChannelId::from_proto(request.channel_id);
2794    let member_id = UserId::from_proto(request.user_id);
2795
2796    let RemoveChannelMemberResult {
2797        membership_update,
2798        notification_id,
2799    } = db
2800        .remove_channel_member(channel_id, member_id, session.user_id())
2801        .await?;
2802
2803    let mut connection_pool = session.connection_pool().await;
2804    notify_membership_updated(
2805        &mut connection_pool,
2806        membership_update,
2807        member_id,
2808        &session.peer,
2809    );
2810    for connection_id in connection_pool.user_connection_ids(member_id) {
2811        if let Some(notification_id) = notification_id {
2812            session
2813                .peer
2814                .send(
2815                    connection_id,
2816                    proto::DeleteNotification {
2817                        notification_id: notification_id.to_proto(),
2818                    },
2819                )
2820                .trace_err();
2821        }
2822    }
2823
2824    response.send(proto::Ack {})?;
2825    Ok(())
2826}
2827
2828/// Toggle the channel between public and private.
2829/// Care is taken to maintain the invariant that public channels only descend from public channels,
2830/// (though members-only channels can appear at any point in the hierarchy).
2831async fn set_channel_visibility(
2832    request: proto::SetChannelVisibility,
2833    response: Response<proto::SetChannelVisibility>,
2834    session: Session,
2835) -> Result<()> {
2836    let db = session.db().await;
2837    let channel_id = ChannelId::from_proto(request.channel_id);
2838    let visibility = request.visibility().into();
2839
2840    let channel_model = db
2841        .set_channel_visibility(channel_id, visibility, session.user_id())
2842        .await?;
2843    let root_id = channel_model.root_id();
2844    let channel = Channel::from_model(channel_model);
2845
2846    let mut connection_pool = session.connection_pool().await;
2847    for (user_id, role) in connection_pool
2848        .channel_user_ids(root_id)
2849        .collect::<Vec<_>>()
2850        .into_iter()
2851    {
2852        let update = if role.can_see_channel(channel.visibility) {
2853            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2854            proto::UpdateChannels {
2855                channels: vec![channel.to_proto()],
2856                ..Default::default()
2857            }
2858        } else {
2859            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2860            proto::UpdateChannels {
2861                delete_channels: vec![channel.id.to_proto()],
2862                ..Default::default()
2863            }
2864        };
2865
2866        for connection_id in connection_pool.user_connection_ids(user_id) {
2867            session.peer.send(connection_id, update.clone())?;
2868        }
2869    }
2870
2871    response.send(proto::Ack {})?;
2872    Ok(())
2873}
2874
2875/// Alter the role for a user in the channel.
2876async fn set_channel_member_role(
2877    request: proto::SetChannelMemberRole,
2878    response: Response<proto::SetChannelMemberRole>,
2879    session: Session,
2880) -> Result<()> {
2881    let db = session.db().await;
2882    let channel_id = ChannelId::from_proto(request.channel_id);
2883    let member_id = UserId::from_proto(request.user_id);
2884    let result = db
2885        .set_channel_member_role(
2886            channel_id,
2887            session.user_id(),
2888            member_id,
2889            request.role().into(),
2890        )
2891        .await?;
2892
2893    match result {
2894        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2895            let mut connection_pool = session.connection_pool().await;
2896            notify_membership_updated(
2897                &mut connection_pool,
2898                membership_update,
2899                member_id,
2900                &session.peer,
2901            )
2902        }
2903        db::SetMemberRoleResult::InviteUpdated(channel) => {
2904            let update = proto::UpdateChannels {
2905                channel_invitations: vec![channel.to_proto()],
2906                ..Default::default()
2907            };
2908
2909            for connection_id in session
2910                .connection_pool()
2911                .await
2912                .user_connection_ids(member_id)
2913            {
2914                session.peer.send(connection_id, update.clone())?;
2915            }
2916        }
2917    }
2918
2919    response.send(proto::Ack {})?;
2920    Ok(())
2921}
2922
2923/// Change the name of a channel
2924async fn rename_channel(
2925    request: proto::RenameChannel,
2926    response: Response<proto::RenameChannel>,
2927    session: Session,
2928) -> Result<()> {
2929    let db = session.db().await;
2930    let channel_id = ChannelId::from_proto(request.channel_id);
2931    let channel_model = db
2932        .rename_channel(channel_id, session.user_id(), &request.name)
2933        .await?;
2934    let root_id = channel_model.root_id();
2935    let channel = Channel::from_model(channel_model);
2936
2937    response.send(proto::RenameChannelResponse {
2938        channel: Some(channel.to_proto()),
2939    })?;
2940
2941    let connection_pool = session.connection_pool().await;
2942    let update = proto::UpdateChannels {
2943        channels: vec![channel.to_proto()],
2944        ..Default::default()
2945    };
2946    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2947        if role.can_see_channel(channel.visibility) {
2948            session.peer.send(connection_id, update.clone())?;
2949        }
2950    }
2951
2952    Ok(())
2953}
2954
2955/// Move a channel to a new parent.
2956async fn move_channel(
2957    request: proto::MoveChannel,
2958    response: Response<proto::MoveChannel>,
2959    session: Session,
2960) -> Result<()> {
2961    let channel_id = ChannelId::from_proto(request.channel_id);
2962    let to = ChannelId::from_proto(request.to);
2963
2964    let (root_id, channels) = session
2965        .db()
2966        .await
2967        .move_channel(channel_id, to, session.user_id())
2968        .await?;
2969
2970    let connection_pool = session.connection_pool().await;
2971    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2972        let channels = channels
2973            .iter()
2974            .filter_map(|channel| {
2975                if role.can_see_channel(channel.visibility) {
2976                    Some(channel.to_proto())
2977                } else {
2978                    None
2979                }
2980            })
2981            .collect::<Vec<_>>();
2982        if channels.is_empty() {
2983            continue;
2984        }
2985
2986        let update = proto::UpdateChannels {
2987            channels,
2988            ..Default::default()
2989        };
2990
2991        session.peer.send(connection_id, update.clone())?;
2992    }
2993
2994    response.send(Ack {})?;
2995    Ok(())
2996}
2997
2998/// Get the list of channel members
2999async fn get_channel_members(
3000    request: proto::GetChannelMembers,
3001    response: Response<proto::GetChannelMembers>,
3002    session: Session,
3003) -> Result<()> {
3004    let db = session.db().await;
3005    let channel_id = ChannelId::from_proto(request.channel_id);
3006    let limit = if request.limit == 0 {
3007        u16::MAX as u64
3008    } else {
3009        request.limit
3010    };
3011    let (members, users) = db
3012        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3013        .await?;
3014    response.send(proto::GetChannelMembersResponse { members, users })?;
3015    Ok(())
3016}
3017
3018/// Accept or decline a channel invitation.
3019async fn respond_to_channel_invite(
3020    request: proto::RespondToChannelInvite,
3021    response: Response<proto::RespondToChannelInvite>,
3022    session: Session,
3023) -> Result<()> {
3024    let db = session.db().await;
3025    let channel_id = ChannelId::from_proto(request.channel_id);
3026    let RespondToChannelInvite {
3027        membership_update,
3028        notifications,
3029    } = db
3030        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3031        .await?;
3032
3033    let mut connection_pool = session.connection_pool().await;
3034    if let Some(membership_update) = membership_update {
3035        notify_membership_updated(
3036            &mut connection_pool,
3037            membership_update,
3038            session.user_id(),
3039            &session.peer,
3040        );
3041    } else {
3042        let update = proto::UpdateChannels {
3043            remove_channel_invitations: vec![channel_id.to_proto()],
3044            ..Default::default()
3045        };
3046
3047        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3048            session.peer.send(connection_id, update.clone())?;
3049        }
3050    };
3051
3052    send_notifications(&connection_pool, &session.peer, notifications);
3053
3054    response.send(proto::Ack {})?;
3055
3056    Ok(())
3057}
3058
3059/// Join the channels' room
3060async fn join_channel(
3061    request: proto::JoinChannel,
3062    response: Response<proto::JoinChannel>,
3063    session: Session,
3064) -> Result<()> {
3065    let channel_id = ChannelId::from_proto(request.channel_id);
3066    join_channel_internal(channel_id, Box::new(response), session).await
3067}
3068
3069trait JoinChannelInternalResponse {
3070    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3071}
3072impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3073    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3074        Response::<proto::JoinChannel>::send(self, result)
3075    }
3076}
3077impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3078    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3079        Response::<proto::JoinRoom>::send(self, result)
3080    }
3081}
3082
3083async fn join_channel_internal(
3084    channel_id: ChannelId,
3085    response: Box<impl JoinChannelInternalResponse>,
3086    session: Session,
3087) -> Result<()> {
3088    let joined_room = {
3089        let mut db = session.db().await;
3090        // If zed quits without leaving the room, and the user re-opens zed before the
3091        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3092        // room they were in.
3093        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3094            tracing::info!(
3095                stale_connection_id = %connection,
3096                "cleaning up stale connection",
3097            );
3098            drop(db);
3099            leave_room_for_session(&session, connection).await?;
3100            db = session.db().await;
3101        }
3102
3103        let (joined_room, membership_updated, role) = db
3104            .join_channel(channel_id, session.user_id(), session.connection_id)
3105            .await?;
3106
3107        let live_kit_connection_info =
3108            session
3109                .app_state
3110                .livekit_client
3111                .as_ref()
3112                .and_then(|live_kit| {
3113                    let (can_publish, token) = if role == ChannelRole::Guest {
3114                        (
3115                            false,
3116                            live_kit
3117                                .guest_token(
3118                                    &joined_room.room.livekit_room,
3119                                    &session.user_id().to_string(),
3120                                )
3121                                .trace_err()?,
3122                        )
3123                    } else {
3124                        (
3125                            true,
3126                            live_kit
3127                                .room_token(
3128                                    &joined_room.room.livekit_room,
3129                                    &session.user_id().to_string(),
3130                                )
3131                                .trace_err()?,
3132                        )
3133                    };
3134
3135                    Some(LiveKitConnectionInfo {
3136                        server_url: live_kit.url().into(),
3137                        token,
3138                        can_publish,
3139                    })
3140                });
3141
3142        response.send(proto::JoinRoomResponse {
3143            room: Some(joined_room.room.clone()),
3144            channel_id: joined_room
3145                .channel
3146                .as_ref()
3147                .map(|channel| channel.id.to_proto()),
3148            live_kit_connection_info,
3149        })?;
3150
3151        let mut connection_pool = session.connection_pool().await;
3152        if let Some(membership_updated) = membership_updated {
3153            notify_membership_updated(
3154                &mut connection_pool,
3155                membership_updated,
3156                session.user_id(),
3157                &session.peer,
3158            );
3159        }
3160
3161        room_updated(&joined_room.room, &session.peer);
3162
3163        joined_room
3164    };
3165
3166    channel_updated(
3167        &joined_room
3168            .channel
3169            .ok_or_else(|| anyhow!("channel not returned"))?,
3170        &joined_room.room,
3171        &session.peer,
3172        &*session.connection_pool().await,
3173    );
3174
3175    update_user_contacts(session.user_id(), &session).await?;
3176    Ok(())
3177}
3178
3179/// Start editing the channel notes
3180async fn join_channel_buffer(
3181    request: proto::JoinChannelBuffer,
3182    response: Response<proto::JoinChannelBuffer>,
3183    session: Session,
3184) -> Result<()> {
3185    let db = session.db().await;
3186    let channel_id = ChannelId::from_proto(request.channel_id);
3187
3188    let open_response = db
3189        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3190        .await?;
3191
3192    let collaborators = open_response.collaborators.clone();
3193    response.send(open_response)?;
3194
3195    let update = UpdateChannelBufferCollaborators {
3196        channel_id: channel_id.to_proto(),
3197        collaborators: collaborators.clone(),
3198    };
3199    channel_buffer_updated(
3200        session.connection_id,
3201        collaborators
3202            .iter()
3203            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3204        &update,
3205        &session.peer,
3206    );
3207
3208    Ok(())
3209}
3210
3211/// Edit the channel notes
3212async fn update_channel_buffer(
3213    request: proto::UpdateChannelBuffer,
3214    session: Session,
3215) -> Result<()> {
3216    let db = session.db().await;
3217    let channel_id = ChannelId::from_proto(request.channel_id);
3218
3219    let (collaborators, epoch, version) = db
3220        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3221        .await?;
3222
3223    channel_buffer_updated(
3224        session.connection_id,
3225        collaborators.clone(),
3226        &proto::UpdateChannelBuffer {
3227            channel_id: channel_id.to_proto(),
3228            operations: request.operations,
3229        },
3230        &session.peer,
3231    );
3232
3233    let pool = &*session.connection_pool().await;
3234
3235    let non_collaborators =
3236        pool.channel_connection_ids(channel_id)
3237            .filter_map(|(connection_id, _)| {
3238                if collaborators.contains(&connection_id) {
3239                    None
3240                } else {
3241                    Some(connection_id)
3242                }
3243            });
3244
3245    broadcast(None, non_collaborators, |peer_id| {
3246        session.peer.send(
3247            peer_id,
3248            proto::UpdateChannels {
3249                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3250                    channel_id: channel_id.to_proto(),
3251                    epoch: epoch as u64,
3252                    version: version.clone(),
3253                }],
3254                ..Default::default()
3255            },
3256        )
3257    });
3258
3259    Ok(())
3260}
3261
3262/// Rejoin the channel notes after a connection blip
3263async fn rejoin_channel_buffers(
3264    request: proto::RejoinChannelBuffers,
3265    response: Response<proto::RejoinChannelBuffers>,
3266    session: Session,
3267) -> Result<()> {
3268    let db = session.db().await;
3269    let buffers = db
3270        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3271        .await?;
3272
3273    for rejoined_buffer in &buffers {
3274        let collaborators_to_notify = rejoined_buffer
3275            .buffer
3276            .collaborators
3277            .iter()
3278            .filter_map(|c| Some(c.peer_id?.into()));
3279        channel_buffer_updated(
3280            session.connection_id,
3281            collaborators_to_notify,
3282            &proto::UpdateChannelBufferCollaborators {
3283                channel_id: rejoined_buffer.buffer.channel_id,
3284                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3285            },
3286            &session.peer,
3287        );
3288    }
3289
3290    response.send(proto::RejoinChannelBuffersResponse {
3291        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3292    })?;
3293
3294    Ok(())
3295}
3296
3297/// Stop editing the channel notes
3298async fn leave_channel_buffer(
3299    request: proto::LeaveChannelBuffer,
3300    response: Response<proto::LeaveChannelBuffer>,
3301    session: Session,
3302) -> Result<()> {
3303    let db = session.db().await;
3304    let channel_id = ChannelId::from_proto(request.channel_id);
3305
3306    let left_buffer = db
3307        .leave_channel_buffer(channel_id, session.connection_id)
3308        .await?;
3309
3310    response.send(Ack {})?;
3311
3312    channel_buffer_updated(
3313        session.connection_id,
3314        left_buffer.connections,
3315        &proto::UpdateChannelBufferCollaborators {
3316            channel_id: channel_id.to_proto(),
3317            collaborators: left_buffer.collaborators,
3318        },
3319        &session.peer,
3320    );
3321
3322    Ok(())
3323}
3324
3325fn channel_buffer_updated<T: EnvelopedMessage>(
3326    sender_id: ConnectionId,
3327    collaborators: impl IntoIterator<Item = ConnectionId>,
3328    message: &T,
3329    peer: &Peer,
3330) {
3331    broadcast(Some(sender_id), collaborators, |peer_id| {
3332        peer.send(peer_id, message.clone())
3333    });
3334}
3335
3336fn send_notifications(
3337    connection_pool: &ConnectionPool,
3338    peer: &Peer,
3339    notifications: db::NotificationBatch,
3340) {
3341    for (user_id, notification) in notifications {
3342        for connection_id in connection_pool.user_connection_ids(user_id) {
3343            if let Err(error) = peer.send(
3344                connection_id,
3345                proto::AddNotification {
3346                    notification: Some(notification.clone()),
3347                },
3348            ) {
3349                tracing::error!(
3350                    "failed to send notification to {:?} {}",
3351                    connection_id,
3352                    error
3353                );
3354            }
3355        }
3356    }
3357}
3358
3359/// Send a message to the channel
3360async fn send_channel_message(
3361    request: proto::SendChannelMessage,
3362    response: Response<proto::SendChannelMessage>,
3363    session: Session,
3364) -> Result<()> {
3365    // Validate the message body.
3366    let body = request.body.trim().to_string();
3367    if body.len() > MAX_MESSAGE_LEN {
3368        return Err(anyhow!("message is too long"))?;
3369    }
3370    if body.is_empty() {
3371        return Err(anyhow!("message can't be blank"))?;
3372    }
3373
3374    // TODO: adjust mentions if body is trimmed
3375
3376    let timestamp = OffsetDateTime::now_utc();
3377    let nonce = request
3378        .nonce
3379        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3380
3381    let channel_id = ChannelId::from_proto(request.channel_id);
3382    let CreatedChannelMessage {
3383        message_id,
3384        participant_connection_ids,
3385        notifications,
3386    } = session
3387        .db()
3388        .await
3389        .create_channel_message(
3390            channel_id,
3391            session.user_id(),
3392            &body,
3393            &request.mentions,
3394            timestamp,
3395            nonce.clone().into(),
3396            request.reply_to_message_id.map(MessageId::from_proto),
3397        )
3398        .await?;
3399
3400    let message = proto::ChannelMessage {
3401        sender_id: session.user_id().to_proto(),
3402        id: message_id.to_proto(),
3403        body,
3404        mentions: request.mentions,
3405        timestamp: timestamp.unix_timestamp() as u64,
3406        nonce: Some(nonce),
3407        reply_to_message_id: request.reply_to_message_id,
3408        edited_at: None,
3409    };
3410    broadcast(
3411        Some(session.connection_id),
3412        participant_connection_ids.clone(),
3413        |connection| {
3414            session.peer.send(
3415                connection,
3416                proto::ChannelMessageSent {
3417                    channel_id: channel_id.to_proto(),
3418                    message: Some(message.clone()),
3419                },
3420            )
3421        },
3422    );
3423    response.send(proto::SendChannelMessageResponse {
3424        message: Some(message),
3425    })?;
3426
3427    let pool = &*session.connection_pool().await;
3428    let non_participants =
3429        pool.channel_connection_ids(channel_id)
3430            .filter_map(|(connection_id, _)| {
3431                if participant_connection_ids.contains(&connection_id) {
3432                    None
3433                } else {
3434                    Some(connection_id)
3435                }
3436            });
3437    broadcast(None, non_participants, |peer_id| {
3438        session.peer.send(
3439            peer_id,
3440            proto::UpdateChannels {
3441                latest_channel_message_ids: vec![proto::ChannelMessageId {
3442                    channel_id: channel_id.to_proto(),
3443                    message_id: message_id.to_proto(),
3444                }],
3445                ..Default::default()
3446            },
3447        )
3448    });
3449    send_notifications(pool, &session.peer, notifications);
3450
3451    Ok(())
3452}
3453
3454/// Delete a channel message
3455async fn remove_channel_message(
3456    request: proto::RemoveChannelMessage,
3457    response: Response<proto::RemoveChannelMessage>,
3458    session: Session,
3459) -> Result<()> {
3460    let channel_id = ChannelId::from_proto(request.channel_id);
3461    let message_id = MessageId::from_proto(request.message_id);
3462    let (connection_ids, existing_notification_ids) = session
3463        .db()
3464        .await
3465        .remove_channel_message(channel_id, message_id, session.user_id())
3466        .await?;
3467
3468    broadcast(
3469        Some(session.connection_id),
3470        connection_ids,
3471        move |connection| {
3472            session.peer.send(connection, request.clone())?;
3473
3474            for notification_id in &existing_notification_ids {
3475                session.peer.send(
3476                    connection,
3477                    proto::DeleteNotification {
3478                        notification_id: (*notification_id).to_proto(),
3479                    },
3480                )?;
3481            }
3482
3483            Ok(())
3484        },
3485    );
3486    response.send(proto::Ack {})?;
3487    Ok(())
3488}
3489
3490async fn update_channel_message(
3491    request: proto::UpdateChannelMessage,
3492    response: Response<proto::UpdateChannelMessage>,
3493    session: Session,
3494) -> Result<()> {
3495    let channel_id = ChannelId::from_proto(request.channel_id);
3496    let message_id = MessageId::from_proto(request.message_id);
3497    let updated_at = OffsetDateTime::now_utc();
3498    let UpdatedChannelMessage {
3499        message_id,
3500        participant_connection_ids,
3501        notifications,
3502        reply_to_message_id,
3503        timestamp,
3504        deleted_mention_notification_ids,
3505        updated_mention_notifications,
3506    } = session
3507        .db()
3508        .await
3509        .update_channel_message(
3510            channel_id,
3511            message_id,
3512            session.user_id(),
3513            request.body.as_str(),
3514            &request.mentions,
3515            updated_at,
3516        )
3517        .await?;
3518
3519    let nonce = request
3520        .nonce
3521        .clone()
3522        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3523
3524    let message = proto::ChannelMessage {
3525        sender_id: session.user_id().to_proto(),
3526        id: message_id.to_proto(),
3527        body: request.body.clone(),
3528        mentions: request.mentions.clone(),
3529        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3530        nonce: Some(nonce),
3531        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3532        edited_at: Some(updated_at.unix_timestamp() as u64),
3533    };
3534
3535    response.send(proto::Ack {})?;
3536
3537    let pool = &*session.connection_pool().await;
3538    broadcast(
3539        Some(session.connection_id),
3540        participant_connection_ids,
3541        |connection| {
3542            session.peer.send(
3543                connection,
3544                proto::ChannelMessageUpdate {
3545                    channel_id: channel_id.to_proto(),
3546                    message: Some(message.clone()),
3547                },
3548            )?;
3549
3550            for notification_id in &deleted_mention_notification_ids {
3551                session.peer.send(
3552                    connection,
3553                    proto::DeleteNotification {
3554                        notification_id: (*notification_id).to_proto(),
3555                    },
3556                )?;
3557            }
3558
3559            for notification in &updated_mention_notifications {
3560                session.peer.send(
3561                    connection,
3562                    proto::UpdateNotification {
3563                        notification: Some(notification.clone()),
3564                    },
3565                )?;
3566            }
3567
3568            Ok(())
3569        },
3570    );
3571
3572    send_notifications(pool, &session.peer, notifications);
3573
3574    Ok(())
3575}
3576
3577/// Mark a channel message as read
3578async fn acknowledge_channel_message(
3579    request: proto::AckChannelMessage,
3580    session: Session,
3581) -> Result<()> {
3582    let channel_id = ChannelId::from_proto(request.channel_id);
3583    let message_id = MessageId::from_proto(request.message_id);
3584    let notifications = session
3585        .db()
3586        .await
3587        .observe_channel_message(channel_id, session.user_id(), message_id)
3588        .await?;
3589    send_notifications(
3590        &*session.connection_pool().await,
3591        &session.peer,
3592        notifications,
3593    );
3594    Ok(())
3595}
3596
3597/// Mark a buffer version as synced
3598async fn acknowledge_buffer_version(
3599    request: proto::AckBufferOperation,
3600    session: Session,
3601) -> Result<()> {
3602    let buffer_id = BufferId::from_proto(request.buffer_id);
3603    session
3604        .db()
3605        .await
3606        .observe_buffer_version(
3607            buffer_id,
3608            session.user_id(),
3609            request.epoch as i32,
3610            &request.version,
3611        )
3612        .await?;
3613    Ok(())
3614}
3615
3616async fn count_language_model_tokens(
3617    request: proto::CountLanguageModelTokens,
3618    response: Response<proto::CountLanguageModelTokens>,
3619    session: Session,
3620    config: &Config,
3621) -> Result<()> {
3622    authorize_access_to_legacy_llm_endpoints(&session).await?;
3623
3624    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3625        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3626        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3627    };
3628
3629    session
3630        .app_state
3631        .rate_limiter
3632        .check(&*rate_limit, session.user_id())
3633        .await?;
3634
3635    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3636        Some(proto::LanguageModelProvider::Google) => {
3637            let api_key = config
3638                .google_ai_api_key
3639                .as_ref()
3640                .context("no Google AI API key configured on the server")?;
3641            google_ai::count_tokens(
3642                session.http_client.as_ref(),
3643                google_ai::API_URL,
3644                api_key,
3645                serde_json::from_str(&request.request)?,
3646            )
3647            .await?
3648        }
3649        _ => return Err(anyhow!("unsupported provider"))?,
3650    };
3651
3652    response.send(proto::CountLanguageModelTokensResponse {
3653        token_count: result.total_tokens as u32,
3654    })?;
3655
3656    Ok(())
3657}
3658
3659struct ZedProCountLanguageModelTokensRateLimit;
3660
3661impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3662    fn capacity(&self) -> usize {
3663        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3664            .ok()
3665            .and_then(|v| v.parse().ok())
3666            .unwrap_or(600) // Picked arbitrarily
3667    }
3668
3669    fn refill_duration(&self) -> chrono::Duration {
3670        chrono::Duration::hours(1)
3671    }
3672
3673    fn db_name(&self) -> &'static str {
3674        "zed-pro:count-language-model-tokens"
3675    }
3676}
3677
3678struct FreeCountLanguageModelTokensRateLimit;
3679
3680impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3681    fn capacity(&self) -> usize {
3682        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3683            .ok()
3684            .and_then(|v| v.parse().ok())
3685            .unwrap_or(600 / 10) // Picked arbitrarily
3686    }
3687
3688    fn refill_duration(&self) -> chrono::Duration {
3689        chrono::Duration::hours(1)
3690    }
3691
3692    fn db_name(&self) -> &'static str {
3693        "free:count-language-model-tokens"
3694    }
3695}
3696
3697struct ZedProComputeEmbeddingsRateLimit;
3698
3699impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3700    fn capacity(&self) -> usize {
3701        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3702            .ok()
3703            .and_then(|v| v.parse().ok())
3704            .unwrap_or(5000) // Picked arbitrarily
3705    }
3706
3707    fn refill_duration(&self) -> chrono::Duration {
3708        chrono::Duration::hours(1)
3709    }
3710
3711    fn db_name(&self) -> &'static str {
3712        "zed-pro:compute-embeddings"
3713    }
3714}
3715
3716struct FreeComputeEmbeddingsRateLimit;
3717
3718impl RateLimit for FreeComputeEmbeddingsRateLimit {
3719    fn capacity(&self) -> usize {
3720        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3721            .ok()
3722            .and_then(|v| v.parse().ok())
3723            .unwrap_or(5000 / 10) // Picked arbitrarily
3724    }
3725
3726    fn refill_duration(&self) -> chrono::Duration {
3727        chrono::Duration::hours(1)
3728    }
3729
3730    fn db_name(&self) -> &'static str {
3731        "free:compute-embeddings"
3732    }
3733}
3734
3735async fn compute_embeddings(
3736    request: proto::ComputeEmbeddings,
3737    response: Response<proto::ComputeEmbeddings>,
3738    session: Session,
3739    api_key: Option<Arc<str>>,
3740) -> Result<()> {
3741    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3742    authorize_access_to_legacy_llm_endpoints(&session).await?;
3743
3744    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3745        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3746        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3747    };
3748
3749    session
3750        .app_state
3751        .rate_limiter
3752        .check(&*rate_limit, session.user_id())
3753        .await?;
3754
3755    let embeddings = match request.model.as_str() {
3756        "openai/text-embedding-3-small" => {
3757            open_ai::embed(
3758                session.http_client.as_ref(),
3759                OPEN_AI_API_URL,
3760                &api_key,
3761                OpenAiEmbeddingModel::TextEmbedding3Small,
3762                request.texts.iter().map(|text| text.as_str()),
3763            )
3764            .await?
3765        }
3766        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3767    };
3768
3769    let embeddings = request
3770        .texts
3771        .iter()
3772        .map(|text| {
3773            let mut hasher = sha2::Sha256::new();
3774            hasher.update(text.as_bytes());
3775            let result = hasher.finalize();
3776            result.to_vec()
3777        })
3778        .zip(
3779            embeddings
3780                .data
3781                .into_iter()
3782                .map(|embedding| embedding.embedding),
3783        )
3784        .collect::<HashMap<_, _>>();
3785
3786    let db = session.db().await;
3787    db.save_embeddings(&request.model, &embeddings)
3788        .await
3789        .context("failed to save embeddings")
3790        .trace_err();
3791
3792    response.send(proto::ComputeEmbeddingsResponse {
3793        embeddings: embeddings
3794            .into_iter()
3795            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3796            .collect(),
3797    })?;
3798    Ok(())
3799}
3800
3801async fn get_cached_embeddings(
3802    request: proto::GetCachedEmbeddings,
3803    response: Response<proto::GetCachedEmbeddings>,
3804    session: Session,
3805) -> Result<()> {
3806    authorize_access_to_legacy_llm_endpoints(&session).await?;
3807
3808    let db = session.db().await;
3809    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3810
3811    response.send(proto::GetCachedEmbeddingsResponse {
3812        embeddings: embeddings
3813            .into_iter()
3814            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3815            .collect(),
3816    })?;
3817    Ok(())
3818}
3819
3820/// This is leftover from before the LLM service.
3821///
3822/// The endpoints protected by this check will be moved there eventually.
3823async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3824    if session.is_staff() {
3825        Ok(())
3826    } else {
3827        Err(anyhow!("permission denied"))?
3828    }
3829}
3830
3831/// Get a Supermaven API key for the user
3832async fn get_supermaven_api_key(
3833    _request: proto::GetSupermavenApiKey,
3834    response: Response<proto::GetSupermavenApiKey>,
3835    session: Session,
3836) -> Result<()> {
3837    let user_id: String = session.user_id().to_string();
3838    if !session.is_staff() {
3839        return Err(anyhow!("supermaven not enabled for this account"))?;
3840    }
3841
3842    let email = session
3843        .email()
3844        .ok_or_else(|| anyhow!("user must have an email"))?;
3845
3846    let supermaven_admin_api = session
3847        .supermaven_client
3848        .as_ref()
3849        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3850
3851    let result = supermaven_admin_api
3852        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3853        .await?;
3854
3855    response.send(proto::GetSupermavenApiKeyResponse {
3856        api_key: result.api_key,
3857    })?;
3858
3859    Ok(())
3860}
3861
3862/// Start receiving chat updates for a channel
3863async fn join_channel_chat(
3864    request: proto::JoinChannelChat,
3865    response: Response<proto::JoinChannelChat>,
3866    session: Session,
3867) -> Result<()> {
3868    let channel_id = ChannelId::from_proto(request.channel_id);
3869
3870    let db = session.db().await;
3871    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3872        .await?;
3873    let messages = db
3874        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3875        .await?;
3876    response.send(proto::JoinChannelChatResponse {
3877        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3878        messages,
3879    })?;
3880    Ok(())
3881}
3882
3883/// Stop receiving chat updates for a channel
3884async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3885    let channel_id = ChannelId::from_proto(request.channel_id);
3886    session
3887        .db()
3888        .await
3889        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3890        .await?;
3891    Ok(())
3892}
3893
3894/// Retrieve the chat history for a channel
3895async fn get_channel_messages(
3896    request: proto::GetChannelMessages,
3897    response: Response<proto::GetChannelMessages>,
3898    session: Session,
3899) -> Result<()> {
3900    let channel_id = ChannelId::from_proto(request.channel_id);
3901    let messages = session
3902        .db()
3903        .await
3904        .get_channel_messages(
3905            channel_id,
3906            session.user_id(),
3907            MESSAGE_COUNT_PER_PAGE,
3908            Some(MessageId::from_proto(request.before_message_id)),
3909        )
3910        .await?;
3911    response.send(proto::GetChannelMessagesResponse {
3912        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3913        messages,
3914    })?;
3915    Ok(())
3916}
3917
3918/// Retrieve specific chat messages
3919async fn get_channel_messages_by_id(
3920    request: proto::GetChannelMessagesById,
3921    response: Response<proto::GetChannelMessagesById>,
3922    session: Session,
3923) -> Result<()> {
3924    let message_ids = request
3925        .message_ids
3926        .iter()
3927        .map(|id| MessageId::from_proto(*id))
3928        .collect::<Vec<_>>();
3929    let messages = session
3930        .db()
3931        .await
3932        .get_channel_messages_by_id(session.user_id(), &message_ids)
3933        .await?;
3934    response.send(proto::GetChannelMessagesResponse {
3935        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3936        messages,
3937    })?;
3938    Ok(())
3939}
3940
3941/// Retrieve the current users notifications
3942async fn get_notifications(
3943    request: proto::GetNotifications,
3944    response: Response<proto::GetNotifications>,
3945    session: Session,
3946) -> Result<()> {
3947    let notifications = session
3948        .db()
3949        .await
3950        .get_notifications(
3951            session.user_id(),
3952            NOTIFICATION_COUNT_PER_PAGE,
3953            request.before_id.map(db::NotificationId::from_proto),
3954        )
3955        .await?;
3956    response.send(proto::GetNotificationsResponse {
3957        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3958        notifications,
3959    })?;
3960    Ok(())
3961}
3962
3963/// Mark notifications as read
3964async fn mark_notification_as_read(
3965    request: proto::MarkNotificationRead,
3966    response: Response<proto::MarkNotificationRead>,
3967    session: Session,
3968) -> Result<()> {
3969    let database = &session.db().await;
3970    let notifications = database
3971        .mark_notification_as_read_by_id(
3972            session.user_id(),
3973            NotificationId::from_proto(request.notification_id),
3974        )
3975        .await?;
3976    send_notifications(
3977        &*session.connection_pool().await,
3978        &session.peer,
3979        notifications,
3980    );
3981    response.send(proto::Ack {})?;
3982    Ok(())
3983}
3984
3985/// Get the current users information
3986async fn get_private_user_info(
3987    _request: proto::GetPrivateUserInfo,
3988    response: Response<proto::GetPrivateUserInfo>,
3989    session: Session,
3990) -> Result<()> {
3991    let db = session.db().await;
3992
3993    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3994    let user = db
3995        .get_user_by_id(session.user_id())
3996        .await?
3997        .ok_or_else(|| anyhow!("user not found"))?;
3998    let flags = db.get_user_flags(session.user_id()).await?;
3999
4000    response.send(proto::GetPrivateUserInfoResponse {
4001        metrics_id,
4002        staff: user.admin,
4003        flags,
4004        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4005    })?;
4006    Ok(())
4007}
4008
4009/// Accept the terms of service (tos) on behalf of the current user
4010async fn accept_terms_of_service(
4011    _request: proto::AcceptTermsOfService,
4012    response: Response<proto::AcceptTermsOfService>,
4013    session: Session,
4014) -> Result<()> {
4015    let db = session.db().await;
4016
4017    let accepted_tos_at = Utc::now();
4018    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4019        .await?;
4020
4021    response.send(proto::AcceptTermsOfServiceResponse {
4022        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4023    })?;
4024    Ok(())
4025}
4026
4027/// The minimum account age an account must have in order to use the LLM service.
4028const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4029
4030async fn get_llm_api_token(
4031    _request: proto::GetLlmToken,
4032    response: Response<proto::GetLlmToken>,
4033    session: Session,
4034) -> Result<()> {
4035    let db = session.db().await;
4036
4037    let flags = db.get_user_flags(session.user_id()).await?;
4038    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4039    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4040    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4041
4042    if !session.is_staff() && !has_language_models_feature_flag {
4043        Err(anyhow!("permission denied"))?
4044    }
4045
4046    let user_id = session.user_id();
4047    let user = db
4048        .get_user_by_id(user_id)
4049        .await?
4050        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4051
4052    if user.accepted_tos_at.is_none() {
4053        Err(anyhow!("terms of service not accepted"))?
4054    }
4055
4056    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4057
4058    let bypass_account_age_check =
4059        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4060    if !bypass_account_age_check {
4061        let mut account_created_at = user.created_at;
4062        if let Some(github_created_at) = user.github_user_created_at {
4063            account_created_at = account_created_at.min(github_created_at);
4064        }
4065        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4066            Err(anyhow!("account too young"))?
4067        }
4068    }
4069
4070    let billing_preferences = db.get_billing_preferences(user.id).await?;
4071
4072    let token = LlmTokenClaims::create(
4073        &user,
4074        session.is_staff(),
4075        billing_preferences,
4076        has_llm_closed_beta_feature_flag,
4077        has_predict_edits_feature_flag,
4078        has_llm_subscription,
4079        session.current_plan(&db).await?,
4080        session.system_id.clone(),
4081        &session.app_state.config,
4082    )?;
4083    response.send(proto::GetLlmTokenResponse { token })?;
4084    Ok(())
4085}
4086
4087fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4088    let message = match message {
4089        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4090        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4091        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4092        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4093        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4094            code: frame.code.into(),
4095            reason: frame.reason,
4096        })),
4097        // We should never receive a frame while reading the message, according
4098        // to the `tungstenite` maintainers:
4099        //
4100        // > It cannot occur when you read messages from the WebSocket, but it
4101        // > can be used when you want to send the raw frames (e.g. you want to
4102        // > send the frames to the WebSocket without composing the full message first).
4103        // >
4104        // > — https://github.com/snapview/tungstenite-rs/issues/268
4105        TungsteniteMessage::Frame(_) => {
4106            bail!("received an unexpected frame while reading the message")
4107        }
4108    };
4109
4110    Ok(message)
4111}
4112
4113fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4114    match message {
4115        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4116        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4117        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4118        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4119        AxumMessage::Close(frame) => {
4120            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4121                code: frame.code.into(),
4122                reason: frame.reason,
4123            }))
4124        }
4125    }
4126}
4127
4128fn notify_membership_updated(
4129    connection_pool: &mut ConnectionPool,
4130    result: MembershipUpdated,
4131    user_id: UserId,
4132    peer: &Peer,
4133) {
4134    for membership in &result.new_channels.channel_memberships {
4135        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4136    }
4137    for channel_id in &result.removed_channels {
4138        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4139    }
4140
4141    let user_channels_update = proto::UpdateUserChannels {
4142        channel_memberships: result
4143            .new_channels
4144            .channel_memberships
4145            .iter()
4146            .map(|cm| proto::ChannelMembership {
4147                channel_id: cm.channel_id.to_proto(),
4148                role: cm.role.into(),
4149            })
4150            .collect(),
4151        ..Default::default()
4152    };
4153
4154    let mut update = build_channels_update(result.new_channels);
4155    update.delete_channels = result
4156        .removed_channels
4157        .into_iter()
4158        .map(|id| id.to_proto())
4159        .collect();
4160    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4161
4162    for connection_id in connection_pool.user_connection_ids(user_id) {
4163        peer.send(connection_id, user_channels_update.clone())
4164            .trace_err();
4165        peer.send(connection_id, update.clone()).trace_err();
4166    }
4167}
4168
4169fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4170    proto::UpdateUserChannels {
4171        channel_memberships: channels
4172            .channel_memberships
4173            .iter()
4174            .map(|m| proto::ChannelMembership {
4175                channel_id: m.channel_id.to_proto(),
4176                role: m.role.into(),
4177            })
4178            .collect(),
4179        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4180        observed_channel_message_id: channels.observed_channel_messages.clone(),
4181    }
4182}
4183
4184fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4185    let mut update = proto::UpdateChannels::default();
4186
4187    for channel in channels.channels {
4188        update.channels.push(channel.to_proto());
4189    }
4190
4191    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4192    update.latest_channel_message_ids = channels.latest_channel_messages;
4193
4194    for (channel_id, participants) in channels.channel_participants {
4195        update
4196            .channel_participants
4197            .push(proto::ChannelParticipants {
4198                channel_id: channel_id.to_proto(),
4199                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4200            });
4201    }
4202
4203    for channel in channels.invited_channels {
4204        update.channel_invitations.push(channel.to_proto());
4205    }
4206
4207    update
4208}
4209
4210fn build_initial_contacts_update(
4211    contacts: Vec<db::Contact>,
4212    pool: &ConnectionPool,
4213) -> proto::UpdateContacts {
4214    let mut update = proto::UpdateContacts::default();
4215
4216    for contact in contacts {
4217        match contact {
4218            db::Contact::Accepted { user_id, busy } => {
4219                update.contacts.push(contact_for_user(user_id, busy, pool));
4220            }
4221            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4222            db::Contact::Incoming { user_id } => {
4223                update
4224                    .incoming_requests
4225                    .push(proto::IncomingContactRequest {
4226                        requester_id: user_id.to_proto(),
4227                    })
4228            }
4229        }
4230    }
4231
4232    update
4233}
4234
4235fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4236    proto::Contact {
4237        user_id: user_id.to_proto(),
4238        online: pool.is_user_online(user_id),
4239        busy,
4240    }
4241}
4242
4243fn room_updated(room: &proto::Room, peer: &Peer) {
4244    broadcast(
4245        None,
4246        room.participants
4247            .iter()
4248            .filter_map(|participant| Some(participant.peer_id?.into())),
4249        |peer_id| {
4250            peer.send(
4251                peer_id,
4252                proto::RoomUpdated {
4253                    room: Some(room.clone()),
4254                },
4255            )
4256        },
4257    );
4258}
4259
4260fn channel_updated(
4261    channel: &db::channel::Model,
4262    room: &proto::Room,
4263    peer: &Peer,
4264    pool: &ConnectionPool,
4265) {
4266    let participants = room
4267        .participants
4268        .iter()
4269        .map(|p| p.user_id)
4270        .collect::<Vec<_>>();
4271
4272    broadcast(
4273        None,
4274        pool.channel_connection_ids(channel.root_id())
4275            .filter_map(|(channel_id, role)| {
4276                role.can_see_channel(channel.visibility)
4277                    .then_some(channel_id)
4278            }),
4279        |peer_id| {
4280            peer.send(
4281                peer_id,
4282                proto::UpdateChannels {
4283                    channel_participants: vec![proto::ChannelParticipants {
4284                        channel_id: channel.id.to_proto(),
4285                        participant_user_ids: participants.clone(),
4286                    }],
4287                    ..Default::default()
4288                },
4289            )
4290        },
4291    );
4292}
4293
4294async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4295    let db = session.db().await;
4296
4297    let contacts = db.get_contacts(user_id).await?;
4298    let busy = db.is_user_busy(user_id).await?;
4299
4300    let pool = session.connection_pool().await;
4301    let updated_contact = contact_for_user(user_id, busy, &pool);
4302    for contact in contacts {
4303        if let db::Contact::Accepted {
4304            user_id: contact_user_id,
4305            ..
4306        } = contact
4307        {
4308            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4309                session
4310                    .peer
4311                    .send(
4312                        contact_conn_id,
4313                        proto::UpdateContacts {
4314                            contacts: vec![updated_contact.clone()],
4315                            remove_contacts: Default::default(),
4316                            incoming_requests: Default::default(),
4317                            remove_incoming_requests: Default::default(),
4318                            outgoing_requests: Default::default(),
4319                            remove_outgoing_requests: Default::default(),
4320                        },
4321                    )
4322                    .trace_err();
4323            }
4324        }
4325    }
4326    Ok(())
4327}
4328
4329async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4330    let mut contacts_to_update = HashSet::default();
4331
4332    let room_id;
4333    let canceled_calls_to_user_ids;
4334    let livekit_room;
4335    let delete_livekit_room;
4336    let room;
4337    let channel;
4338
4339    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4340        contacts_to_update.insert(session.user_id());
4341
4342        for project in left_room.left_projects.values() {
4343            project_left(project, session);
4344        }
4345
4346        room_id = RoomId::from_proto(left_room.room.id);
4347        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4348        livekit_room = mem::take(&mut left_room.room.livekit_room);
4349        delete_livekit_room = left_room.deleted;
4350        room = mem::take(&mut left_room.room);
4351        channel = mem::take(&mut left_room.channel);
4352
4353        room_updated(&room, &session.peer);
4354    } else {
4355        return Ok(());
4356    }
4357
4358    if let Some(channel) = channel {
4359        channel_updated(
4360            &channel,
4361            &room,
4362            &session.peer,
4363            &*session.connection_pool().await,
4364        );
4365    }
4366
4367    {
4368        let pool = session.connection_pool().await;
4369        for canceled_user_id in canceled_calls_to_user_ids {
4370            for connection_id in pool.user_connection_ids(canceled_user_id) {
4371                session
4372                    .peer
4373                    .send(
4374                        connection_id,
4375                        proto::CallCanceled {
4376                            room_id: room_id.to_proto(),
4377                        },
4378                    )
4379                    .trace_err();
4380            }
4381            contacts_to_update.insert(canceled_user_id);
4382        }
4383    }
4384
4385    for contact_user_id in contacts_to_update {
4386        update_user_contacts(contact_user_id, session).await?;
4387    }
4388
4389    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4390        live_kit
4391            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4392            .await
4393            .trace_err();
4394
4395        if delete_livekit_room {
4396            live_kit.delete_room(livekit_room).await.trace_err();
4397        }
4398    }
4399
4400    Ok(())
4401}
4402
4403async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4404    let left_channel_buffers = session
4405        .db()
4406        .await
4407        .leave_channel_buffers(session.connection_id)
4408        .await?;
4409
4410    for left_buffer in left_channel_buffers {
4411        channel_buffer_updated(
4412            session.connection_id,
4413            left_buffer.connections,
4414            &proto::UpdateChannelBufferCollaborators {
4415                channel_id: left_buffer.channel_id.to_proto(),
4416                collaborators: left_buffer.collaborators,
4417            },
4418            &session.peer,
4419        );
4420    }
4421
4422    Ok(())
4423}
4424
4425fn project_left(project: &db::LeftProject, session: &Session) {
4426    for connection_id in &project.connection_ids {
4427        if project.should_unshare {
4428            session
4429                .peer
4430                .send(
4431                    *connection_id,
4432                    proto::UnshareProject {
4433                        project_id: project.id.to_proto(),
4434                    },
4435                )
4436                .trace_err();
4437        } else {
4438            session
4439                .peer
4440                .send(
4441                    *connection_id,
4442                    proto::RemoveProjectCollaborator {
4443                        project_id: project.id.to_proto(),
4444                        peer_id: Some(session.connection_id.into()),
4445                    },
4446                )
4447                .trace_err();
4448        }
4449    }
4450}
4451
4452pub trait ResultExt {
4453    type Ok;
4454
4455    fn trace_err(self) -> Option<Self::Ok>;
4456}
4457
4458impl<T, E> ResultExt for Result<T, E>
4459where
4460    E: std::fmt::Debug,
4461{
4462    type Ok = T;
4463
4464    #[track_caller]
4465    fn trace_err(self) -> Option<T> {
4466        match self {
4467            Ok(value) => Some(value),
4468            Err(error) => {
4469                tracing::error!("{:?}", error);
4470                None
4471            }
4472        }
4473    }
4474}