rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 341            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 342            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 344            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 345            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 346            .add_message_handler(create_buffer_for_peer)
 347            .add_request_handler(update_buffer)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 353            .add_request_handler(get_users)
 354            .add_request_handler(fuzzy_search_users)
 355            .add_request_handler(request_contact)
 356            .add_request_handler(remove_contact)
 357            .add_request_handler(respond_to_contact_request)
 358            .add_message_handler(subscribe_to_channels)
 359            .add_request_handler(create_channel)
 360            .add_request_handler(delete_channel)
 361            .add_request_handler(invite_channel_member)
 362            .add_request_handler(remove_channel_member)
 363            .add_request_handler(set_channel_member_role)
 364            .add_request_handler(set_channel_visibility)
 365            .add_request_handler(rename_channel)
 366            .add_request_handler(join_channel_buffer)
 367            .add_request_handler(leave_channel_buffer)
 368            .add_message_handler(update_channel_buffer)
 369            .add_request_handler(rejoin_channel_buffers)
 370            .add_request_handler(get_channel_members)
 371            .add_request_handler(respond_to_channel_invite)
 372            .add_request_handler(join_channel)
 373            .add_request_handler(join_channel_chat)
 374            .add_message_handler(leave_channel_chat)
 375            .add_request_handler(send_channel_message)
 376            .add_request_handler(remove_channel_message)
 377            .add_request_handler(update_channel_message)
 378            .add_request_handler(get_channel_messages)
 379            .add_request_handler(get_channel_messages_by_id)
 380            .add_request_handler(get_notifications)
 381            .add_request_handler(mark_notification_as_read)
 382            .add_request_handler(move_channel)
 383            .add_request_handler(follow)
 384            .add_message_handler(unfollow)
 385            .add_message_handler(update_followers)
 386            .add_request_handler(get_private_user_info)
 387            .add_request_handler(get_llm_api_token)
 388            .add_request_handler(accept_terms_of_service)
 389            .add_message_handler(acknowledge_channel_message)
 390            .add_message_handler(acknowledge_buffer_version)
 391            .add_request_handler(get_supermaven_api_key)
 392            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 395            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 398            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 399            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 400            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 401            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 403            .add_message_handler(update_context)
 404            .add_request_handler({
 405                let app_state = app_state.clone();
 406                move |request, response, session| {
 407                    let app_state = app_state.clone();
 408                    async move {
 409                        count_language_model_tokens(request, response, session, &app_state.config)
 410                            .await
 411                    }
 412                }
 413            })
 414            .add_request_handler(get_cached_embeddings)
 415            .add_request_handler({
 416                let app_state = app_state.clone();
 417                move |request, response, session| {
 418                    compute_embeddings(
 419                        request,
 420                        response,
 421                        session,
 422                        app_state.config.openai_api_key.clone(),
 423                    )
 424                }
 425            });
 426
 427        Arc::new(server)
 428    }
 429
 430    pub async fn start(&self) -> Result<()> {
 431        let server_id = *self.id.lock();
 432        let app_state = self.app_state.clone();
 433        let peer = self.peer.clone();
 434        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 435        let pool = self.connection_pool.clone();
 436        let livekit_client = self.app_state.livekit_client.clone();
 437
 438        let span = info_span!("start server");
 439        self.app_state.executor.spawn_detached(
 440            async move {
 441                tracing::info!("waiting for cleanup timeout");
 442                timeout.await;
 443                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 444                if let Some((room_ids, channel_ids)) = app_state
 445                    .db
 446                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 447                    .await
 448                    .trace_err()
 449                {
 450                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 451                    tracing::info!(
 452                        stale_channel_buffer_count = channel_ids.len(),
 453                        "retrieved stale channel buffers"
 454                    );
 455
 456                    for channel_id in channel_ids {
 457                        if let Some(refreshed_channel_buffer) = app_state
 458                            .db
 459                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 460                            .await
 461                            .trace_err()
 462                        {
 463                            for connection_id in refreshed_channel_buffer.connection_ids {
 464                                peer.send(
 465                                    connection_id,
 466                                    proto::UpdateChannelBufferCollaborators {
 467                                        channel_id: channel_id.to_proto(),
 468                                        collaborators: refreshed_channel_buffer
 469                                            .collaborators
 470                                            .clone(),
 471                                    },
 472                                )
 473                                .trace_err();
 474                            }
 475                        }
 476                    }
 477
 478                    for room_id in room_ids {
 479                        let mut contacts_to_update = HashSet::default();
 480                        let mut canceled_calls_to_user_ids = Vec::new();
 481                        let mut livekit_room = String::new();
 482                        let mut delete_livekit_room = false;
 483
 484                        if let Some(mut refreshed_room) = app_state
 485                            .db
 486                            .clear_stale_room_participants(room_id, server_id)
 487                            .await
 488                            .trace_err()
 489                        {
 490                            tracing::info!(
 491                                room_id = room_id.0,
 492                                new_participant_count = refreshed_room.room.participants.len(),
 493                                "refreshed room"
 494                            );
 495                            room_updated(&refreshed_room.room, &peer);
 496                            if let Some(channel) = refreshed_room.channel.as_ref() {
 497                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 498                            }
 499                            contacts_to_update
 500                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 501                            contacts_to_update
 502                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 503                            canceled_calls_to_user_ids =
 504                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 505                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 506                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 507                        }
 508
 509                        {
 510                            let pool = pool.lock();
 511                            for canceled_user_id in canceled_calls_to_user_ids {
 512                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 513                                    peer.send(
 514                                        connection_id,
 515                                        proto::CallCanceled {
 516                                            room_id: room_id.to_proto(),
 517                                        },
 518                                    )
 519                                    .trace_err();
 520                                }
 521                            }
 522                        }
 523
 524                        for user_id in contacts_to_update {
 525                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 526                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 527                            if let Some((busy, contacts)) = busy.zip(contacts) {
 528                                let pool = pool.lock();
 529                                let updated_contact = contact_for_user(user_id, busy, &pool);
 530                                for contact in contacts {
 531                                    if let db::Contact::Accepted {
 532                                        user_id: contact_user_id,
 533                                        ..
 534                                    } = contact
 535                                    {
 536                                        for contact_conn_id in
 537                                            pool.user_connection_ids(contact_user_id)
 538                                        {
 539                                            peer.send(
 540                                                contact_conn_id,
 541                                                proto::UpdateContacts {
 542                                                    contacts: vec![updated_contact.clone()],
 543                                                    remove_contacts: Default::default(),
 544                                                    incoming_requests: Default::default(),
 545                                                    remove_incoming_requests: Default::default(),
 546                                                    outgoing_requests: Default::default(),
 547                                                    remove_outgoing_requests: Default::default(),
 548                                                },
 549                                            )
 550                                            .trace_err();
 551                                        }
 552                                    }
 553                                }
 554                            }
 555                        }
 556
 557                        if let Some(live_kit) = livekit_client.as_ref() {
 558                            if delete_livekit_room {
 559                                live_kit.delete_room(livekit_room).await.trace_err();
 560                            }
 561                        }
 562                    }
 563                }
 564
 565                app_state
 566                    .db
 567                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 568                    .await
 569                    .trace_err();
 570            }
 571            .instrument(span),
 572        );
 573        Ok(())
 574    }
 575
 576    pub fn teardown(&self) {
 577        self.peer.teardown();
 578        self.connection_pool.lock().reset();
 579        let _ = self.teardown.send(true);
 580    }
 581
 582    #[cfg(test)]
 583    pub fn reset(&self, id: ServerId) {
 584        self.teardown();
 585        *self.id.lock() = id;
 586        self.peer.reset(id.0 as u32);
 587        let _ = self.teardown.send(false);
 588    }
 589
 590    #[cfg(test)]
 591    pub fn id(&self) -> ServerId {
 592        *self.id.lock()
 593    }
 594
 595    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 596    where
 597        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 598        Fut: 'static + Send + Future<Output = Result<()>>,
 599        M: EnvelopedMessage,
 600    {
 601        let prev_handler = self.handlers.insert(
 602            TypeId::of::<M>(),
 603            Box::new(move |envelope, session| {
 604                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 605                let received_at = envelope.received_at;
 606                tracing::info!("message received");
 607                let start_time = Instant::now();
 608                let future = (handler)(*envelope, session);
 609                async move {
 610                    let result = future.await;
 611                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 612                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 613                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 614                    let payload_type = M::NAME;
 615
 616                    match result {
 617                        Err(error) => {
 618                            tracing::error!(
 619                                ?error,
 620                                total_duration_ms,
 621                                processing_duration_ms,
 622                                queue_duration_ms,
 623                                payload_type,
 624                                "error handling message"
 625                            )
 626                        }
 627                        Ok(()) => tracing::info!(
 628                            total_duration_ms,
 629                            processing_duration_ms,
 630                            queue_duration_ms,
 631                            "finished handling message"
 632                        ),
 633                    }
 634                }
 635                .boxed()
 636            }),
 637        );
 638        if prev_handler.is_some() {
 639            panic!("registered a handler for the same message twice");
 640        }
 641        self
 642    }
 643
 644    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 645    where
 646        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 647        Fut: 'static + Send + Future<Output = Result<()>>,
 648        M: EnvelopedMessage,
 649    {
 650        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 651        self
 652    }
 653
 654    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 655    where
 656        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 657        Fut: Send + Future<Output = Result<()>>,
 658        M: RequestMessage,
 659    {
 660        let handler = Arc::new(handler);
 661        self.add_handler(move |envelope, session| {
 662            let receipt = envelope.receipt();
 663            let handler = handler.clone();
 664            async move {
 665                let peer = session.peer.clone();
 666                let responded = Arc::new(AtomicBool::default());
 667                let response = Response {
 668                    peer: peer.clone(),
 669                    responded: responded.clone(),
 670                    receipt,
 671                };
 672                match (handler)(envelope.payload, response, session).await {
 673                    Ok(()) => {
 674                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 675                            Ok(())
 676                        } else {
 677                            Err(anyhow!("handler did not send a response"))?
 678                        }
 679                    }
 680                    Err(error) => {
 681                        let proto_err = match &error {
 682                            Error::Internal(err) => err.to_proto(),
 683                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 684                        };
 685                        peer.respond_with_error(receipt, proto_err)?;
 686                        Err(error)
 687                    }
 688                }
 689            }
 690        })
 691    }
 692
 693    #[allow(clippy::too_many_arguments)]
 694    pub fn handle_connection(
 695        self: &Arc<Self>,
 696        connection: Connection,
 697        address: String,
 698        principal: Principal,
 699        zed_version: ZedVersion,
 700        geoip_country_code: Option<String>,
 701        system_id: Option<String>,
 702        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 703        executor: Executor,
 704    ) -> impl Future<Output = ()> {
 705        let this = self.clone();
 706        let span = info_span!("handle connection", %address,
 707            connection_id=field::Empty,
 708            user_id=field::Empty,
 709            login=field::Empty,
 710            impersonator=field::Empty,
 711            geoip_country_code=field::Empty
 712        );
 713        principal.update_span(&span);
 714        if let Some(country_code) = geoip_country_code.as_ref() {
 715            span.record("geoip_country_code", country_code);
 716        }
 717
 718        let mut teardown = self.teardown.subscribe();
 719        async move {
 720            if *teardown.borrow() {
 721                tracing::error!("server is tearing down");
 722                return
 723            }
 724            let (connection_id, handle_io, mut incoming_rx) = this
 725                .peer
 726                .add_connection(connection, {
 727                    let executor = executor.clone();
 728                    move |duration| executor.sleep(duration)
 729                });
 730            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 731
 732            tracing::info!("connection opened");
 733
 734            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 735            let http_client = match ReqwestClient::user_agent(&user_agent) {
 736                Ok(http_client) => Arc::new(http_client),
 737                Err(error) => {
 738                    tracing::error!(?error, "failed to create HTTP client");
 739                    return;
 740                }
 741            };
 742
 743            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 744                    supermaven_admin_api_key.to_string(),
 745                    http_client.clone(),
 746                )));
 747
 748            let session = Session {
 749                principal: principal.clone(),
 750                connection_id,
 751                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 752                peer: this.peer.clone(),
 753                connection_pool: this.connection_pool.clone(),
 754                app_state: this.app_state.clone(),
 755                http_client,
 756                geoip_country_code,
 757                system_id,
 758                _executor: executor.clone(),
 759                supermaven_client,
 760            };
 761
 762            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 763                tracing::error!(?error, "failed to send initial client update");
 764                return;
 765            }
 766
 767            let handle_io = handle_io.fuse();
 768            futures::pin_mut!(handle_io);
 769
 770            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 771            // This prevents deadlocks when e.g., client A performs a request to client B and
 772            // client B performs a request to client A. If both clients stop processing further
 773            // messages until their respective request completes, they won't have a chance to
 774            // respond to the other client's request and cause a deadlock.
 775            //
 776            // This arrangement ensures we will attempt to process earlier messages first, but fall
 777            // back to processing messages arrived later in the spirit of making progress.
 778            let mut foreground_message_handlers = FuturesUnordered::new();
 779            let concurrent_handlers = Arc::new(Semaphore::new(256));
 780            loop {
 781                let next_message = async {
 782                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 783                    let message = incoming_rx.next().await;
 784                    (permit, message)
 785                }.fuse();
 786                futures::pin_mut!(next_message);
 787                futures::select_biased! {
 788                    _ = teardown.changed().fuse() => return,
 789                    result = handle_io => {
 790                        if let Err(error) = result {
 791                            tracing::error!(?error, "error handling I/O");
 792                        }
 793                        break;
 794                    }
 795                    _ = foreground_message_handlers.next() => {}
 796                    next_message = next_message => {
 797                        let (permit, message) = next_message;
 798                        if let Some(message) = message {
 799                            let type_name = message.payload_type_name();
 800                            // note: we copy all the fields from the parent span so we can query them in the logs.
 801                            // (https://github.com/tokio-rs/tracing/issues/2670).
 802                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 803                                user_id=field::Empty,
 804                                login=field::Empty,
 805                                impersonator=field::Empty,
 806                            );
 807                            principal.update_span(&span);
 808                            let span_enter = span.enter();
 809                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 810                                let is_background = message.is_background();
 811                                let handle_message = (handler)(message, session.clone());
 812                                drop(span_enter);
 813
 814                                let handle_message = async move {
 815                                    handle_message.await;
 816                                    drop(permit);
 817                                }.instrument(span);
 818                                if is_background {
 819                                    executor.spawn_detached(handle_message);
 820                                } else {
 821                                    foreground_message_handlers.push(handle_message);
 822                                }
 823                            } else {
 824                                tracing::error!("no message handler");
 825                            }
 826                        } else {
 827                            tracing::info!("connection closed");
 828                            break;
 829                        }
 830                    }
 831                }
 832            }
 833
 834            drop(foreground_message_handlers);
 835            tracing::info!("signing out");
 836            if let Err(error) = connection_lost(session, teardown, executor).await {
 837                tracing::error!(?error, "error signing out");
 838            }
 839
 840        }.instrument(span)
 841    }
 842
 843    async fn send_initial_client_update(
 844        &self,
 845        connection_id: ConnectionId,
 846        principal: &Principal,
 847        zed_version: ZedVersion,
 848        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 849        session: &Session,
 850    ) -> Result<()> {
 851        self.peer.send(
 852            connection_id,
 853            proto::Hello {
 854                peer_id: Some(connection_id.into()),
 855            },
 856        )?;
 857        tracing::info!("sent hello message");
 858        if let Some(send_connection_id) = send_connection_id.take() {
 859            let _ = send_connection_id.send(connection_id);
 860        }
 861
 862        match principal {
 863            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 864                if !user.connected_once {
 865                    self.peer.send(connection_id, proto::ShowContacts {})?;
 866                    self.app_state
 867                        .db
 868                        .set_user_connected_once(user.id, true)
 869                        .await?;
 870                }
 871
 872                update_user_plan(user.id, session).await?;
 873
 874                let contacts = self.app_state.db.get_contacts(user.id).await?;
 875
 876                {
 877                    let mut pool = self.connection_pool.lock();
 878                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 879                    self.peer.send(
 880                        connection_id,
 881                        build_initial_contacts_update(contacts, &pool),
 882                    )?;
 883                }
 884
 885                if should_auto_subscribe_to_channels(zed_version) {
 886                    subscribe_user_to_channels(user.id, session).await?;
 887                }
 888
 889                if let Some(incoming_call) =
 890                    self.app_state.db.incoming_call_for_user(user.id).await?
 891                {
 892                    self.peer.send(connection_id, incoming_call)?;
 893                }
 894
 895                update_user_contacts(user.id, session).await?;
 896            }
 897        }
 898
 899        Ok(())
 900    }
 901
 902    pub async fn invite_code_redeemed(
 903        self: &Arc<Self>,
 904        inviter_id: UserId,
 905        invitee_id: UserId,
 906    ) -> Result<()> {
 907        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 908            if let Some(code) = &user.invite_code {
 909                let pool = self.connection_pool.lock();
 910                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 911                for connection_id in pool.user_connection_ids(inviter_id) {
 912                    self.peer.send(
 913                        connection_id,
 914                        proto::UpdateContacts {
 915                            contacts: vec![invitee_contact.clone()],
 916                            ..Default::default()
 917                        },
 918                    )?;
 919                    self.peer.send(
 920                        connection_id,
 921                        proto::UpdateInviteInfo {
 922                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 923                            count: user.invite_count as u32,
 924                        },
 925                    )?;
 926                }
 927            }
 928        }
 929        Ok(())
 930    }
 931
 932    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 933        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 934            if let Some(invite_code) = &user.invite_code {
 935                let pool = self.connection_pool.lock();
 936                for connection_id in pool.user_connection_ids(user_id) {
 937                    self.peer.send(
 938                        connection_id,
 939                        proto::UpdateInviteInfo {
 940                            url: format!(
 941                                "{}{}",
 942                                self.app_state.config.invite_link_prefix, invite_code
 943                            ),
 944                            count: user.invite_count as u32,
 945                        },
 946                    )?;
 947                }
 948            }
 949        }
 950        Ok(())
 951    }
 952
 953    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 954        let pool = self.connection_pool.lock();
 955        for connection_id in pool.user_connection_ids(user_id) {
 956            self.peer
 957                .send(connection_id, proto::RefreshLlmToken {})
 958                .trace_err();
 959        }
 960    }
 961
 962    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 963        ServerSnapshot {
 964            connection_pool: ConnectionPoolGuard {
 965                guard: self.connection_pool.lock(),
 966                _not_send: PhantomData,
 967            },
 968            peer: &self.peer,
 969        }
 970    }
 971}
 972
 973impl<'a> Deref for ConnectionPoolGuard<'a> {
 974    type Target = ConnectionPool;
 975
 976    fn deref(&self) -> &Self::Target {
 977        &self.guard
 978    }
 979}
 980
 981impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 982    fn deref_mut(&mut self) -> &mut Self::Target {
 983        &mut self.guard
 984    }
 985}
 986
 987impl<'a> Drop for ConnectionPoolGuard<'a> {
 988    fn drop(&mut self) {
 989        #[cfg(test)]
 990        self.check_invariants();
 991    }
 992}
 993
 994fn broadcast<F>(
 995    sender_id: Option<ConnectionId>,
 996    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 997    mut f: F,
 998) where
 999    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1000{
1001    for receiver_id in receiver_ids {
1002        if Some(receiver_id) != sender_id {
1003            if let Err(error) = f(receiver_id) {
1004                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1005            }
1006        }
1007    }
1008}
1009
1010pub struct ProtocolVersion(u32);
1011
1012impl Header for ProtocolVersion {
1013    fn name() -> &'static HeaderName {
1014        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1015        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1016    }
1017
1018    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1019    where
1020        Self: Sized,
1021        I: Iterator<Item = &'i axum::http::HeaderValue>,
1022    {
1023        let version = values
1024            .next()
1025            .ok_or_else(axum::headers::Error::invalid)?
1026            .to_str()
1027            .map_err(|_| axum::headers::Error::invalid())?
1028            .parse()
1029            .map_err(|_| axum::headers::Error::invalid())?;
1030        Ok(Self(version))
1031    }
1032
1033    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1034        values.extend([self.0.to_string().parse().unwrap()]);
1035    }
1036}
1037
1038pub struct AppVersionHeader(SemanticVersion);
1039impl Header for AppVersionHeader {
1040    fn name() -> &'static HeaderName {
1041        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1042        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1043    }
1044
1045    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1046    where
1047        Self: Sized,
1048        I: Iterator<Item = &'i axum::http::HeaderValue>,
1049    {
1050        let version = values
1051            .next()
1052            .ok_or_else(axum::headers::Error::invalid)?
1053            .to_str()
1054            .map_err(|_| axum::headers::Error::invalid())?
1055            .parse()
1056            .map_err(|_| axum::headers::Error::invalid())?;
1057        Ok(Self(version))
1058    }
1059
1060    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1061        values.extend([self.0.to_string().parse().unwrap()]);
1062    }
1063}
1064
1065pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1066    Router::new()
1067        .route("/rpc", get(handle_websocket_request))
1068        .layer(
1069            ServiceBuilder::new()
1070                .layer(Extension(server.app_state.clone()))
1071                .layer(middleware::from_fn(auth::validate_header)),
1072        )
1073        .route("/metrics", get(handle_metrics))
1074        .layer(Extension(server))
1075}
1076
1077#[allow(clippy::too_many_arguments)]
1078pub async fn handle_websocket_request(
1079    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1080    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1081    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1082    Extension(server): Extension<Arc<Server>>,
1083    Extension(principal): Extension<Principal>,
1084    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1085    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1086    ws: WebSocketUpgrade,
1087) -> axum::response::Response {
1088    if protocol_version != rpc::PROTOCOL_VERSION {
1089        return (
1090            StatusCode::UPGRADE_REQUIRED,
1091            "client must be upgraded".to_string(),
1092        )
1093            .into_response();
1094    }
1095
1096    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1097        return (
1098            StatusCode::UPGRADE_REQUIRED,
1099            "no version header found".to_string(),
1100        )
1101            .into_response();
1102    };
1103
1104    if !version.can_collaborate() {
1105        return (
1106            StatusCode::UPGRADE_REQUIRED,
1107            "client must be upgraded".to_string(),
1108        )
1109            .into_response();
1110    }
1111
1112    let socket_address = socket_address.to_string();
1113    ws.on_upgrade(move |socket| {
1114        let socket = socket
1115            .map_ok(to_tungstenite_message)
1116            .err_into()
1117            .with(|message| async move { to_axum_message(message) });
1118        let connection = Connection::new(Box::pin(socket));
1119        async move {
1120            server
1121                .handle_connection(
1122                    connection,
1123                    socket_address,
1124                    principal,
1125                    version,
1126                    country_code_header.map(|header| header.to_string()),
1127                    system_id_header.map(|header| header.to_string()),
1128                    None,
1129                    Executor::Production,
1130                )
1131                .await;
1132        }
1133    })
1134}
1135
1136pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1137    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1138    let connections_metric = CONNECTIONS_METRIC
1139        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1140
1141    let connections = server
1142        .connection_pool
1143        .lock()
1144        .connections()
1145        .filter(|connection| !connection.admin)
1146        .count();
1147    connections_metric.set(connections as _);
1148
1149    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1150    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1151        register_int_gauge!(
1152            "shared_projects",
1153            "number of open projects with one or more guests"
1154        )
1155        .unwrap()
1156    });
1157
1158    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1159    shared_projects_metric.set(shared_projects as _);
1160
1161    let encoder = prometheus::TextEncoder::new();
1162    let metric_families = prometheus::gather();
1163    let encoded_metrics = encoder
1164        .encode_to_string(&metric_families)
1165        .map_err(|err| anyhow!("{}", err))?;
1166    Ok(encoded_metrics)
1167}
1168
1169#[instrument(err, skip(executor))]
1170async fn connection_lost(
1171    session: Session,
1172    mut teardown: watch::Receiver<bool>,
1173    executor: Executor,
1174) -> Result<()> {
1175    session.peer.disconnect(session.connection_id);
1176    session
1177        .connection_pool()
1178        .await
1179        .remove_connection(session.connection_id)?;
1180
1181    session
1182        .db()
1183        .await
1184        .connection_lost(session.connection_id)
1185        .await
1186        .trace_err();
1187
1188    futures::select_biased! {
1189        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1190
1191            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1192            leave_room_for_session(&session, session.connection_id).await.trace_err();
1193            leave_channel_buffers_for_session(&session)
1194                .await
1195                .trace_err();
1196
1197            if !session
1198                .connection_pool()
1199                .await
1200                .is_user_online(session.user_id())
1201            {
1202                let db = session.db().await;
1203                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1204                    room_updated(&room, &session.peer);
1205                }
1206            }
1207
1208            update_user_contacts(session.user_id(), &session).await?;
1209        },
1210        _ = teardown.changed().fuse() => {}
1211    }
1212
1213    Ok(())
1214}
1215
1216/// Acknowledges a ping from a client, used to keep the connection alive.
1217async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1218    response.send(proto::Ack {})?;
1219    Ok(())
1220}
1221
1222/// Creates a new room for calling (outside of channels)
1223async fn create_room(
1224    _request: proto::CreateRoom,
1225    response: Response<proto::CreateRoom>,
1226    session: Session,
1227) -> Result<()> {
1228    let livekit_room = nanoid::nanoid!(30);
1229
1230    let live_kit_connection_info = util::maybe!(async {
1231        let live_kit = session.app_state.livekit_client.as_ref();
1232        let live_kit = live_kit?;
1233        let user_id = session.user_id().to_string();
1234
1235        let token = live_kit
1236            .room_token(&livekit_room, &user_id.to_string())
1237            .trace_err()?;
1238
1239        Some(proto::LiveKitConnectionInfo {
1240            server_url: live_kit.url().into(),
1241            token,
1242            can_publish: true,
1243        })
1244    })
1245    .await;
1246
1247    let room = session
1248        .db()
1249        .await
1250        .create_room(session.user_id(), session.connection_id, &livekit_room)
1251        .await?;
1252
1253    response.send(proto::CreateRoomResponse {
1254        room: Some(room.clone()),
1255        live_kit_connection_info,
1256    })?;
1257
1258    update_user_contacts(session.user_id(), &session).await?;
1259    Ok(())
1260}
1261
1262/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1263async fn join_room(
1264    request: proto::JoinRoom,
1265    response: Response<proto::JoinRoom>,
1266    session: Session,
1267) -> Result<()> {
1268    let room_id = RoomId::from_proto(request.id);
1269
1270    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1271
1272    if let Some(channel_id) = channel_id {
1273        return join_channel_internal(channel_id, Box::new(response), session).await;
1274    }
1275
1276    let joined_room = {
1277        let room = session
1278            .db()
1279            .await
1280            .join_room(room_id, session.user_id(), session.connection_id)
1281            .await?;
1282        room_updated(&room.room, &session.peer);
1283        room.into_inner()
1284    };
1285
1286    for connection_id in session
1287        .connection_pool()
1288        .await
1289        .user_connection_ids(session.user_id())
1290    {
1291        session
1292            .peer
1293            .send(
1294                connection_id,
1295                proto::CallCanceled {
1296                    room_id: room_id.to_proto(),
1297                },
1298            )
1299            .trace_err();
1300    }
1301
1302    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1303    {
1304        live_kit
1305            .room_token(
1306                &joined_room.room.livekit_room,
1307                &session.user_id().to_string(),
1308            )
1309            .trace_err()
1310            .map(|token| proto::LiveKitConnectionInfo {
1311                server_url: live_kit.url().into(),
1312                token,
1313                can_publish: true,
1314            })
1315    } else {
1316        None
1317    };
1318
1319    response.send(proto::JoinRoomResponse {
1320        room: Some(joined_room.room),
1321        channel_id: None,
1322        live_kit_connection_info,
1323    })?;
1324
1325    update_user_contacts(session.user_id(), &session).await?;
1326    Ok(())
1327}
1328
1329/// Rejoin room is used to reconnect to a room after connection errors.
1330async fn rejoin_room(
1331    request: proto::RejoinRoom,
1332    response: Response<proto::RejoinRoom>,
1333    session: Session,
1334) -> Result<()> {
1335    let room;
1336    let channel;
1337    {
1338        let mut rejoined_room = session
1339            .db()
1340            .await
1341            .rejoin_room(request, session.user_id(), session.connection_id)
1342            .await?;
1343
1344        response.send(proto::RejoinRoomResponse {
1345            room: Some(rejoined_room.room.clone()),
1346            reshared_projects: rejoined_room
1347                .reshared_projects
1348                .iter()
1349                .map(|project| proto::ResharedProject {
1350                    id: project.id.to_proto(),
1351                    collaborators: project
1352                        .collaborators
1353                        .iter()
1354                        .map(|collaborator| collaborator.to_proto())
1355                        .collect(),
1356                })
1357                .collect(),
1358            rejoined_projects: rejoined_room
1359                .rejoined_projects
1360                .iter()
1361                .map(|rejoined_project| rejoined_project.to_proto())
1362                .collect(),
1363        })?;
1364        room_updated(&rejoined_room.room, &session.peer);
1365
1366        for project in &rejoined_room.reshared_projects {
1367            for collaborator in &project.collaborators {
1368                session
1369                    .peer
1370                    .send(
1371                        collaborator.connection_id,
1372                        proto::UpdateProjectCollaborator {
1373                            project_id: project.id.to_proto(),
1374                            old_peer_id: Some(project.old_connection_id.into()),
1375                            new_peer_id: Some(session.connection_id.into()),
1376                        },
1377                    )
1378                    .trace_err();
1379            }
1380
1381            broadcast(
1382                Some(session.connection_id),
1383                project
1384                    .collaborators
1385                    .iter()
1386                    .map(|collaborator| collaborator.connection_id),
1387                |connection_id| {
1388                    session.peer.forward_send(
1389                        session.connection_id,
1390                        connection_id,
1391                        proto::UpdateProject {
1392                            project_id: project.id.to_proto(),
1393                            worktrees: project.worktrees.clone(),
1394                        },
1395                    )
1396                },
1397            );
1398        }
1399
1400        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1401
1402        let rejoined_room = rejoined_room.into_inner();
1403
1404        room = rejoined_room.room;
1405        channel = rejoined_room.channel;
1406    }
1407
1408    if let Some(channel) = channel {
1409        channel_updated(
1410            &channel,
1411            &room,
1412            &session.peer,
1413            &*session.connection_pool().await,
1414        );
1415    }
1416
1417    update_user_contacts(session.user_id(), &session).await?;
1418    Ok(())
1419}
1420
1421fn notify_rejoined_projects(
1422    rejoined_projects: &mut Vec<RejoinedProject>,
1423    session: &Session,
1424) -> Result<()> {
1425    for project in rejoined_projects.iter() {
1426        for collaborator in &project.collaborators {
1427            session
1428                .peer
1429                .send(
1430                    collaborator.connection_id,
1431                    proto::UpdateProjectCollaborator {
1432                        project_id: project.id.to_proto(),
1433                        old_peer_id: Some(project.old_connection_id.into()),
1434                        new_peer_id: Some(session.connection_id.into()),
1435                    },
1436                )
1437                .trace_err();
1438        }
1439    }
1440
1441    for project in rejoined_projects {
1442        for worktree in mem::take(&mut project.worktrees) {
1443            // Stream this worktree's entries.
1444            let message = proto::UpdateWorktree {
1445                project_id: project.id.to_proto(),
1446                worktree_id: worktree.id,
1447                abs_path: worktree.abs_path.clone(),
1448                root_name: worktree.root_name,
1449                updated_entries: worktree.updated_entries,
1450                removed_entries: worktree.removed_entries,
1451                scan_id: worktree.scan_id,
1452                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1453                updated_repositories: worktree.updated_repositories,
1454                removed_repositories: worktree.removed_repositories,
1455            };
1456            for update in proto::split_worktree_update(message) {
1457                session.peer.send(session.connection_id, update.clone())?;
1458            }
1459
1460            // Stream this worktree's diagnostics.
1461            for summary in worktree.diagnostic_summaries {
1462                session.peer.send(
1463                    session.connection_id,
1464                    proto::UpdateDiagnosticSummary {
1465                        project_id: project.id.to_proto(),
1466                        worktree_id: worktree.id,
1467                        summary: Some(summary),
1468                    },
1469                )?;
1470            }
1471
1472            for settings_file in worktree.settings_files {
1473                session.peer.send(
1474                    session.connection_id,
1475                    proto::UpdateWorktreeSettings {
1476                        project_id: project.id.to_proto(),
1477                        worktree_id: worktree.id,
1478                        path: settings_file.path,
1479                        content: Some(settings_file.content),
1480                        kind: Some(settings_file.kind.to_proto().into()),
1481                    },
1482                )?;
1483            }
1484        }
1485
1486        for language_server in &project.language_servers {
1487            session.peer.send(
1488                session.connection_id,
1489                proto::UpdateLanguageServer {
1490                    project_id: project.id.to_proto(),
1491                    language_server_id: language_server.id,
1492                    variant: Some(
1493                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1494                            proto::LspDiskBasedDiagnosticsUpdated {},
1495                        ),
1496                    ),
1497                },
1498            )?;
1499        }
1500    }
1501    Ok(())
1502}
1503
1504/// leave room disconnects from the room.
1505async fn leave_room(
1506    _: proto::LeaveRoom,
1507    response: Response<proto::LeaveRoom>,
1508    session: Session,
1509) -> Result<()> {
1510    leave_room_for_session(&session, session.connection_id).await?;
1511    response.send(proto::Ack {})?;
1512    Ok(())
1513}
1514
1515/// Updates the permissions of someone else in the room.
1516async fn set_room_participant_role(
1517    request: proto::SetRoomParticipantRole,
1518    response: Response<proto::SetRoomParticipantRole>,
1519    session: Session,
1520) -> Result<()> {
1521    let user_id = UserId::from_proto(request.user_id);
1522    let role = ChannelRole::from(request.role());
1523
1524    let (livekit_room, can_publish) = {
1525        let room = session
1526            .db()
1527            .await
1528            .set_room_participant_role(
1529                session.user_id(),
1530                RoomId::from_proto(request.room_id),
1531                user_id,
1532                role,
1533            )
1534            .await?;
1535
1536        let livekit_room = room.livekit_room.clone();
1537        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1538        room_updated(&room, &session.peer);
1539        (livekit_room, can_publish)
1540    };
1541
1542    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1543        live_kit
1544            .update_participant(
1545                livekit_room.clone(),
1546                request.user_id.to_string(),
1547                livekit_server::proto::ParticipantPermission {
1548                    can_subscribe: true,
1549                    can_publish,
1550                    can_publish_data: can_publish,
1551                    hidden: false,
1552                    recorder: false,
1553                },
1554            )
1555            .await
1556            .trace_err();
1557    }
1558
1559    response.send(proto::Ack {})?;
1560    Ok(())
1561}
1562
1563/// Call someone else into the current room
1564async fn call(
1565    request: proto::Call,
1566    response: Response<proto::Call>,
1567    session: Session,
1568) -> Result<()> {
1569    let room_id = RoomId::from_proto(request.room_id);
1570    let calling_user_id = session.user_id();
1571    let calling_connection_id = session.connection_id;
1572    let called_user_id = UserId::from_proto(request.called_user_id);
1573    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1574    if !session
1575        .db()
1576        .await
1577        .has_contact(calling_user_id, called_user_id)
1578        .await?
1579    {
1580        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1581    }
1582
1583    let incoming_call = {
1584        let (room, incoming_call) = &mut *session
1585            .db()
1586            .await
1587            .call(
1588                room_id,
1589                calling_user_id,
1590                calling_connection_id,
1591                called_user_id,
1592                initial_project_id,
1593            )
1594            .await?;
1595        room_updated(room, &session.peer);
1596        mem::take(incoming_call)
1597    };
1598    update_user_contacts(called_user_id, &session).await?;
1599
1600    let mut calls = session
1601        .connection_pool()
1602        .await
1603        .user_connection_ids(called_user_id)
1604        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1605        .collect::<FuturesUnordered<_>>();
1606
1607    while let Some(call_response) = calls.next().await {
1608        match call_response.as_ref() {
1609            Ok(_) => {
1610                response.send(proto::Ack {})?;
1611                return Ok(());
1612            }
1613            Err(_) => {
1614                call_response.trace_err();
1615            }
1616        }
1617    }
1618
1619    {
1620        let room = session
1621            .db()
1622            .await
1623            .call_failed(room_id, called_user_id)
1624            .await?;
1625        room_updated(&room, &session.peer);
1626    }
1627    update_user_contacts(called_user_id, &session).await?;
1628
1629    Err(anyhow!("failed to ring user"))?
1630}
1631
1632/// Cancel an outgoing call.
1633async fn cancel_call(
1634    request: proto::CancelCall,
1635    response: Response<proto::CancelCall>,
1636    session: Session,
1637) -> Result<()> {
1638    let called_user_id = UserId::from_proto(request.called_user_id);
1639    let room_id = RoomId::from_proto(request.room_id);
1640    {
1641        let room = session
1642            .db()
1643            .await
1644            .cancel_call(room_id, session.connection_id, called_user_id)
1645            .await?;
1646        room_updated(&room, &session.peer);
1647    }
1648
1649    for connection_id in session
1650        .connection_pool()
1651        .await
1652        .user_connection_ids(called_user_id)
1653    {
1654        session
1655            .peer
1656            .send(
1657                connection_id,
1658                proto::CallCanceled {
1659                    room_id: room_id.to_proto(),
1660                },
1661            )
1662            .trace_err();
1663    }
1664    response.send(proto::Ack {})?;
1665
1666    update_user_contacts(called_user_id, &session).await?;
1667    Ok(())
1668}
1669
1670/// Decline an incoming call.
1671async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1672    let room_id = RoomId::from_proto(message.room_id);
1673    {
1674        let room = session
1675            .db()
1676            .await
1677            .decline_call(Some(room_id), session.user_id())
1678            .await?
1679            .ok_or_else(|| anyhow!("failed to decline call"))?;
1680        room_updated(&room, &session.peer);
1681    }
1682
1683    for connection_id in session
1684        .connection_pool()
1685        .await
1686        .user_connection_ids(session.user_id())
1687    {
1688        session
1689            .peer
1690            .send(
1691                connection_id,
1692                proto::CallCanceled {
1693                    room_id: room_id.to_proto(),
1694                },
1695            )
1696            .trace_err();
1697    }
1698    update_user_contacts(session.user_id(), &session).await?;
1699    Ok(())
1700}
1701
1702/// Updates other participants in the room with your current location.
1703async fn update_participant_location(
1704    request: proto::UpdateParticipantLocation,
1705    response: Response<proto::UpdateParticipantLocation>,
1706    session: Session,
1707) -> Result<()> {
1708    let room_id = RoomId::from_proto(request.room_id);
1709    let location = request
1710        .location
1711        .ok_or_else(|| anyhow!("invalid location"))?;
1712
1713    let db = session.db().await;
1714    let room = db
1715        .update_room_participant_location(room_id, session.connection_id, location)
1716        .await?;
1717
1718    room_updated(&room, &session.peer);
1719    response.send(proto::Ack {})?;
1720    Ok(())
1721}
1722
1723/// Share a project into the room.
1724async fn share_project(
1725    request: proto::ShareProject,
1726    response: Response<proto::ShareProject>,
1727    session: Session,
1728) -> Result<()> {
1729    let (project_id, room) = &*session
1730        .db()
1731        .await
1732        .share_project(
1733            RoomId::from_proto(request.room_id),
1734            session.connection_id,
1735            &request.worktrees,
1736            request.is_ssh_project,
1737        )
1738        .await?;
1739    response.send(proto::ShareProjectResponse {
1740        project_id: project_id.to_proto(),
1741    })?;
1742    room_updated(room, &session.peer);
1743
1744    Ok(())
1745}
1746
1747/// Unshare a project from the room.
1748async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1749    let project_id = ProjectId::from_proto(message.project_id);
1750    unshare_project_internal(project_id, session.connection_id, &session).await
1751}
1752
1753async fn unshare_project_internal(
1754    project_id: ProjectId,
1755    connection_id: ConnectionId,
1756    session: &Session,
1757) -> Result<()> {
1758    let delete = {
1759        let room_guard = session
1760            .db()
1761            .await
1762            .unshare_project(project_id, connection_id)
1763            .await?;
1764
1765        let (delete, room, guest_connection_ids) = &*room_guard;
1766
1767        let message = proto::UnshareProject {
1768            project_id: project_id.to_proto(),
1769        };
1770
1771        broadcast(
1772            Some(connection_id),
1773            guest_connection_ids.iter().copied(),
1774            |conn_id| session.peer.send(conn_id, message.clone()),
1775        );
1776        if let Some(room) = room {
1777            room_updated(room, &session.peer);
1778        }
1779
1780        *delete
1781    };
1782
1783    if delete {
1784        let db = session.db().await;
1785        db.delete_project(project_id).await?;
1786    }
1787
1788    Ok(())
1789}
1790
1791/// Join someone elses shared project.
1792async fn join_project(
1793    request: proto::JoinProject,
1794    response: Response<proto::JoinProject>,
1795    session: Session,
1796) -> Result<()> {
1797    let project_id = ProjectId::from_proto(request.project_id);
1798
1799    tracing::info!(%project_id, "join project");
1800
1801    let db = session.db().await;
1802    let (project, replica_id) = &mut *db
1803        .join_project(project_id, session.connection_id, session.user_id())
1804        .await?;
1805    drop(db);
1806    tracing::info!(%project_id, "join remote project");
1807    join_project_internal(response, session, project, replica_id)
1808}
1809
1810trait JoinProjectInternalResponse {
1811    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1812}
1813impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1814    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1815        Response::<proto::JoinProject>::send(self, result)
1816    }
1817}
1818
1819fn join_project_internal(
1820    response: impl JoinProjectInternalResponse,
1821    session: Session,
1822    project: &mut Project,
1823    replica_id: &ReplicaId,
1824) -> Result<()> {
1825    let collaborators = project
1826        .collaborators
1827        .iter()
1828        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1829        .map(|collaborator| collaborator.to_proto())
1830        .collect::<Vec<_>>();
1831    let project_id = project.id;
1832    let guest_user_id = session.user_id();
1833
1834    let worktrees = project
1835        .worktrees
1836        .iter()
1837        .map(|(id, worktree)| proto::WorktreeMetadata {
1838            id: *id,
1839            root_name: worktree.root_name.clone(),
1840            visible: worktree.visible,
1841            abs_path: worktree.abs_path.clone(),
1842        })
1843        .collect::<Vec<_>>();
1844
1845    let add_project_collaborator = proto::AddProjectCollaborator {
1846        project_id: project_id.to_proto(),
1847        collaborator: Some(proto::Collaborator {
1848            peer_id: Some(session.connection_id.into()),
1849            replica_id: replica_id.0 as u32,
1850            user_id: guest_user_id.to_proto(),
1851            is_host: false,
1852        }),
1853    };
1854
1855    for collaborator in &collaborators {
1856        session
1857            .peer
1858            .send(
1859                collaborator.peer_id.unwrap().into(),
1860                add_project_collaborator.clone(),
1861            )
1862            .trace_err();
1863    }
1864
1865    // First, we send the metadata associated with each worktree.
1866    response.send(proto::JoinProjectResponse {
1867        project_id: project.id.0 as u64,
1868        worktrees: worktrees.clone(),
1869        replica_id: replica_id.0 as u32,
1870        collaborators: collaborators.clone(),
1871        language_servers: project.language_servers.clone(),
1872        role: project.role.into(),
1873    })?;
1874
1875    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1876        // Stream this worktree's entries.
1877        let message = proto::UpdateWorktree {
1878            project_id: project_id.to_proto(),
1879            worktree_id,
1880            abs_path: worktree.abs_path.clone(),
1881            root_name: worktree.root_name,
1882            updated_entries: worktree.entries,
1883            removed_entries: Default::default(),
1884            scan_id: worktree.scan_id,
1885            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1886            updated_repositories: worktree.repository_entries.into_values().collect(),
1887            removed_repositories: Default::default(),
1888        };
1889        for update in proto::split_worktree_update(message) {
1890            session.peer.send(session.connection_id, update.clone())?;
1891        }
1892
1893        // Stream this worktree's diagnostics.
1894        for summary in worktree.diagnostic_summaries {
1895            session.peer.send(
1896                session.connection_id,
1897                proto::UpdateDiagnosticSummary {
1898                    project_id: project_id.to_proto(),
1899                    worktree_id: worktree.id,
1900                    summary: Some(summary),
1901                },
1902            )?;
1903        }
1904
1905        for settings_file in worktree.settings_files {
1906            session.peer.send(
1907                session.connection_id,
1908                proto::UpdateWorktreeSettings {
1909                    project_id: project_id.to_proto(),
1910                    worktree_id: worktree.id,
1911                    path: settings_file.path,
1912                    content: Some(settings_file.content),
1913                    kind: Some(settings_file.kind.to_proto() as i32),
1914                },
1915            )?;
1916        }
1917    }
1918
1919    for language_server in &project.language_servers {
1920        session.peer.send(
1921            session.connection_id,
1922            proto::UpdateLanguageServer {
1923                project_id: project_id.to_proto(),
1924                language_server_id: language_server.id,
1925                variant: Some(
1926                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1927                        proto::LspDiskBasedDiagnosticsUpdated {},
1928                    ),
1929                ),
1930            },
1931        )?;
1932    }
1933
1934    Ok(())
1935}
1936
1937/// Leave someone elses shared project.
1938async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1939    let sender_id = session.connection_id;
1940    let project_id = ProjectId::from_proto(request.project_id);
1941    let db = session.db().await;
1942
1943    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1944    tracing::info!(
1945        %project_id,
1946        "leave project"
1947    );
1948
1949    project_left(project, &session);
1950    if let Some(room) = room {
1951        room_updated(room, &session.peer);
1952    }
1953
1954    Ok(())
1955}
1956
1957/// Updates other participants with changes to the project
1958async fn update_project(
1959    request: proto::UpdateProject,
1960    response: Response<proto::UpdateProject>,
1961    session: Session,
1962) -> Result<()> {
1963    let project_id = ProjectId::from_proto(request.project_id);
1964    let (room, guest_connection_ids) = &*session
1965        .db()
1966        .await
1967        .update_project(project_id, session.connection_id, &request.worktrees)
1968        .await?;
1969    broadcast(
1970        Some(session.connection_id),
1971        guest_connection_ids.iter().copied(),
1972        |connection_id| {
1973            session
1974                .peer
1975                .forward_send(session.connection_id, connection_id, request.clone())
1976        },
1977    );
1978    if let Some(room) = room {
1979        room_updated(room, &session.peer);
1980    }
1981    response.send(proto::Ack {})?;
1982
1983    Ok(())
1984}
1985
1986/// Updates other participants with changes to the worktree
1987async fn update_worktree(
1988    request: proto::UpdateWorktree,
1989    response: Response<proto::UpdateWorktree>,
1990    session: Session,
1991) -> Result<()> {
1992    let guest_connection_ids = session
1993        .db()
1994        .await
1995        .update_worktree(&request, session.connection_id)
1996        .await?;
1997
1998    broadcast(
1999        Some(session.connection_id),
2000        guest_connection_ids.iter().copied(),
2001        |connection_id| {
2002            session
2003                .peer
2004                .forward_send(session.connection_id, connection_id, request.clone())
2005        },
2006    );
2007    response.send(proto::Ack {})?;
2008    Ok(())
2009}
2010
2011/// Updates other participants with changes to the diagnostics
2012async fn update_diagnostic_summary(
2013    message: proto::UpdateDiagnosticSummary,
2014    session: Session,
2015) -> Result<()> {
2016    let guest_connection_ids = session
2017        .db()
2018        .await
2019        .update_diagnostic_summary(&message, session.connection_id)
2020        .await?;
2021
2022    broadcast(
2023        Some(session.connection_id),
2024        guest_connection_ids.iter().copied(),
2025        |connection_id| {
2026            session
2027                .peer
2028                .forward_send(session.connection_id, connection_id, message.clone())
2029        },
2030    );
2031
2032    Ok(())
2033}
2034
2035/// Updates other participants with changes to the worktree settings
2036async fn update_worktree_settings(
2037    message: proto::UpdateWorktreeSettings,
2038    session: Session,
2039) -> Result<()> {
2040    let guest_connection_ids = session
2041        .db()
2042        .await
2043        .update_worktree_settings(&message, session.connection_id)
2044        .await?;
2045
2046    broadcast(
2047        Some(session.connection_id),
2048        guest_connection_ids.iter().copied(),
2049        |connection_id| {
2050            session
2051                .peer
2052                .forward_send(session.connection_id, connection_id, message.clone())
2053        },
2054    );
2055
2056    Ok(())
2057}
2058
2059/// Notify other participants that a  language server has started.
2060async fn start_language_server(
2061    request: proto::StartLanguageServer,
2062    session: Session,
2063) -> Result<()> {
2064    let guest_connection_ids = session
2065        .db()
2066        .await
2067        .start_language_server(&request, session.connection_id)
2068        .await?;
2069
2070    broadcast(
2071        Some(session.connection_id),
2072        guest_connection_ids.iter().copied(),
2073        |connection_id| {
2074            session
2075                .peer
2076                .forward_send(session.connection_id, connection_id, request.clone())
2077        },
2078    );
2079    Ok(())
2080}
2081
2082/// Notify other participants that a language server has changed.
2083async fn update_language_server(
2084    request: proto::UpdateLanguageServer,
2085    session: Session,
2086) -> Result<()> {
2087    let project_id = ProjectId::from_proto(request.project_id);
2088    let project_connection_ids = session
2089        .db()
2090        .await
2091        .project_connection_ids(project_id, session.connection_id, true)
2092        .await?;
2093    broadcast(
2094        Some(session.connection_id),
2095        project_connection_ids.iter().copied(),
2096        |connection_id| {
2097            session
2098                .peer
2099                .forward_send(session.connection_id, connection_id, request.clone())
2100        },
2101    );
2102    Ok(())
2103}
2104
2105/// forward a project request to the host. These requests should be read only
2106/// as guests are allowed to send them.
2107async fn forward_read_only_project_request<T>(
2108    request: T,
2109    response: Response<T>,
2110    session: Session,
2111) -> Result<()>
2112where
2113    T: EntityMessage + RequestMessage,
2114{
2115    let project_id = ProjectId::from_proto(request.remote_entity_id());
2116    let host_connection_id = session
2117        .db()
2118        .await
2119        .host_for_read_only_project_request(project_id, session.connection_id)
2120        .await?;
2121    let payload = session
2122        .peer
2123        .forward_request(session.connection_id, host_connection_id, request)
2124        .await?;
2125    response.send(payload)?;
2126    Ok(())
2127}
2128
2129async fn forward_find_search_candidates_request(
2130    request: proto::FindSearchCandidates,
2131    response: Response<proto::FindSearchCandidates>,
2132    session: Session,
2133) -> Result<()> {
2134    let project_id = ProjectId::from_proto(request.remote_entity_id());
2135    let host_connection_id = session
2136        .db()
2137        .await
2138        .host_for_read_only_project_request(project_id, session.connection_id)
2139        .await?;
2140    let payload = session
2141        .peer
2142        .forward_request(session.connection_id, host_connection_id, request)
2143        .await?;
2144    response.send(payload)?;
2145    Ok(())
2146}
2147
2148/// forward a project request to the host. These requests are disallowed
2149/// for guests.
2150async fn forward_mutating_project_request<T>(
2151    request: T,
2152    response: Response<T>,
2153    session: Session,
2154) -> Result<()>
2155where
2156    T: EntityMessage + RequestMessage,
2157{
2158    let project_id = ProjectId::from_proto(request.remote_entity_id());
2159
2160    let host_connection_id = session
2161        .db()
2162        .await
2163        .host_for_mutating_project_request(project_id, session.connection_id)
2164        .await?;
2165    let payload = session
2166        .peer
2167        .forward_request(session.connection_id, host_connection_id, request)
2168        .await?;
2169    response.send(payload)?;
2170    Ok(())
2171}
2172
2173/// Notify other participants that a new buffer has been created
2174async fn create_buffer_for_peer(
2175    request: proto::CreateBufferForPeer,
2176    session: Session,
2177) -> Result<()> {
2178    session
2179        .db()
2180        .await
2181        .check_user_is_project_host(
2182            ProjectId::from_proto(request.project_id),
2183            session.connection_id,
2184        )
2185        .await?;
2186    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2187    session
2188        .peer
2189        .forward_send(session.connection_id, peer_id.into(), request)?;
2190    Ok(())
2191}
2192
2193/// Notify other participants that a buffer has been updated. This is
2194/// allowed for guests as long as the update is limited to selections.
2195async fn update_buffer(
2196    request: proto::UpdateBuffer,
2197    response: Response<proto::UpdateBuffer>,
2198    session: Session,
2199) -> Result<()> {
2200    let project_id = ProjectId::from_proto(request.project_id);
2201    let mut capability = Capability::ReadOnly;
2202
2203    for op in request.operations.iter() {
2204        match op.variant {
2205            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2206            Some(_) => capability = Capability::ReadWrite,
2207        }
2208    }
2209
2210    let host = {
2211        let guard = session
2212            .db()
2213            .await
2214            .connections_for_buffer_update(project_id, session.connection_id, capability)
2215            .await?;
2216
2217        let (host, guests) = &*guard;
2218
2219        broadcast(
2220            Some(session.connection_id),
2221            guests.clone(),
2222            |connection_id| {
2223                session
2224                    .peer
2225                    .forward_send(session.connection_id, connection_id, request.clone())
2226            },
2227        );
2228
2229        *host
2230    };
2231
2232    if host != session.connection_id {
2233        session
2234            .peer
2235            .forward_request(session.connection_id, host, request.clone())
2236            .await?;
2237    }
2238
2239    response.send(proto::Ack {})?;
2240    Ok(())
2241}
2242
2243async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2244    let project_id = ProjectId::from_proto(message.project_id);
2245
2246    let operation = message.operation.as_ref().context("invalid operation")?;
2247    let capability = match operation.variant.as_ref() {
2248        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2249            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2250                match buffer_op.variant {
2251                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2252                        Capability::ReadOnly
2253                    }
2254                    _ => Capability::ReadWrite,
2255                }
2256            } else {
2257                Capability::ReadWrite
2258            }
2259        }
2260        Some(_) => Capability::ReadWrite,
2261        None => Capability::ReadOnly,
2262    };
2263
2264    let guard = session
2265        .db()
2266        .await
2267        .connections_for_buffer_update(project_id, session.connection_id, capability)
2268        .await?;
2269
2270    let (host, guests) = &*guard;
2271
2272    broadcast(
2273        Some(session.connection_id),
2274        guests.iter().chain([host]).copied(),
2275        |connection_id| {
2276            session
2277                .peer
2278                .forward_send(session.connection_id, connection_id, message.clone())
2279        },
2280    );
2281
2282    Ok(())
2283}
2284
2285/// Notify other participants that a project has been updated.
2286async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2287    request: T,
2288    session: Session,
2289) -> Result<()> {
2290    let project_id = ProjectId::from_proto(request.remote_entity_id());
2291    let project_connection_ids = session
2292        .db()
2293        .await
2294        .project_connection_ids(project_id, session.connection_id, false)
2295        .await?;
2296
2297    broadcast(
2298        Some(session.connection_id),
2299        project_connection_ids.iter().copied(),
2300        |connection_id| {
2301            session
2302                .peer
2303                .forward_send(session.connection_id, connection_id, request.clone())
2304        },
2305    );
2306    Ok(())
2307}
2308
2309/// Start following another user in a call.
2310async fn follow(
2311    request: proto::Follow,
2312    response: Response<proto::Follow>,
2313    session: Session,
2314) -> Result<()> {
2315    let room_id = RoomId::from_proto(request.room_id);
2316    let project_id = request.project_id.map(ProjectId::from_proto);
2317    let leader_id = request
2318        .leader_id
2319        .ok_or_else(|| anyhow!("invalid leader id"))?
2320        .into();
2321    let follower_id = session.connection_id;
2322
2323    session
2324        .db()
2325        .await
2326        .check_room_participants(room_id, leader_id, session.connection_id)
2327        .await?;
2328
2329    let response_payload = session
2330        .peer
2331        .forward_request(session.connection_id, leader_id, request)
2332        .await?;
2333    response.send(response_payload)?;
2334
2335    if let Some(project_id) = project_id {
2336        let room = session
2337            .db()
2338            .await
2339            .follow(room_id, project_id, leader_id, follower_id)
2340            .await?;
2341        room_updated(&room, &session.peer);
2342    }
2343
2344    Ok(())
2345}
2346
2347/// Stop following another user in a call.
2348async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2349    let room_id = RoomId::from_proto(request.room_id);
2350    let project_id = request.project_id.map(ProjectId::from_proto);
2351    let leader_id = request
2352        .leader_id
2353        .ok_or_else(|| anyhow!("invalid leader id"))?
2354        .into();
2355    let follower_id = session.connection_id;
2356
2357    session
2358        .db()
2359        .await
2360        .check_room_participants(room_id, leader_id, session.connection_id)
2361        .await?;
2362
2363    session
2364        .peer
2365        .forward_send(session.connection_id, leader_id, request)?;
2366
2367    if let Some(project_id) = project_id {
2368        let room = session
2369            .db()
2370            .await
2371            .unfollow(room_id, project_id, leader_id, follower_id)
2372            .await?;
2373        room_updated(&room, &session.peer);
2374    }
2375
2376    Ok(())
2377}
2378
2379/// Notify everyone following you of your current location.
2380async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2381    let room_id = RoomId::from_proto(request.room_id);
2382    let database = session.db.lock().await;
2383
2384    let connection_ids = if let Some(project_id) = request.project_id {
2385        let project_id = ProjectId::from_proto(project_id);
2386        database
2387            .project_connection_ids(project_id, session.connection_id, true)
2388            .await?
2389    } else {
2390        database
2391            .room_connection_ids(room_id, session.connection_id)
2392            .await?
2393    };
2394
2395    // For now, don't send view update messages back to that view's current leader.
2396    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2397        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2398        _ => None,
2399    });
2400
2401    for connection_id in connection_ids.iter().cloned() {
2402        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2403            session
2404                .peer
2405                .forward_send(session.connection_id, connection_id, request.clone())?;
2406        }
2407    }
2408    Ok(())
2409}
2410
2411/// Get public data about users.
2412async fn get_users(
2413    request: proto::GetUsers,
2414    response: Response<proto::GetUsers>,
2415    session: Session,
2416) -> Result<()> {
2417    let user_ids = request
2418        .user_ids
2419        .into_iter()
2420        .map(UserId::from_proto)
2421        .collect();
2422    let users = session
2423        .db()
2424        .await
2425        .get_users_by_ids(user_ids)
2426        .await?
2427        .into_iter()
2428        .map(|user| proto::User {
2429            id: user.id.to_proto(),
2430            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2431            github_login: user.github_login,
2432            email: user.email_address,
2433            name: user.name,
2434        })
2435        .collect();
2436    response.send(proto::UsersResponse { users })?;
2437    Ok(())
2438}
2439
2440/// Search for users (to invite) buy Github login
2441async fn fuzzy_search_users(
2442    request: proto::FuzzySearchUsers,
2443    response: Response<proto::FuzzySearchUsers>,
2444    session: Session,
2445) -> Result<()> {
2446    let query = request.query;
2447    let users = match query.len() {
2448        0 => vec![],
2449        1 | 2 => session
2450            .db()
2451            .await
2452            .get_user_by_github_login(&query)
2453            .await?
2454            .into_iter()
2455            .collect(),
2456        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2457    };
2458    let users = users
2459        .into_iter()
2460        .filter(|user| user.id != session.user_id())
2461        .map(|user| proto::User {
2462            id: user.id.to_proto(),
2463            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2464            github_login: user.github_login,
2465            name: user.name,
2466            email: user.email_address,
2467        })
2468        .collect();
2469    response.send(proto::UsersResponse { users })?;
2470    Ok(())
2471}
2472
2473/// Send a contact request to another user.
2474async fn request_contact(
2475    request: proto::RequestContact,
2476    response: Response<proto::RequestContact>,
2477    session: Session,
2478) -> Result<()> {
2479    let requester_id = session.user_id();
2480    let responder_id = UserId::from_proto(request.responder_id);
2481    if requester_id == responder_id {
2482        return Err(anyhow!("cannot add yourself as a contact"))?;
2483    }
2484
2485    let notifications = session
2486        .db()
2487        .await
2488        .send_contact_request(requester_id, responder_id)
2489        .await?;
2490
2491    // Update outgoing contact requests of requester
2492    let mut update = proto::UpdateContacts::default();
2493    update.outgoing_requests.push(responder_id.to_proto());
2494    for connection_id in session
2495        .connection_pool()
2496        .await
2497        .user_connection_ids(requester_id)
2498    {
2499        session.peer.send(connection_id, update.clone())?;
2500    }
2501
2502    // Update incoming contact requests of responder
2503    let mut update = proto::UpdateContacts::default();
2504    update
2505        .incoming_requests
2506        .push(proto::IncomingContactRequest {
2507            requester_id: requester_id.to_proto(),
2508        });
2509    let connection_pool = session.connection_pool().await;
2510    for connection_id in connection_pool.user_connection_ids(responder_id) {
2511        session.peer.send(connection_id, update.clone())?;
2512    }
2513
2514    send_notifications(&connection_pool, &session.peer, notifications);
2515
2516    response.send(proto::Ack {})?;
2517    Ok(())
2518}
2519
2520/// Accept or decline a contact request
2521async fn respond_to_contact_request(
2522    request: proto::RespondToContactRequest,
2523    response: Response<proto::RespondToContactRequest>,
2524    session: Session,
2525) -> Result<()> {
2526    let responder_id = session.user_id();
2527    let requester_id = UserId::from_proto(request.requester_id);
2528    let db = session.db().await;
2529    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2530        db.dismiss_contact_notification(responder_id, requester_id)
2531            .await?;
2532    } else {
2533        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2534
2535        let notifications = db
2536            .respond_to_contact_request(responder_id, requester_id, accept)
2537            .await?;
2538        let requester_busy = db.is_user_busy(requester_id).await?;
2539        let responder_busy = db.is_user_busy(responder_id).await?;
2540
2541        let pool = session.connection_pool().await;
2542        // Update responder with new contact
2543        let mut update = proto::UpdateContacts::default();
2544        if accept {
2545            update
2546                .contacts
2547                .push(contact_for_user(requester_id, requester_busy, &pool));
2548        }
2549        update
2550            .remove_incoming_requests
2551            .push(requester_id.to_proto());
2552        for connection_id in pool.user_connection_ids(responder_id) {
2553            session.peer.send(connection_id, update.clone())?;
2554        }
2555
2556        // Update requester with new contact
2557        let mut update = proto::UpdateContacts::default();
2558        if accept {
2559            update
2560                .contacts
2561                .push(contact_for_user(responder_id, responder_busy, &pool));
2562        }
2563        update
2564            .remove_outgoing_requests
2565            .push(responder_id.to_proto());
2566
2567        for connection_id in pool.user_connection_ids(requester_id) {
2568            session.peer.send(connection_id, update.clone())?;
2569        }
2570
2571        send_notifications(&pool, &session.peer, notifications);
2572    }
2573
2574    response.send(proto::Ack {})?;
2575    Ok(())
2576}
2577
2578/// Remove a contact.
2579async fn remove_contact(
2580    request: proto::RemoveContact,
2581    response: Response<proto::RemoveContact>,
2582    session: Session,
2583) -> Result<()> {
2584    let requester_id = session.user_id();
2585    let responder_id = UserId::from_proto(request.user_id);
2586    let db = session.db().await;
2587    let (contact_accepted, deleted_notification_id) =
2588        db.remove_contact(requester_id, responder_id).await?;
2589
2590    let pool = session.connection_pool().await;
2591    // Update outgoing contact requests of requester
2592    let mut update = proto::UpdateContacts::default();
2593    if contact_accepted {
2594        update.remove_contacts.push(responder_id.to_proto());
2595    } else {
2596        update
2597            .remove_outgoing_requests
2598            .push(responder_id.to_proto());
2599    }
2600    for connection_id in pool.user_connection_ids(requester_id) {
2601        session.peer.send(connection_id, update.clone())?;
2602    }
2603
2604    // Update incoming contact requests of responder
2605    let mut update = proto::UpdateContacts::default();
2606    if contact_accepted {
2607        update.remove_contacts.push(requester_id.to_proto());
2608    } else {
2609        update
2610            .remove_incoming_requests
2611            .push(requester_id.to_proto());
2612    }
2613    for connection_id in pool.user_connection_ids(responder_id) {
2614        session.peer.send(connection_id, update.clone())?;
2615        if let Some(notification_id) = deleted_notification_id {
2616            session.peer.send(
2617                connection_id,
2618                proto::DeleteNotification {
2619                    notification_id: notification_id.to_proto(),
2620                },
2621            )?;
2622        }
2623    }
2624
2625    response.send(proto::Ack {})?;
2626    Ok(())
2627}
2628
2629fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2630    version.0.minor() < 139
2631}
2632
2633async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2634    let plan = session.current_plan(&session.db().await).await?;
2635
2636    session
2637        .peer
2638        .send(
2639            session.connection_id,
2640            proto::UpdateUserPlan { plan: plan.into() },
2641        )
2642        .trace_err();
2643
2644    Ok(())
2645}
2646
2647async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2648    subscribe_user_to_channels(session.user_id(), &session).await?;
2649    Ok(())
2650}
2651
2652async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2653    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2654    let mut pool = session.connection_pool().await;
2655    for membership in &channels_for_user.channel_memberships {
2656        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2657    }
2658    session.peer.send(
2659        session.connection_id,
2660        build_update_user_channels(&channels_for_user),
2661    )?;
2662    session.peer.send(
2663        session.connection_id,
2664        build_channels_update(channels_for_user),
2665    )?;
2666    Ok(())
2667}
2668
2669/// Creates a new channel.
2670async fn create_channel(
2671    request: proto::CreateChannel,
2672    response: Response<proto::CreateChannel>,
2673    session: Session,
2674) -> Result<()> {
2675    let db = session.db().await;
2676
2677    let parent_id = request.parent_id.map(ChannelId::from_proto);
2678    let (channel, membership) = db
2679        .create_channel(&request.name, parent_id, session.user_id())
2680        .await?;
2681
2682    let root_id = channel.root_id();
2683    let channel = Channel::from_model(channel);
2684
2685    response.send(proto::CreateChannelResponse {
2686        channel: Some(channel.to_proto()),
2687        parent_id: request.parent_id,
2688    })?;
2689
2690    let mut connection_pool = session.connection_pool().await;
2691    if let Some(membership) = membership {
2692        connection_pool.subscribe_to_channel(
2693            membership.user_id,
2694            membership.channel_id,
2695            membership.role,
2696        );
2697        let update = proto::UpdateUserChannels {
2698            channel_memberships: vec![proto::ChannelMembership {
2699                channel_id: membership.channel_id.to_proto(),
2700                role: membership.role.into(),
2701            }],
2702            ..Default::default()
2703        };
2704        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2705            session.peer.send(connection_id, update.clone())?;
2706        }
2707    }
2708
2709    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2710        if !role.can_see_channel(channel.visibility) {
2711            continue;
2712        }
2713
2714        let update = proto::UpdateChannels {
2715            channels: vec![channel.to_proto()],
2716            ..Default::default()
2717        };
2718        session.peer.send(connection_id, update.clone())?;
2719    }
2720
2721    Ok(())
2722}
2723
2724/// Delete a channel
2725async fn delete_channel(
2726    request: proto::DeleteChannel,
2727    response: Response<proto::DeleteChannel>,
2728    session: Session,
2729) -> Result<()> {
2730    let db = session.db().await;
2731
2732    let channel_id = request.channel_id;
2733    let (root_channel, removed_channels) = db
2734        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2735        .await?;
2736    response.send(proto::Ack {})?;
2737
2738    // Notify members of removed channels
2739    let mut update = proto::UpdateChannels::default();
2740    update
2741        .delete_channels
2742        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2743
2744    let connection_pool = session.connection_pool().await;
2745    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2746        session.peer.send(connection_id, update.clone())?;
2747    }
2748
2749    Ok(())
2750}
2751
2752/// Invite someone to join a channel.
2753async fn invite_channel_member(
2754    request: proto::InviteChannelMember,
2755    response: Response<proto::InviteChannelMember>,
2756    session: Session,
2757) -> Result<()> {
2758    let db = session.db().await;
2759    let channel_id = ChannelId::from_proto(request.channel_id);
2760    let invitee_id = UserId::from_proto(request.user_id);
2761    let InviteMemberResult {
2762        channel,
2763        notifications,
2764    } = db
2765        .invite_channel_member(
2766            channel_id,
2767            invitee_id,
2768            session.user_id(),
2769            request.role().into(),
2770        )
2771        .await?;
2772
2773    let update = proto::UpdateChannels {
2774        channel_invitations: vec![channel.to_proto()],
2775        ..Default::default()
2776    };
2777
2778    let connection_pool = session.connection_pool().await;
2779    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2780        session.peer.send(connection_id, update.clone())?;
2781    }
2782
2783    send_notifications(&connection_pool, &session.peer, notifications);
2784
2785    response.send(proto::Ack {})?;
2786    Ok(())
2787}
2788
2789/// remove someone from a channel
2790async fn remove_channel_member(
2791    request: proto::RemoveChannelMember,
2792    response: Response<proto::RemoveChannelMember>,
2793    session: Session,
2794) -> Result<()> {
2795    let db = session.db().await;
2796    let channel_id = ChannelId::from_proto(request.channel_id);
2797    let member_id = UserId::from_proto(request.user_id);
2798
2799    let RemoveChannelMemberResult {
2800        membership_update,
2801        notification_id,
2802    } = db
2803        .remove_channel_member(channel_id, member_id, session.user_id())
2804        .await?;
2805
2806    let mut connection_pool = session.connection_pool().await;
2807    notify_membership_updated(
2808        &mut connection_pool,
2809        membership_update,
2810        member_id,
2811        &session.peer,
2812    );
2813    for connection_id in connection_pool.user_connection_ids(member_id) {
2814        if let Some(notification_id) = notification_id {
2815            session
2816                .peer
2817                .send(
2818                    connection_id,
2819                    proto::DeleteNotification {
2820                        notification_id: notification_id.to_proto(),
2821                    },
2822                )
2823                .trace_err();
2824        }
2825    }
2826
2827    response.send(proto::Ack {})?;
2828    Ok(())
2829}
2830
2831/// Toggle the channel between public and private.
2832/// Care is taken to maintain the invariant that public channels only descend from public channels,
2833/// (though members-only channels can appear at any point in the hierarchy).
2834async fn set_channel_visibility(
2835    request: proto::SetChannelVisibility,
2836    response: Response<proto::SetChannelVisibility>,
2837    session: Session,
2838) -> Result<()> {
2839    let db = session.db().await;
2840    let channel_id = ChannelId::from_proto(request.channel_id);
2841    let visibility = request.visibility().into();
2842
2843    let channel_model = db
2844        .set_channel_visibility(channel_id, visibility, session.user_id())
2845        .await?;
2846    let root_id = channel_model.root_id();
2847    let channel = Channel::from_model(channel_model);
2848
2849    let mut connection_pool = session.connection_pool().await;
2850    for (user_id, role) in connection_pool
2851        .channel_user_ids(root_id)
2852        .collect::<Vec<_>>()
2853        .into_iter()
2854    {
2855        let update = if role.can_see_channel(channel.visibility) {
2856            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2857            proto::UpdateChannels {
2858                channels: vec![channel.to_proto()],
2859                ..Default::default()
2860            }
2861        } else {
2862            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2863            proto::UpdateChannels {
2864                delete_channels: vec![channel.id.to_proto()],
2865                ..Default::default()
2866            }
2867        };
2868
2869        for connection_id in connection_pool.user_connection_ids(user_id) {
2870            session.peer.send(connection_id, update.clone())?;
2871        }
2872    }
2873
2874    response.send(proto::Ack {})?;
2875    Ok(())
2876}
2877
2878/// Alter the role for a user in the channel.
2879async fn set_channel_member_role(
2880    request: proto::SetChannelMemberRole,
2881    response: Response<proto::SetChannelMemberRole>,
2882    session: Session,
2883) -> Result<()> {
2884    let db = session.db().await;
2885    let channel_id = ChannelId::from_proto(request.channel_id);
2886    let member_id = UserId::from_proto(request.user_id);
2887    let result = db
2888        .set_channel_member_role(
2889            channel_id,
2890            session.user_id(),
2891            member_id,
2892            request.role().into(),
2893        )
2894        .await?;
2895
2896    match result {
2897        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2898            let mut connection_pool = session.connection_pool().await;
2899            notify_membership_updated(
2900                &mut connection_pool,
2901                membership_update,
2902                member_id,
2903                &session.peer,
2904            )
2905        }
2906        db::SetMemberRoleResult::InviteUpdated(channel) => {
2907            let update = proto::UpdateChannels {
2908                channel_invitations: vec![channel.to_proto()],
2909                ..Default::default()
2910            };
2911
2912            for connection_id in session
2913                .connection_pool()
2914                .await
2915                .user_connection_ids(member_id)
2916            {
2917                session.peer.send(connection_id, update.clone())?;
2918            }
2919        }
2920    }
2921
2922    response.send(proto::Ack {})?;
2923    Ok(())
2924}
2925
2926/// Change the name of a channel
2927async fn rename_channel(
2928    request: proto::RenameChannel,
2929    response: Response<proto::RenameChannel>,
2930    session: Session,
2931) -> Result<()> {
2932    let db = session.db().await;
2933    let channel_id = ChannelId::from_proto(request.channel_id);
2934    let channel_model = db
2935        .rename_channel(channel_id, session.user_id(), &request.name)
2936        .await?;
2937    let root_id = channel_model.root_id();
2938    let channel = Channel::from_model(channel_model);
2939
2940    response.send(proto::RenameChannelResponse {
2941        channel: Some(channel.to_proto()),
2942    })?;
2943
2944    let connection_pool = session.connection_pool().await;
2945    let update = proto::UpdateChannels {
2946        channels: vec![channel.to_proto()],
2947        ..Default::default()
2948    };
2949    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2950        if role.can_see_channel(channel.visibility) {
2951            session.peer.send(connection_id, update.clone())?;
2952        }
2953    }
2954
2955    Ok(())
2956}
2957
2958/// Move a channel to a new parent.
2959async fn move_channel(
2960    request: proto::MoveChannel,
2961    response: Response<proto::MoveChannel>,
2962    session: Session,
2963) -> Result<()> {
2964    let channel_id = ChannelId::from_proto(request.channel_id);
2965    let to = ChannelId::from_proto(request.to);
2966
2967    let (root_id, channels) = session
2968        .db()
2969        .await
2970        .move_channel(channel_id, to, session.user_id())
2971        .await?;
2972
2973    let connection_pool = session.connection_pool().await;
2974    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2975        let channels = channels
2976            .iter()
2977            .filter_map(|channel| {
2978                if role.can_see_channel(channel.visibility) {
2979                    Some(channel.to_proto())
2980                } else {
2981                    None
2982                }
2983            })
2984            .collect::<Vec<_>>();
2985        if channels.is_empty() {
2986            continue;
2987        }
2988
2989        let update = proto::UpdateChannels {
2990            channels,
2991            ..Default::default()
2992        };
2993
2994        session.peer.send(connection_id, update.clone())?;
2995    }
2996
2997    response.send(Ack {})?;
2998    Ok(())
2999}
3000
3001/// Get the list of channel members
3002async fn get_channel_members(
3003    request: proto::GetChannelMembers,
3004    response: Response<proto::GetChannelMembers>,
3005    session: Session,
3006) -> Result<()> {
3007    let db = session.db().await;
3008    let channel_id = ChannelId::from_proto(request.channel_id);
3009    let limit = if request.limit == 0 {
3010        u16::MAX as u64
3011    } else {
3012        request.limit
3013    };
3014    let (members, users) = db
3015        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3016        .await?;
3017    response.send(proto::GetChannelMembersResponse { members, users })?;
3018    Ok(())
3019}
3020
3021/// Accept or decline a channel invitation.
3022async fn respond_to_channel_invite(
3023    request: proto::RespondToChannelInvite,
3024    response: Response<proto::RespondToChannelInvite>,
3025    session: Session,
3026) -> Result<()> {
3027    let db = session.db().await;
3028    let channel_id = ChannelId::from_proto(request.channel_id);
3029    let RespondToChannelInvite {
3030        membership_update,
3031        notifications,
3032    } = db
3033        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3034        .await?;
3035
3036    let mut connection_pool = session.connection_pool().await;
3037    if let Some(membership_update) = membership_update {
3038        notify_membership_updated(
3039            &mut connection_pool,
3040            membership_update,
3041            session.user_id(),
3042            &session.peer,
3043        );
3044    } else {
3045        let update = proto::UpdateChannels {
3046            remove_channel_invitations: vec![channel_id.to_proto()],
3047            ..Default::default()
3048        };
3049
3050        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3051            session.peer.send(connection_id, update.clone())?;
3052        }
3053    };
3054
3055    send_notifications(&connection_pool, &session.peer, notifications);
3056
3057    response.send(proto::Ack {})?;
3058
3059    Ok(())
3060}
3061
3062/// Join the channels' room
3063async fn join_channel(
3064    request: proto::JoinChannel,
3065    response: Response<proto::JoinChannel>,
3066    session: Session,
3067) -> Result<()> {
3068    let channel_id = ChannelId::from_proto(request.channel_id);
3069    join_channel_internal(channel_id, Box::new(response), session).await
3070}
3071
3072trait JoinChannelInternalResponse {
3073    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3074}
3075impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3076    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3077        Response::<proto::JoinChannel>::send(self, result)
3078    }
3079}
3080impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3081    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3082        Response::<proto::JoinRoom>::send(self, result)
3083    }
3084}
3085
3086async fn join_channel_internal(
3087    channel_id: ChannelId,
3088    response: Box<impl JoinChannelInternalResponse>,
3089    session: Session,
3090) -> Result<()> {
3091    let joined_room = {
3092        let mut db = session.db().await;
3093        // If zed quits without leaving the room, and the user re-opens zed before the
3094        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3095        // room they were in.
3096        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3097            tracing::info!(
3098                stale_connection_id = %connection,
3099                "cleaning up stale connection",
3100            );
3101            drop(db);
3102            leave_room_for_session(&session, connection).await?;
3103            db = session.db().await;
3104        }
3105
3106        let (joined_room, membership_updated, role) = db
3107            .join_channel(channel_id, session.user_id(), session.connection_id)
3108            .await?;
3109
3110        let live_kit_connection_info =
3111            session
3112                .app_state
3113                .livekit_client
3114                .as_ref()
3115                .and_then(|live_kit| {
3116                    let (can_publish, token) = if role == ChannelRole::Guest {
3117                        (
3118                            false,
3119                            live_kit
3120                                .guest_token(
3121                                    &joined_room.room.livekit_room,
3122                                    &session.user_id().to_string(),
3123                                )
3124                                .trace_err()?,
3125                        )
3126                    } else {
3127                        (
3128                            true,
3129                            live_kit
3130                                .room_token(
3131                                    &joined_room.room.livekit_room,
3132                                    &session.user_id().to_string(),
3133                                )
3134                                .trace_err()?,
3135                        )
3136                    };
3137
3138                    Some(LiveKitConnectionInfo {
3139                        server_url: live_kit.url().into(),
3140                        token,
3141                        can_publish,
3142                    })
3143                });
3144
3145        response.send(proto::JoinRoomResponse {
3146            room: Some(joined_room.room.clone()),
3147            channel_id: joined_room
3148                .channel
3149                .as_ref()
3150                .map(|channel| channel.id.to_proto()),
3151            live_kit_connection_info,
3152        })?;
3153
3154        let mut connection_pool = session.connection_pool().await;
3155        if let Some(membership_updated) = membership_updated {
3156            notify_membership_updated(
3157                &mut connection_pool,
3158                membership_updated,
3159                session.user_id(),
3160                &session.peer,
3161            );
3162        }
3163
3164        room_updated(&joined_room.room, &session.peer);
3165
3166        joined_room
3167    };
3168
3169    channel_updated(
3170        &joined_room
3171            .channel
3172            .ok_or_else(|| anyhow!("channel not returned"))?,
3173        &joined_room.room,
3174        &session.peer,
3175        &*session.connection_pool().await,
3176    );
3177
3178    update_user_contacts(session.user_id(), &session).await?;
3179    Ok(())
3180}
3181
3182/// Start editing the channel notes
3183async fn join_channel_buffer(
3184    request: proto::JoinChannelBuffer,
3185    response: Response<proto::JoinChannelBuffer>,
3186    session: Session,
3187) -> Result<()> {
3188    let db = session.db().await;
3189    let channel_id = ChannelId::from_proto(request.channel_id);
3190
3191    let open_response = db
3192        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3193        .await?;
3194
3195    let collaborators = open_response.collaborators.clone();
3196    response.send(open_response)?;
3197
3198    let update = UpdateChannelBufferCollaborators {
3199        channel_id: channel_id.to_proto(),
3200        collaborators: collaborators.clone(),
3201    };
3202    channel_buffer_updated(
3203        session.connection_id,
3204        collaborators
3205            .iter()
3206            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3207        &update,
3208        &session.peer,
3209    );
3210
3211    Ok(())
3212}
3213
3214/// Edit the channel notes
3215async fn update_channel_buffer(
3216    request: proto::UpdateChannelBuffer,
3217    session: Session,
3218) -> Result<()> {
3219    let db = session.db().await;
3220    let channel_id = ChannelId::from_proto(request.channel_id);
3221
3222    let (collaborators, epoch, version) = db
3223        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3224        .await?;
3225
3226    channel_buffer_updated(
3227        session.connection_id,
3228        collaborators.clone(),
3229        &proto::UpdateChannelBuffer {
3230            channel_id: channel_id.to_proto(),
3231            operations: request.operations,
3232        },
3233        &session.peer,
3234    );
3235
3236    let pool = &*session.connection_pool().await;
3237
3238    let non_collaborators =
3239        pool.channel_connection_ids(channel_id)
3240            .filter_map(|(connection_id, _)| {
3241                if collaborators.contains(&connection_id) {
3242                    None
3243                } else {
3244                    Some(connection_id)
3245                }
3246            });
3247
3248    broadcast(None, non_collaborators, |peer_id| {
3249        session.peer.send(
3250            peer_id,
3251            proto::UpdateChannels {
3252                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3253                    channel_id: channel_id.to_proto(),
3254                    epoch: epoch as u64,
3255                    version: version.clone(),
3256                }],
3257                ..Default::default()
3258            },
3259        )
3260    });
3261
3262    Ok(())
3263}
3264
3265/// Rejoin the channel notes after a connection blip
3266async fn rejoin_channel_buffers(
3267    request: proto::RejoinChannelBuffers,
3268    response: Response<proto::RejoinChannelBuffers>,
3269    session: Session,
3270) -> Result<()> {
3271    let db = session.db().await;
3272    let buffers = db
3273        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3274        .await?;
3275
3276    for rejoined_buffer in &buffers {
3277        let collaborators_to_notify = rejoined_buffer
3278            .buffer
3279            .collaborators
3280            .iter()
3281            .filter_map(|c| Some(c.peer_id?.into()));
3282        channel_buffer_updated(
3283            session.connection_id,
3284            collaborators_to_notify,
3285            &proto::UpdateChannelBufferCollaborators {
3286                channel_id: rejoined_buffer.buffer.channel_id,
3287                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3288            },
3289            &session.peer,
3290        );
3291    }
3292
3293    response.send(proto::RejoinChannelBuffersResponse {
3294        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3295    })?;
3296
3297    Ok(())
3298}
3299
3300/// Stop editing the channel notes
3301async fn leave_channel_buffer(
3302    request: proto::LeaveChannelBuffer,
3303    response: Response<proto::LeaveChannelBuffer>,
3304    session: Session,
3305) -> Result<()> {
3306    let db = session.db().await;
3307    let channel_id = ChannelId::from_proto(request.channel_id);
3308
3309    let left_buffer = db
3310        .leave_channel_buffer(channel_id, session.connection_id)
3311        .await?;
3312
3313    response.send(Ack {})?;
3314
3315    channel_buffer_updated(
3316        session.connection_id,
3317        left_buffer.connections,
3318        &proto::UpdateChannelBufferCollaborators {
3319            channel_id: channel_id.to_proto(),
3320            collaborators: left_buffer.collaborators,
3321        },
3322        &session.peer,
3323    );
3324
3325    Ok(())
3326}
3327
3328fn channel_buffer_updated<T: EnvelopedMessage>(
3329    sender_id: ConnectionId,
3330    collaborators: impl IntoIterator<Item = ConnectionId>,
3331    message: &T,
3332    peer: &Peer,
3333) {
3334    broadcast(Some(sender_id), collaborators, |peer_id| {
3335        peer.send(peer_id, message.clone())
3336    });
3337}
3338
3339fn send_notifications(
3340    connection_pool: &ConnectionPool,
3341    peer: &Peer,
3342    notifications: db::NotificationBatch,
3343) {
3344    for (user_id, notification) in notifications {
3345        for connection_id in connection_pool.user_connection_ids(user_id) {
3346            if let Err(error) = peer.send(
3347                connection_id,
3348                proto::AddNotification {
3349                    notification: Some(notification.clone()),
3350                },
3351            ) {
3352                tracing::error!(
3353                    "failed to send notification to {:?} {}",
3354                    connection_id,
3355                    error
3356                );
3357            }
3358        }
3359    }
3360}
3361
3362/// Send a message to the channel
3363async fn send_channel_message(
3364    request: proto::SendChannelMessage,
3365    response: Response<proto::SendChannelMessage>,
3366    session: Session,
3367) -> Result<()> {
3368    // Validate the message body.
3369    let body = request.body.trim().to_string();
3370    if body.len() > MAX_MESSAGE_LEN {
3371        return Err(anyhow!("message is too long"))?;
3372    }
3373    if body.is_empty() {
3374        return Err(anyhow!("message can't be blank"))?;
3375    }
3376
3377    // TODO: adjust mentions if body is trimmed
3378
3379    let timestamp = OffsetDateTime::now_utc();
3380    let nonce = request
3381        .nonce
3382        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3383
3384    let channel_id = ChannelId::from_proto(request.channel_id);
3385    let CreatedChannelMessage {
3386        message_id,
3387        participant_connection_ids,
3388        notifications,
3389    } = session
3390        .db()
3391        .await
3392        .create_channel_message(
3393            channel_id,
3394            session.user_id(),
3395            &body,
3396            &request.mentions,
3397            timestamp,
3398            nonce.clone().into(),
3399            request.reply_to_message_id.map(MessageId::from_proto),
3400        )
3401        .await?;
3402
3403    let message = proto::ChannelMessage {
3404        sender_id: session.user_id().to_proto(),
3405        id: message_id.to_proto(),
3406        body,
3407        mentions: request.mentions,
3408        timestamp: timestamp.unix_timestamp() as u64,
3409        nonce: Some(nonce),
3410        reply_to_message_id: request.reply_to_message_id,
3411        edited_at: None,
3412    };
3413    broadcast(
3414        Some(session.connection_id),
3415        participant_connection_ids.clone(),
3416        |connection| {
3417            session.peer.send(
3418                connection,
3419                proto::ChannelMessageSent {
3420                    channel_id: channel_id.to_proto(),
3421                    message: Some(message.clone()),
3422                },
3423            )
3424        },
3425    );
3426    response.send(proto::SendChannelMessageResponse {
3427        message: Some(message),
3428    })?;
3429
3430    let pool = &*session.connection_pool().await;
3431    let non_participants =
3432        pool.channel_connection_ids(channel_id)
3433            .filter_map(|(connection_id, _)| {
3434                if participant_connection_ids.contains(&connection_id) {
3435                    None
3436                } else {
3437                    Some(connection_id)
3438                }
3439            });
3440    broadcast(None, non_participants, |peer_id| {
3441        session.peer.send(
3442            peer_id,
3443            proto::UpdateChannels {
3444                latest_channel_message_ids: vec![proto::ChannelMessageId {
3445                    channel_id: channel_id.to_proto(),
3446                    message_id: message_id.to_proto(),
3447                }],
3448                ..Default::default()
3449            },
3450        )
3451    });
3452    send_notifications(pool, &session.peer, notifications);
3453
3454    Ok(())
3455}
3456
3457/// Delete a channel message
3458async fn remove_channel_message(
3459    request: proto::RemoveChannelMessage,
3460    response: Response<proto::RemoveChannelMessage>,
3461    session: Session,
3462) -> Result<()> {
3463    let channel_id = ChannelId::from_proto(request.channel_id);
3464    let message_id = MessageId::from_proto(request.message_id);
3465    let (connection_ids, existing_notification_ids) = session
3466        .db()
3467        .await
3468        .remove_channel_message(channel_id, message_id, session.user_id())
3469        .await?;
3470
3471    broadcast(
3472        Some(session.connection_id),
3473        connection_ids,
3474        move |connection| {
3475            session.peer.send(connection, request.clone())?;
3476
3477            for notification_id in &existing_notification_ids {
3478                session.peer.send(
3479                    connection,
3480                    proto::DeleteNotification {
3481                        notification_id: (*notification_id).to_proto(),
3482                    },
3483                )?;
3484            }
3485
3486            Ok(())
3487        },
3488    );
3489    response.send(proto::Ack {})?;
3490    Ok(())
3491}
3492
3493async fn update_channel_message(
3494    request: proto::UpdateChannelMessage,
3495    response: Response<proto::UpdateChannelMessage>,
3496    session: Session,
3497) -> Result<()> {
3498    let channel_id = ChannelId::from_proto(request.channel_id);
3499    let message_id = MessageId::from_proto(request.message_id);
3500    let updated_at = OffsetDateTime::now_utc();
3501    let UpdatedChannelMessage {
3502        message_id,
3503        participant_connection_ids,
3504        notifications,
3505        reply_to_message_id,
3506        timestamp,
3507        deleted_mention_notification_ids,
3508        updated_mention_notifications,
3509    } = session
3510        .db()
3511        .await
3512        .update_channel_message(
3513            channel_id,
3514            message_id,
3515            session.user_id(),
3516            request.body.as_str(),
3517            &request.mentions,
3518            updated_at,
3519        )
3520        .await?;
3521
3522    let nonce = request
3523        .nonce
3524        .clone()
3525        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3526
3527    let message = proto::ChannelMessage {
3528        sender_id: session.user_id().to_proto(),
3529        id: message_id.to_proto(),
3530        body: request.body.clone(),
3531        mentions: request.mentions.clone(),
3532        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3533        nonce: Some(nonce),
3534        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3535        edited_at: Some(updated_at.unix_timestamp() as u64),
3536    };
3537
3538    response.send(proto::Ack {})?;
3539
3540    let pool = &*session.connection_pool().await;
3541    broadcast(
3542        Some(session.connection_id),
3543        participant_connection_ids,
3544        |connection| {
3545            session.peer.send(
3546                connection,
3547                proto::ChannelMessageUpdate {
3548                    channel_id: channel_id.to_proto(),
3549                    message: Some(message.clone()),
3550                },
3551            )?;
3552
3553            for notification_id in &deleted_mention_notification_ids {
3554                session.peer.send(
3555                    connection,
3556                    proto::DeleteNotification {
3557                        notification_id: (*notification_id).to_proto(),
3558                    },
3559                )?;
3560            }
3561
3562            for notification in &updated_mention_notifications {
3563                session.peer.send(
3564                    connection,
3565                    proto::UpdateNotification {
3566                        notification: Some(notification.clone()),
3567                    },
3568                )?;
3569            }
3570
3571            Ok(())
3572        },
3573    );
3574
3575    send_notifications(pool, &session.peer, notifications);
3576
3577    Ok(())
3578}
3579
3580/// Mark a channel message as read
3581async fn acknowledge_channel_message(
3582    request: proto::AckChannelMessage,
3583    session: Session,
3584) -> Result<()> {
3585    let channel_id = ChannelId::from_proto(request.channel_id);
3586    let message_id = MessageId::from_proto(request.message_id);
3587    let notifications = session
3588        .db()
3589        .await
3590        .observe_channel_message(channel_id, session.user_id(), message_id)
3591        .await?;
3592    send_notifications(
3593        &*session.connection_pool().await,
3594        &session.peer,
3595        notifications,
3596    );
3597    Ok(())
3598}
3599
3600/// Mark a buffer version as synced
3601async fn acknowledge_buffer_version(
3602    request: proto::AckBufferOperation,
3603    session: Session,
3604) -> Result<()> {
3605    let buffer_id = BufferId::from_proto(request.buffer_id);
3606    session
3607        .db()
3608        .await
3609        .observe_buffer_version(
3610            buffer_id,
3611            session.user_id(),
3612            request.epoch as i32,
3613            &request.version,
3614        )
3615        .await?;
3616    Ok(())
3617}
3618
3619async fn count_language_model_tokens(
3620    request: proto::CountLanguageModelTokens,
3621    response: Response<proto::CountLanguageModelTokens>,
3622    session: Session,
3623    config: &Config,
3624) -> Result<()> {
3625    authorize_access_to_legacy_llm_endpoints(&session).await?;
3626
3627    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3628        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3629        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3630    };
3631
3632    session
3633        .app_state
3634        .rate_limiter
3635        .check(&*rate_limit, session.user_id())
3636        .await?;
3637
3638    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3639        Some(proto::LanguageModelProvider::Google) => {
3640            let api_key = config
3641                .google_ai_api_key
3642                .as_ref()
3643                .context("no Google AI API key configured on the server")?;
3644            google_ai::count_tokens(
3645                session.http_client.as_ref(),
3646                google_ai::API_URL,
3647                api_key,
3648                serde_json::from_str(&request.request)?,
3649            )
3650            .await?
3651        }
3652        _ => return Err(anyhow!("unsupported provider"))?,
3653    };
3654
3655    response.send(proto::CountLanguageModelTokensResponse {
3656        token_count: result.total_tokens as u32,
3657    })?;
3658
3659    Ok(())
3660}
3661
3662struct ZedProCountLanguageModelTokensRateLimit;
3663
3664impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3665    fn capacity(&self) -> usize {
3666        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3667            .ok()
3668            .and_then(|v| v.parse().ok())
3669            .unwrap_or(600) // Picked arbitrarily
3670    }
3671
3672    fn refill_duration(&self) -> chrono::Duration {
3673        chrono::Duration::hours(1)
3674    }
3675
3676    fn db_name(&self) -> &'static str {
3677        "zed-pro:count-language-model-tokens"
3678    }
3679}
3680
3681struct FreeCountLanguageModelTokensRateLimit;
3682
3683impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3684    fn capacity(&self) -> usize {
3685        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3686            .ok()
3687            .and_then(|v| v.parse().ok())
3688            .unwrap_or(600 / 10) // Picked arbitrarily
3689    }
3690
3691    fn refill_duration(&self) -> chrono::Duration {
3692        chrono::Duration::hours(1)
3693    }
3694
3695    fn db_name(&self) -> &'static str {
3696        "free:count-language-model-tokens"
3697    }
3698}
3699
3700struct ZedProComputeEmbeddingsRateLimit;
3701
3702impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3703    fn capacity(&self) -> usize {
3704        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3705            .ok()
3706            .and_then(|v| v.parse().ok())
3707            .unwrap_or(5000) // Picked arbitrarily
3708    }
3709
3710    fn refill_duration(&self) -> chrono::Duration {
3711        chrono::Duration::hours(1)
3712    }
3713
3714    fn db_name(&self) -> &'static str {
3715        "zed-pro:compute-embeddings"
3716    }
3717}
3718
3719struct FreeComputeEmbeddingsRateLimit;
3720
3721impl RateLimit for FreeComputeEmbeddingsRateLimit {
3722    fn capacity(&self) -> usize {
3723        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3724            .ok()
3725            .and_then(|v| v.parse().ok())
3726            .unwrap_or(5000 / 10) // Picked arbitrarily
3727    }
3728
3729    fn refill_duration(&self) -> chrono::Duration {
3730        chrono::Duration::hours(1)
3731    }
3732
3733    fn db_name(&self) -> &'static str {
3734        "free:compute-embeddings"
3735    }
3736}
3737
3738async fn compute_embeddings(
3739    request: proto::ComputeEmbeddings,
3740    response: Response<proto::ComputeEmbeddings>,
3741    session: Session,
3742    api_key: Option<Arc<str>>,
3743) -> Result<()> {
3744    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3745    authorize_access_to_legacy_llm_endpoints(&session).await?;
3746
3747    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3748        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3749        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3750    };
3751
3752    session
3753        .app_state
3754        .rate_limiter
3755        .check(&*rate_limit, session.user_id())
3756        .await?;
3757
3758    let embeddings = match request.model.as_str() {
3759        "openai/text-embedding-3-small" => {
3760            open_ai::embed(
3761                session.http_client.as_ref(),
3762                OPEN_AI_API_URL,
3763                &api_key,
3764                OpenAiEmbeddingModel::TextEmbedding3Small,
3765                request.texts.iter().map(|text| text.as_str()),
3766            )
3767            .await?
3768        }
3769        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3770    };
3771
3772    let embeddings = request
3773        .texts
3774        .iter()
3775        .map(|text| {
3776            let mut hasher = sha2::Sha256::new();
3777            hasher.update(text.as_bytes());
3778            let result = hasher.finalize();
3779            result.to_vec()
3780        })
3781        .zip(
3782            embeddings
3783                .data
3784                .into_iter()
3785                .map(|embedding| embedding.embedding),
3786        )
3787        .collect::<HashMap<_, _>>();
3788
3789    let db = session.db().await;
3790    db.save_embeddings(&request.model, &embeddings)
3791        .await
3792        .context("failed to save embeddings")
3793        .trace_err();
3794
3795    response.send(proto::ComputeEmbeddingsResponse {
3796        embeddings: embeddings
3797            .into_iter()
3798            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3799            .collect(),
3800    })?;
3801    Ok(())
3802}
3803
3804async fn get_cached_embeddings(
3805    request: proto::GetCachedEmbeddings,
3806    response: Response<proto::GetCachedEmbeddings>,
3807    session: Session,
3808) -> Result<()> {
3809    authorize_access_to_legacy_llm_endpoints(&session).await?;
3810
3811    let db = session.db().await;
3812    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3813
3814    response.send(proto::GetCachedEmbeddingsResponse {
3815        embeddings: embeddings
3816            .into_iter()
3817            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3818            .collect(),
3819    })?;
3820    Ok(())
3821}
3822
3823/// This is leftover from before the LLM service.
3824///
3825/// The endpoints protected by this check will be moved there eventually.
3826async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3827    if session.is_staff() {
3828        Ok(())
3829    } else {
3830        Err(anyhow!("permission denied"))?
3831    }
3832}
3833
3834/// Get a Supermaven API key for the user
3835async fn get_supermaven_api_key(
3836    _request: proto::GetSupermavenApiKey,
3837    response: Response<proto::GetSupermavenApiKey>,
3838    session: Session,
3839) -> Result<()> {
3840    let user_id: String = session.user_id().to_string();
3841    if !session.is_staff() {
3842        return Err(anyhow!("supermaven not enabled for this account"))?;
3843    }
3844
3845    let email = session
3846        .email()
3847        .ok_or_else(|| anyhow!("user must have an email"))?;
3848
3849    let supermaven_admin_api = session
3850        .supermaven_client
3851        .as_ref()
3852        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3853
3854    let result = supermaven_admin_api
3855        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3856        .await?;
3857
3858    response.send(proto::GetSupermavenApiKeyResponse {
3859        api_key: result.api_key,
3860    })?;
3861
3862    Ok(())
3863}
3864
3865/// Start receiving chat updates for a channel
3866async fn join_channel_chat(
3867    request: proto::JoinChannelChat,
3868    response: Response<proto::JoinChannelChat>,
3869    session: Session,
3870) -> Result<()> {
3871    let channel_id = ChannelId::from_proto(request.channel_id);
3872
3873    let db = session.db().await;
3874    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3875        .await?;
3876    let messages = db
3877        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3878        .await?;
3879    response.send(proto::JoinChannelChatResponse {
3880        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3881        messages,
3882    })?;
3883    Ok(())
3884}
3885
3886/// Stop receiving chat updates for a channel
3887async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3888    let channel_id = ChannelId::from_proto(request.channel_id);
3889    session
3890        .db()
3891        .await
3892        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3893        .await?;
3894    Ok(())
3895}
3896
3897/// Retrieve the chat history for a channel
3898async fn get_channel_messages(
3899    request: proto::GetChannelMessages,
3900    response: Response<proto::GetChannelMessages>,
3901    session: Session,
3902) -> Result<()> {
3903    let channel_id = ChannelId::from_proto(request.channel_id);
3904    let messages = session
3905        .db()
3906        .await
3907        .get_channel_messages(
3908            channel_id,
3909            session.user_id(),
3910            MESSAGE_COUNT_PER_PAGE,
3911            Some(MessageId::from_proto(request.before_message_id)),
3912        )
3913        .await?;
3914    response.send(proto::GetChannelMessagesResponse {
3915        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3916        messages,
3917    })?;
3918    Ok(())
3919}
3920
3921/// Retrieve specific chat messages
3922async fn get_channel_messages_by_id(
3923    request: proto::GetChannelMessagesById,
3924    response: Response<proto::GetChannelMessagesById>,
3925    session: Session,
3926) -> Result<()> {
3927    let message_ids = request
3928        .message_ids
3929        .iter()
3930        .map(|id| MessageId::from_proto(*id))
3931        .collect::<Vec<_>>();
3932    let messages = session
3933        .db()
3934        .await
3935        .get_channel_messages_by_id(session.user_id(), &message_ids)
3936        .await?;
3937    response.send(proto::GetChannelMessagesResponse {
3938        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3939        messages,
3940    })?;
3941    Ok(())
3942}
3943
3944/// Retrieve the current users notifications
3945async fn get_notifications(
3946    request: proto::GetNotifications,
3947    response: Response<proto::GetNotifications>,
3948    session: Session,
3949) -> Result<()> {
3950    let notifications = session
3951        .db()
3952        .await
3953        .get_notifications(
3954            session.user_id(),
3955            NOTIFICATION_COUNT_PER_PAGE,
3956            request.before_id.map(db::NotificationId::from_proto),
3957        )
3958        .await?;
3959    response.send(proto::GetNotificationsResponse {
3960        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3961        notifications,
3962    })?;
3963    Ok(())
3964}
3965
3966/// Mark notifications as read
3967async fn mark_notification_as_read(
3968    request: proto::MarkNotificationRead,
3969    response: Response<proto::MarkNotificationRead>,
3970    session: Session,
3971) -> Result<()> {
3972    let database = &session.db().await;
3973    let notifications = database
3974        .mark_notification_as_read_by_id(
3975            session.user_id(),
3976            NotificationId::from_proto(request.notification_id),
3977        )
3978        .await?;
3979    send_notifications(
3980        &*session.connection_pool().await,
3981        &session.peer,
3982        notifications,
3983    );
3984    response.send(proto::Ack {})?;
3985    Ok(())
3986}
3987
3988/// Get the current users information
3989async fn get_private_user_info(
3990    _request: proto::GetPrivateUserInfo,
3991    response: Response<proto::GetPrivateUserInfo>,
3992    session: Session,
3993) -> Result<()> {
3994    let db = session.db().await;
3995
3996    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3997    let user = db
3998        .get_user_by_id(session.user_id())
3999        .await?
4000        .ok_or_else(|| anyhow!("user not found"))?;
4001    let flags = db.get_user_flags(session.user_id()).await?;
4002
4003    response.send(proto::GetPrivateUserInfoResponse {
4004        metrics_id,
4005        staff: user.admin,
4006        flags,
4007        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4008    })?;
4009    Ok(())
4010}
4011
4012/// Accept the terms of service (tos) on behalf of the current user
4013async fn accept_terms_of_service(
4014    _request: proto::AcceptTermsOfService,
4015    response: Response<proto::AcceptTermsOfService>,
4016    session: Session,
4017) -> Result<()> {
4018    let db = session.db().await;
4019
4020    let accepted_tos_at = Utc::now();
4021    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4022        .await?;
4023
4024    response.send(proto::AcceptTermsOfServiceResponse {
4025        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4026    })?;
4027    Ok(())
4028}
4029
4030/// The minimum account age an account must have in order to use the LLM service.
4031const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4032
4033async fn get_llm_api_token(
4034    _request: proto::GetLlmToken,
4035    response: Response<proto::GetLlmToken>,
4036    session: Session,
4037) -> Result<()> {
4038    let db = session.db().await;
4039
4040    let flags = db.get_user_flags(session.user_id()).await?;
4041    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4042    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4043    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4044
4045    if !session.is_staff() && !has_language_models_feature_flag {
4046        Err(anyhow!("permission denied"))?
4047    }
4048
4049    let user_id = session.user_id();
4050    let user = db
4051        .get_user_by_id(user_id)
4052        .await?
4053        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4054
4055    if user.accepted_tos_at.is_none() {
4056        Err(anyhow!("terms of service not accepted"))?
4057    }
4058
4059    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4060
4061    let bypass_account_age_check =
4062        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4063    if !bypass_account_age_check {
4064        let mut account_created_at = user.created_at;
4065        if let Some(github_created_at) = user.github_user_created_at {
4066            account_created_at = account_created_at.min(github_created_at);
4067        }
4068        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4069            Err(anyhow!("account too young"))?
4070        }
4071    }
4072
4073    let billing_preferences = db.get_billing_preferences(user.id).await?;
4074
4075    let token = LlmTokenClaims::create(
4076        &user,
4077        session.is_staff(),
4078        billing_preferences,
4079        has_llm_closed_beta_feature_flag,
4080        has_predict_edits_feature_flag,
4081        has_llm_subscription,
4082        session.current_plan(&db).await?,
4083        session.system_id.clone(),
4084        &session.app_state.config,
4085    )?;
4086    response.send(proto::GetLlmTokenResponse { token })?;
4087    Ok(())
4088}
4089
4090fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4091    let message = match message {
4092        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4093        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4094        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4095        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4096        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4097            code: frame.code.into(),
4098            reason: frame.reason,
4099        })),
4100        // We should never receive a frame while reading the message, according
4101        // to the `tungstenite` maintainers:
4102        //
4103        // > It cannot occur when you read messages from the WebSocket, but it
4104        // > can be used when you want to send the raw frames (e.g. you want to
4105        // > send the frames to the WebSocket without composing the full message first).
4106        // >
4107        // > — https://github.com/snapview/tungstenite-rs/issues/268
4108        TungsteniteMessage::Frame(_) => {
4109            bail!("received an unexpected frame while reading the message")
4110        }
4111    };
4112
4113    Ok(message)
4114}
4115
4116fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4117    match message {
4118        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4119        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4120        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4121        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4122        AxumMessage::Close(frame) => {
4123            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4124                code: frame.code.into(),
4125                reason: frame.reason,
4126            }))
4127        }
4128    }
4129}
4130
4131fn notify_membership_updated(
4132    connection_pool: &mut ConnectionPool,
4133    result: MembershipUpdated,
4134    user_id: UserId,
4135    peer: &Peer,
4136) {
4137    for membership in &result.new_channels.channel_memberships {
4138        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4139    }
4140    for channel_id in &result.removed_channels {
4141        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4142    }
4143
4144    let user_channels_update = proto::UpdateUserChannels {
4145        channel_memberships: result
4146            .new_channels
4147            .channel_memberships
4148            .iter()
4149            .map(|cm| proto::ChannelMembership {
4150                channel_id: cm.channel_id.to_proto(),
4151                role: cm.role.into(),
4152            })
4153            .collect(),
4154        ..Default::default()
4155    };
4156
4157    let mut update = build_channels_update(result.new_channels);
4158    update.delete_channels = result
4159        .removed_channels
4160        .into_iter()
4161        .map(|id| id.to_proto())
4162        .collect();
4163    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4164
4165    for connection_id in connection_pool.user_connection_ids(user_id) {
4166        peer.send(connection_id, user_channels_update.clone())
4167            .trace_err();
4168        peer.send(connection_id, update.clone()).trace_err();
4169    }
4170}
4171
4172fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4173    proto::UpdateUserChannels {
4174        channel_memberships: channels
4175            .channel_memberships
4176            .iter()
4177            .map(|m| proto::ChannelMembership {
4178                channel_id: m.channel_id.to_proto(),
4179                role: m.role.into(),
4180            })
4181            .collect(),
4182        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4183        observed_channel_message_id: channels.observed_channel_messages.clone(),
4184    }
4185}
4186
4187fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4188    let mut update = proto::UpdateChannels::default();
4189
4190    for channel in channels.channels {
4191        update.channels.push(channel.to_proto());
4192    }
4193
4194    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4195    update.latest_channel_message_ids = channels.latest_channel_messages;
4196
4197    for (channel_id, participants) in channels.channel_participants {
4198        update
4199            .channel_participants
4200            .push(proto::ChannelParticipants {
4201                channel_id: channel_id.to_proto(),
4202                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4203            });
4204    }
4205
4206    for channel in channels.invited_channels {
4207        update.channel_invitations.push(channel.to_proto());
4208    }
4209
4210    update
4211}
4212
4213fn build_initial_contacts_update(
4214    contacts: Vec<db::Contact>,
4215    pool: &ConnectionPool,
4216) -> proto::UpdateContacts {
4217    let mut update = proto::UpdateContacts::default();
4218
4219    for contact in contacts {
4220        match contact {
4221            db::Contact::Accepted { user_id, busy } => {
4222                update.contacts.push(contact_for_user(user_id, busy, pool));
4223            }
4224            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4225            db::Contact::Incoming { user_id } => {
4226                update
4227                    .incoming_requests
4228                    .push(proto::IncomingContactRequest {
4229                        requester_id: user_id.to_proto(),
4230                    })
4231            }
4232        }
4233    }
4234
4235    update
4236}
4237
4238fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4239    proto::Contact {
4240        user_id: user_id.to_proto(),
4241        online: pool.is_user_online(user_id),
4242        busy,
4243    }
4244}
4245
4246fn room_updated(room: &proto::Room, peer: &Peer) {
4247    broadcast(
4248        None,
4249        room.participants
4250            .iter()
4251            .filter_map(|participant| Some(participant.peer_id?.into())),
4252        |peer_id| {
4253            peer.send(
4254                peer_id,
4255                proto::RoomUpdated {
4256                    room: Some(room.clone()),
4257                },
4258            )
4259        },
4260    );
4261}
4262
4263fn channel_updated(
4264    channel: &db::channel::Model,
4265    room: &proto::Room,
4266    peer: &Peer,
4267    pool: &ConnectionPool,
4268) {
4269    let participants = room
4270        .participants
4271        .iter()
4272        .map(|p| p.user_id)
4273        .collect::<Vec<_>>();
4274
4275    broadcast(
4276        None,
4277        pool.channel_connection_ids(channel.root_id())
4278            .filter_map(|(channel_id, role)| {
4279                role.can_see_channel(channel.visibility)
4280                    .then_some(channel_id)
4281            }),
4282        |peer_id| {
4283            peer.send(
4284                peer_id,
4285                proto::UpdateChannels {
4286                    channel_participants: vec![proto::ChannelParticipants {
4287                        channel_id: channel.id.to_proto(),
4288                        participant_user_ids: participants.clone(),
4289                    }],
4290                    ..Default::default()
4291                },
4292            )
4293        },
4294    );
4295}
4296
4297async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4298    let db = session.db().await;
4299
4300    let contacts = db.get_contacts(user_id).await?;
4301    let busy = db.is_user_busy(user_id).await?;
4302
4303    let pool = session.connection_pool().await;
4304    let updated_contact = contact_for_user(user_id, busy, &pool);
4305    for contact in contacts {
4306        if let db::Contact::Accepted {
4307            user_id: contact_user_id,
4308            ..
4309        } = contact
4310        {
4311            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4312                session
4313                    .peer
4314                    .send(
4315                        contact_conn_id,
4316                        proto::UpdateContacts {
4317                            contacts: vec![updated_contact.clone()],
4318                            remove_contacts: Default::default(),
4319                            incoming_requests: Default::default(),
4320                            remove_incoming_requests: Default::default(),
4321                            outgoing_requests: Default::default(),
4322                            remove_outgoing_requests: Default::default(),
4323                        },
4324                    )
4325                    .trace_err();
4326            }
4327        }
4328    }
4329    Ok(())
4330}
4331
4332async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4333    let mut contacts_to_update = HashSet::default();
4334
4335    let room_id;
4336    let canceled_calls_to_user_ids;
4337    let livekit_room;
4338    let delete_livekit_room;
4339    let room;
4340    let channel;
4341
4342    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4343        contacts_to_update.insert(session.user_id());
4344
4345        for project in left_room.left_projects.values() {
4346            project_left(project, session);
4347        }
4348
4349        room_id = RoomId::from_proto(left_room.room.id);
4350        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4351        livekit_room = mem::take(&mut left_room.room.livekit_room);
4352        delete_livekit_room = left_room.deleted;
4353        room = mem::take(&mut left_room.room);
4354        channel = mem::take(&mut left_room.channel);
4355
4356        room_updated(&room, &session.peer);
4357    } else {
4358        return Ok(());
4359    }
4360
4361    if let Some(channel) = channel {
4362        channel_updated(
4363            &channel,
4364            &room,
4365            &session.peer,
4366            &*session.connection_pool().await,
4367        );
4368    }
4369
4370    {
4371        let pool = session.connection_pool().await;
4372        for canceled_user_id in canceled_calls_to_user_ids {
4373            for connection_id in pool.user_connection_ids(canceled_user_id) {
4374                session
4375                    .peer
4376                    .send(
4377                        connection_id,
4378                        proto::CallCanceled {
4379                            room_id: room_id.to_proto(),
4380                        },
4381                    )
4382                    .trace_err();
4383            }
4384            contacts_to_update.insert(canceled_user_id);
4385        }
4386    }
4387
4388    for contact_user_id in contacts_to_update {
4389        update_user_contacts(contact_user_id, session).await?;
4390    }
4391
4392    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4393        live_kit
4394            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4395            .await
4396            .trace_err();
4397
4398        if delete_livekit_room {
4399            live_kit.delete_room(livekit_room).await.trace_err();
4400        }
4401    }
4402
4403    Ok(())
4404}
4405
4406async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4407    let left_channel_buffers = session
4408        .db()
4409        .await
4410        .leave_channel_buffers(session.connection_id)
4411        .await?;
4412
4413    for left_buffer in left_channel_buffers {
4414        channel_buffer_updated(
4415            session.connection_id,
4416            left_buffer.connections,
4417            &proto::UpdateChannelBufferCollaborators {
4418                channel_id: left_buffer.channel_id.to_proto(),
4419                collaborators: left_buffer.collaborators,
4420            },
4421            &session.peer,
4422        );
4423    }
4424
4425    Ok(())
4426}
4427
4428fn project_left(project: &db::LeftProject, session: &Session) {
4429    for connection_id in &project.connection_ids {
4430        if project.should_unshare {
4431            session
4432                .peer
4433                .send(
4434                    *connection_id,
4435                    proto::UnshareProject {
4436                        project_id: project.id.to_proto(),
4437                    },
4438                )
4439                .trace_err();
4440        } else {
4441            session
4442                .peer
4443                .send(
4444                    *connection_id,
4445                    proto::RemoveProjectCollaborator {
4446                        project_id: project.id.to_proto(),
4447                        peer_id: Some(session.connection_id.into()),
4448                    },
4449                )
4450                .trace_err();
4451        }
4452    }
4453}
4454
4455pub trait ResultExt {
4456    type Ok;
4457
4458    fn trace_err(self) -> Option<Self::Ok>;
4459}
4460
4461impl<T, E> ResultExt for Result<T, E>
4462where
4463    E: std::fmt::Debug,
4464{
4465    type Ok = T;
4466
4467    #[track_caller]
4468    fn trace_err(self) -> Option<T> {
4469        match self {
4470            Ok(value) => Some(value),
4471            Err(error) => {
4472                tracing::error!("{:?}", error);
4473                None
4474            }
4475        }
4476    }
4477}