rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 311            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 314            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 315            .add_request_handler(
 316                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 317            )
 318            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 319            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 320            .add_request_handler(
 321                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 322            )
 323            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 324            .add_request_handler(
 325                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 326            )
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 328            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 333            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 338            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 339            .add_request_handler(
 340                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 341            )
 342            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 343            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 345            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 346            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 347            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 348            .add_message_handler(create_buffer_for_peer)
 349            .add_request_handler(update_buffer)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 354            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 355            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 356            .add_request_handler(get_users)
 357            .add_request_handler(fuzzy_search_users)
 358            .add_request_handler(request_contact)
 359            .add_request_handler(remove_contact)
 360            .add_request_handler(respond_to_contact_request)
 361            .add_message_handler(subscribe_to_channels)
 362            .add_request_handler(create_channel)
 363            .add_request_handler(delete_channel)
 364            .add_request_handler(invite_channel_member)
 365            .add_request_handler(remove_channel_member)
 366            .add_request_handler(set_channel_member_role)
 367            .add_request_handler(set_channel_visibility)
 368            .add_request_handler(rename_channel)
 369            .add_request_handler(join_channel_buffer)
 370            .add_request_handler(leave_channel_buffer)
 371            .add_message_handler(update_channel_buffer)
 372            .add_request_handler(rejoin_channel_buffers)
 373            .add_request_handler(get_channel_members)
 374            .add_request_handler(respond_to_channel_invite)
 375            .add_request_handler(join_channel)
 376            .add_request_handler(join_channel_chat)
 377            .add_message_handler(leave_channel_chat)
 378            .add_request_handler(send_channel_message)
 379            .add_request_handler(remove_channel_message)
 380            .add_request_handler(update_channel_message)
 381            .add_request_handler(get_channel_messages)
 382            .add_request_handler(get_channel_messages_by_id)
 383            .add_request_handler(get_notifications)
 384            .add_request_handler(mark_notification_as_read)
 385            .add_request_handler(move_channel)
 386            .add_request_handler(follow)
 387            .add_message_handler(unfollow)
 388            .add_message_handler(update_followers)
 389            .add_request_handler(get_private_user_info)
 390            .add_request_handler(get_llm_api_token)
 391            .add_request_handler(accept_terms_of_service)
 392            .add_message_handler(acknowledge_channel_message)
 393            .add_message_handler(acknowledge_buffer_version)
 394            .add_request_handler(get_supermaven_api_key)
 395            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 396            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 397            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 398            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 399            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 400            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 401            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 402            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 403            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 404            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 405            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 406            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 407            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 408            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 409            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 410            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 411            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 412            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 413            .add_message_handler(update_context)
 414            .add_request_handler({
 415                let app_state = app_state.clone();
 416                move |request, response, session| {
 417                    let app_state = app_state.clone();
 418                    async move {
 419                        count_language_model_tokens(request, response, session, &app_state.config)
 420                            .await
 421                    }
 422                }
 423            })
 424            .add_request_handler(get_cached_embeddings)
 425            .add_request_handler({
 426                let app_state = app_state.clone();
 427                move |request, response, session| {
 428                    compute_embeddings(
 429                        request,
 430                        response,
 431                        session,
 432                        app_state.config.openai_api_key.clone(),
 433                    )
 434                }
 435            });
 436
 437        Arc::new(server)
 438    }
 439
 440    pub async fn start(&self) -> Result<()> {
 441        let server_id = *self.id.lock();
 442        let app_state = self.app_state.clone();
 443        let peer = self.peer.clone();
 444        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 445        let pool = self.connection_pool.clone();
 446        let livekit_client = self.app_state.livekit_client.clone();
 447
 448        let span = info_span!("start server");
 449        self.app_state.executor.spawn_detached(
 450            async move {
 451                tracing::info!("waiting for cleanup timeout");
 452                timeout.await;
 453                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 454                if let Some((room_ids, channel_ids)) = app_state
 455                    .db
 456                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 457                    .await
 458                    .trace_err()
 459                {
 460                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 461                    tracing::info!(
 462                        stale_channel_buffer_count = channel_ids.len(),
 463                        "retrieved stale channel buffers"
 464                    );
 465
 466                    for channel_id in channel_ids {
 467                        if let Some(refreshed_channel_buffer) = app_state
 468                            .db
 469                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 470                            .await
 471                            .trace_err()
 472                        {
 473                            for connection_id in refreshed_channel_buffer.connection_ids {
 474                                peer.send(
 475                                    connection_id,
 476                                    proto::UpdateChannelBufferCollaborators {
 477                                        channel_id: channel_id.to_proto(),
 478                                        collaborators: refreshed_channel_buffer
 479                                            .collaborators
 480                                            .clone(),
 481                                    },
 482                                )
 483                                .trace_err();
 484                            }
 485                        }
 486                    }
 487
 488                    for room_id in room_ids {
 489                        let mut contacts_to_update = HashSet::default();
 490                        let mut canceled_calls_to_user_ids = Vec::new();
 491                        let mut livekit_room = String::new();
 492                        let mut delete_livekit_room = false;
 493
 494                        if let Some(mut refreshed_room) = app_state
 495                            .db
 496                            .clear_stale_room_participants(room_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            tracing::info!(
 501                                room_id = room_id.0,
 502                                new_participant_count = refreshed_room.room.participants.len(),
 503                                "refreshed room"
 504                            );
 505                            room_updated(&refreshed_room.room, &peer);
 506                            if let Some(channel) = refreshed_room.channel.as_ref() {
 507                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 508                            }
 509                            contacts_to_update
 510                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 511                            contacts_to_update
 512                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 513                            canceled_calls_to_user_ids =
 514                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 515                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 516                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 517                        }
 518
 519                        {
 520                            let pool = pool.lock();
 521                            for canceled_user_id in canceled_calls_to_user_ids {
 522                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 523                                    peer.send(
 524                                        connection_id,
 525                                        proto::CallCanceled {
 526                                            room_id: room_id.to_proto(),
 527                                        },
 528                                    )
 529                                    .trace_err();
 530                                }
 531                            }
 532                        }
 533
 534                        for user_id in contacts_to_update {
 535                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 536                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 537                            if let Some((busy, contacts)) = busy.zip(contacts) {
 538                                let pool = pool.lock();
 539                                let updated_contact = contact_for_user(user_id, busy, &pool);
 540                                for contact in contacts {
 541                                    if let db::Contact::Accepted {
 542                                        user_id: contact_user_id,
 543                                        ..
 544                                    } = contact
 545                                    {
 546                                        for contact_conn_id in
 547                                            pool.user_connection_ids(contact_user_id)
 548                                        {
 549                                            peer.send(
 550                                                contact_conn_id,
 551                                                proto::UpdateContacts {
 552                                                    contacts: vec![updated_contact.clone()],
 553                                                    remove_contacts: Default::default(),
 554                                                    incoming_requests: Default::default(),
 555                                                    remove_incoming_requests: Default::default(),
 556                                                    outgoing_requests: Default::default(),
 557                                                    remove_outgoing_requests: Default::default(),
 558                                                },
 559                                            )
 560                                            .trace_err();
 561                                        }
 562                                    }
 563                                }
 564                            }
 565                        }
 566
 567                        if let Some(live_kit) = livekit_client.as_ref() {
 568                            if delete_livekit_room {
 569                                live_kit.delete_room(livekit_room).await.trace_err();
 570                            }
 571                        }
 572                    }
 573                }
 574
 575                app_state
 576                    .db
 577                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 578                    .await
 579                    .trace_err();
 580            }
 581            .instrument(span),
 582        );
 583        Ok(())
 584    }
 585
 586    pub fn teardown(&self) {
 587        self.peer.teardown();
 588        self.connection_pool.lock().reset();
 589        let _ = self.teardown.send(true);
 590    }
 591
 592    #[cfg(test)]
 593    pub fn reset(&self, id: ServerId) {
 594        self.teardown();
 595        *self.id.lock() = id;
 596        self.peer.reset(id.0 as u32);
 597        let _ = self.teardown.send(false);
 598    }
 599
 600    #[cfg(test)]
 601    pub fn id(&self) -> ServerId {
 602        *self.id.lock()
 603    }
 604
 605    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 606    where
 607        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 608        Fut: 'static + Send + Future<Output = Result<()>>,
 609        M: EnvelopedMessage,
 610    {
 611        let prev_handler = self.handlers.insert(
 612            TypeId::of::<M>(),
 613            Box::new(move |envelope, session| {
 614                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 615                let received_at = envelope.received_at;
 616                tracing::info!("message received");
 617                let start_time = Instant::now();
 618                let future = (handler)(*envelope, session);
 619                async move {
 620                    let result = future.await;
 621                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 622                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 623                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 624                    let payload_type = M::NAME;
 625
 626                    match result {
 627                        Err(error) => {
 628                            tracing::error!(
 629                                ?error,
 630                                total_duration_ms,
 631                                processing_duration_ms,
 632                                queue_duration_ms,
 633                                payload_type,
 634                                "error handling message"
 635                            )
 636                        }
 637                        Ok(()) => tracing::info!(
 638                            total_duration_ms,
 639                            processing_duration_ms,
 640                            queue_duration_ms,
 641                            "finished handling message"
 642                        ),
 643                    }
 644                }
 645                .boxed()
 646            }),
 647        );
 648        if prev_handler.is_some() {
 649            panic!("registered a handler for the same message twice");
 650        }
 651        self
 652    }
 653
 654    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 655    where
 656        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 657        Fut: 'static + Send + Future<Output = Result<()>>,
 658        M: EnvelopedMessage,
 659    {
 660        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 661        self
 662    }
 663
 664    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 665    where
 666        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 667        Fut: Send + Future<Output = Result<()>>,
 668        M: RequestMessage,
 669    {
 670        let handler = Arc::new(handler);
 671        self.add_handler(move |envelope, session| {
 672            let receipt = envelope.receipt();
 673            let handler = handler.clone();
 674            async move {
 675                let peer = session.peer.clone();
 676                let responded = Arc::new(AtomicBool::default());
 677                let response = Response {
 678                    peer: peer.clone(),
 679                    responded: responded.clone(),
 680                    receipt,
 681                };
 682                match (handler)(envelope.payload, response, session).await {
 683                    Ok(()) => {
 684                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 685                            Ok(())
 686                        } else {
 687                            Err(anyhow!("handler did not send a response"))?
 688                        }
 689                    }
 690                    Err(error) => {
 691                        let proto_err = match &error {
 692                            Error::Internal(err) => err.to_proto(),
 693                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 694                        };
 695                        peer.respond_with_error(receipt, proto_err)?;
 696                        Err(error)
 697                    }
 698                }
 699            }
 700        })
 701    }
 702
 703    pub fn handle_connection(
 704        self: &Arc<Self>,
 705        connection: Connection,
 706        address: String,
 707        principal: Principal,
 708        zed_version: ZedVersion,
 709        geoip_country_code: Option<String>,
 710        system_id: Option<String>,
 711        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 712        executor: Executor,
 713    ) -> impl Future<Output = ()> {
 714        let this = self.clone();
 715        let span = info_span!("handle connection", %address,
 716            connection_id=field::Empty,
 717            user_id=field::Empty,
 718            login=field::Empty,
 719            impersonator=field::Empty,
 720            geoip_country_code=field::Empty
 721        );
 722        principal.update_span(&span);
 723        if let Some(country_code) = geoip_country_code.as_ref() {
 724            span.record("geoip_country_code", country_code);
 725        }
 726
 727        let mut teardown = self.teardown.subscribe();
 728        async move {
 729            if *teardown.borrow() {
 730                tracing::error!("server is tearing down");
 731                return
 732            }
 733            let (connection_id, handle_io, mut incoming_rx) = this
 734                .peer
 735                .add_connection(connection, {
 736                    let executor = executor.clone();
 737                    move |duration| executor.sleep(duration)
 738                });
 739            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 740
 741            tracing::info!("connection opened");
 742
 743            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 744            let http_client = match ReqwestClient::user_agent(&user_agent) {
 745                Ok(http_client) => Arc::new(http_client),
 746                Err(error) => {
 747                    tracing::error!(?error, "failed to create HTTP client");
 748                    return;
 749                }
 750            };
 751
 752            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 753                    supermaven_admin_api_key.to_string(),
 754                    http_client.clone(),
 755                )));
 756
 757            let session = Session {
 758                principal: principal.clone(),
 759                connection_id,
 760                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 761                peer: this.peer.clone(),
 762                connection_pool: this.connection_pool.clone(),
 763                app_state: this.app_state.clone(),
 764                http_client,
 765                geoip_country_code,
 766                system_id,
 767                _executor: executor.clone(),
 768                supermaven_client,
 769            };
 770
 771            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 772                tracing::error!(?error, "failed to send initial client update");
 773                return;
 774            }
 775
 776            let handle_io = handle_io.fuse();
 777            futures::pin_mut!(handle_io);
 778
 779            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 780            // This prevents deadlocks when e.g., client A performs a request to client B and
 781            // client B performs a request to client A. If both clients stop processing further
 782            // messages until their respective request completes, they won't have a chance to
 783            // respond to the other client's request and cause a deadlock.
 784            //
 785            // This arrangement ensures we will attempt to process earlier messages first, but fall
 786            // back to processing messages arrived later in the spirit of making progress.
 787            let mut foreground_message_handlers = FuturesUnordered::new();
 788            let concurrent_handlers = Arc::new(Semaphore::new(256));
 789            loop {
 790                let next_message = async {
 791                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 792                    let message = incoming_rx.next().await;
 793                    (permit, message)
 794                }.fuse();
 795                futures::pin_mut!(next_message);
 796                futures::select_biased! {
 797                    _ = teardown.changed().fuse() => return,
 798                    result = handle_io => {
 799                        if let Err(error) = result {
 800                            tracing::error!(?error, "error handling I/O");
 801                        }
 802                        break;
 803                    }
 804                    _ = foreground_message_handlers.next() => {}
 805                    next_message = next_message => {
 806                        let (permit, message) = next_message;
 807                        if let Some(message) = message {
 808                            let type_name = message.payload_type_name();
 809                            // note: we copy all the fields from the parent span so we can query them in the logs.
 810                            // (https://github.com/tokio-rs/tracing/issues/2670).
 811                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 812                                user_id=field::Empty,
 813                                login=field::Empty,
 814                                impersonator=field::Empty,
 815                            );
 816                            principal.update_span(&span);
 817                            let span_enter = span.enter();
 818                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 819                                let is_background = message.is_background();
 820                                let handle_message = (handler)(message, session.clone());
 821                                drop(span_enter);
 822
 823                                let handle_message = async move {
 824                                    handle_message.await;
 825                                    drop(permit);
 826                                }.instrument(span);
 827                                if is_background {
 828                                    executor.spawn_detached(handle_message);
 829                                } else {
 830                                    foreground_message_handlers.push(handle_message);
 831                                }
 832                            } else {
 833                                tracing::error!("no message handler");
 834                            }
 835                        } else {
 836                            tracing::info!("connection closed");
 837                            break;
 838                        }
 839                    }
 840                }
 841            }
 842
 843            drop(foreground_message_handlers);
 844            tracing::info!("signing out");
 845            if let Err(error) = connection_lost(session, teardown, executor).await {
 846                tracing::error!(?error, "error signing out");
 847            }
 848
 849        }.instrument(span)
 850    }
 851
 852    async fn send_initial_client_update(
 853        &self,
 854        connection_id: ConnectionId,
 855        principal: &Principal,
 856        zed_version: ZedVersion,
 857        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 858        session: &Session,
 859    ) -> Result<()> {
 860        self.peer.send(
 861            connection_id,
 862            proto::Hello {
 863                peer_id: Some(connection_id.into()),
 864            },
 865        )?;
 866        tracing::info!("sent hello message");
 867        if let Some(send_connection_id) = send_connection_id.take() {
 868            let _ = send_connection_id.send(connection_id);
 869        }
 870
 871        match principal {
 872            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 873                if !user.connected_once {
 874                    self.peer.send(connection_id, proto::ShowContacts {})?;
 875                    self.app_state
 876                        .db
 877                        .set_user_connected_once(user.id, true)
 878                        .await?;
 879                }
 880
 881                update_user_plan(user.id, session).await?;
 882
 883                let contacts = self.app_state.db.get_contacts(user.id).await?;
 884
 885                {
 886                    let mut pool = self.connection_pool.lock();
 887                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 888                    self.peer.send(
 889                        connection_id,
 890                        build_initial_contacts_update(contacts, &pool),
 891                    )?;
 892                }
 893
 894                if should_auto_subscribe_to_channels(zed_version) {
 895                    subscribe_user_to_channels(user.id, session).await?;
 896                }
 897
 898                if let Some(incoming_call) =
 899                    self.app_state.db.incoming_call_for_user(user.id).await?
 900                {
 901                    self.peer.send(connection_id, incoming_call)?;
 902                }
 903
 904                update_user_contacts(user.id, session).await?;
 905            }
 906        }
 907
 908        Ok(())
 909    }
 910
 911    pub async fn invite_code_redeemed(
 912        self: &Arc<Self>,
 913        inviter_id: UserId,
 914        invitee_id: UserId,
 915    ) -> Result<()> {
 916        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 917            if let Some(code) = &user.invite_code {
 918                let pool = self.connection_pool.lock();
 919                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 920                for connection_id in pool.user_connection_ids(inviter_id) {
 921                    self.peer.send(
 922                        connection_id,
 923                        proto::UpdateContacts {
 924                            contacts: vec![invitee_contact.clone()],
 925                            ..Default::default()
 926                        },
 927                    )?;
 928                    self.peer.send(
 929                        connection_id,
 930                        proto::UpdateInviteInfo {
 931                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 932                            count: user.invite_count as u32,
 933                        },
 934                    )?;
 935                }
 936            }
 937        }
 938        Ok(())
 939    }
 940
 941    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 942        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 943            if let Some(invite_code) = &user.invite_code {
 944                let pool = self.connection_pool.lock();
 945                for connection_id in pool.user_connection_ids(user_id) {
 946                    self.peer.send(
 947                        connection_id,
 948                        proto::UpdateInviteInfo {
 949                            url: format!(
 950                                "{}{}",
 951                                self.app_state.config.invite_link_prefix, invite_code
 952                            ),
 953                            count: user.invite_count as u32,
 954                        },
 955                    )?;
 956                }
 957            }
 958        }
 959        Ok(())
 960    }
 961
 962    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 963        let pool = self.connection_pool.lock();
 964        for connection_id in pool.user_connection_ids(user_id) {
 965            self.peer
 966                .send(connection_id, proto::RefreshLlmToken {})
 967                .trace_err();
 968        }
 969    }
 970
 971    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 972        ServerSnapshot {
 973            connection_pool: ConnectionPoolGuard {
 974                guard: self.connection_pool.lock(),
 975                _not_send: PhantomData,
 976            },
 977            peer: &self.peer,
 978        }
 979    }
 980}
 981
 982impl Deref for ConnectionPoolGuard<'_> {
 983    type Target = ConnectionPool;
 984
 985    fn deref(&self) -> &Self::Target {
 986        &self.guard
 987    }
 988}
 989
 990impl DerefMut for ConnectionPoolGuard<'_> {
 991    fn deref_mut(&mut self) -> &mut Self::Target {
 992        &mut self.guard
 993    }
 994}
 995
 996impl Drop for ConnectionPoolGuard<'_> {
 997    fn drop(&mut self) {
 998        #[cfg(test)]
 999        self.check_invariants();
1000    }
1001}
1002
1003fn broadcast<F>(
1004    sender_id: Option<ConnectionId>,
1005    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1006    mut f: F,
1007) where
1008    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1009{
1010    for receiver_id in receiver_ids {
1011        if Some(receiver_id) != sender_id {
1012            if let Err(error) = f(receiver_id) {
1013                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1014            }
1015        }
1016    }
1017}
1018
1019pub struct ProtocolVersion(u32);
1020
1021impl Header for ProtocolVersion {
1022    fn name() -> &'static HeaderName {
1023        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1024        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1025    }
1026
1027    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1028    where
1029        Self: Sized,
1030        I: Iterator<Item = &'i axum::http::HeaderValue>,
1031    {
1032        let version = values
1033            .next()
1034            .ok_or_else(axum::headers::Error::invalid)?
1035            .to_str()
1036            .map_err(|_| axum::headers::Error::invalid())?
1037            .parse()
1038            .map_err(|_| axum::headers::Error::invalid())?;
1039        Ok(Self(version))
1040    }
1041
1042    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1043        values.extend([self.0.to_string().parse().unwrap()]);
1044    }
1045}
1046
1047pub struct AppVersionHeader(SemanticVersion);
1048impl Header for AppVersionHeader {
1049    fn name() -> &'static HeaderName {
1050        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1051        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1052    }
1053
1054    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1055    where
1056        Self: Sized,
1057        I: Iterator<Item = &'i axum::http::HeaderValue>,
1058    {
1059        let version = values
1060            .next()
1061            .ok_or_else(axum::headers::Error::invalid)?
1062            .to_str()
1063            .map_err(|_| axum::headers::Error::invalid())?
1064            .parse()
1065            .map_err(|_| axum::headers::Error::invalid())?;
1066        Ok(Self(version))
1067    }
1068
1069    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1070        values.extend([self.0.to_string().parse().unwrap()]);
1071    }
1072}
1073
1074pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1075    Router::new()
1076        .route("/rpc", get(handle_websocket_request))
1077        .layer(
1078            ServiceBuilder::new()
1079                .layer(Extension(server.app_state.clone()))
1080                .layer(middleware::from_fn(auth::validate_header)),
1081        )
1082        .route("/metrics", get(handle_metrics))
1083        .layer(Extension(server))
1084}
1085
1086pub async fn handle_websocket_request(
1087    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1088    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1089    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1090    Extension(server): Extension<Arc<Server>>,
1091    Extension(principal): Extension<Principal>,
1092    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1093    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1094    ws: WebSocketUpgrade,
1095) -> axum::response::Response {
1096    if protocol_version != rpc::PROTOCOL_VERSION {
1097        return (
1098            StatusCode::UPGRADE_REQUIRED,
1099            "client must be upgraded".to_string(),
1100        )
1101            .into_response();
1102    }
1103
1104    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1105        return (
1106            StatusCode::UPGRADE_REQUIRED,
1107            "no version header found".to_string(),
1108        )
1109            .into_response();
1110    };
1111
1112    if !version.can_collaborate() {
1113        return (
1114            StatusCode::UPGRADE_REQUIRED,
1115            "client must be upgraded".to_string(),
1116        )
1117            .into_response();
1118    }
1119
1120    let socket_address = socket_address.to_string();
1121    ws.on_upgrade(move |socket| {
1122        let socket = socket
1123            .map_ok(to_tungstenite_message)
1124            .err_into()
1125            .with(|message| async move { to_axum_message(message) });
1126        let connection = Connection::new(Box::pin(socket));
1127        async move {
1128            server
1129                .handle_connection(
1130                    connection,
1131                    socket_address,
1132                    principal,
1133                    version,
1134                    country_code_header.map(|header| header.to_string()),
1135                    system_id_header.map(|header| header.to_string()),
1136                    None,
1137                    Executor::Production,
1138                )
1139                .await;
1140        }
1141    })
1142}
1143
1144pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1145    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1146    let connections_metric = CONNECTIONS_METRIC
1147        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1148
1149    let connections = server
1150        .connection_pool
1151        .lock()
1152        .connections()
1153        .filter(|connection| !connection.admin)
1154        .count();
1155    connections_metric.set(connections as _);
1156
1157    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1158    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1159        register_int_gauge!(
1160            "shared_projects",
1161            "number of open projects with one or more guests"
1162        )
1163        .unwrap()
1164    });
1165
1166    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1167    shared_projects_metric.set(shared_projects as _);
1168
1169    let encoder = prometheus::TextEncoder::new();
1170    let metric_families = prometheus::gather();
1171    let encoded_metrics = encoder
1172        .encode_to_string(&metric_families)
1173        .map_err(|err| anyhow!("{}", err))?;
1174    Ok(encoded_metrics)
1175}
1176
1177#[instrument(err, skip(executor))]
1178async fn connection_lost(
1179    session: Session,
1180    mut teardown: watch::Receiver<bool>,
1181    executor: Executor,
1182) -> Result<()> {
1183    session.peer.disconnect(session.connection_id);
1184    session
1185        .connection_pool()
1186        .await
1187        .remove_connection(session.connection_id)?;
1188
1189    session
1190        .db()
1191        .await
1192        .connection_lost(session.connection_id)
1193        .await
1194        .trace_err();
1195
1196    futures::select_biased! {
1197        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1198
1199            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1200            leave_room_for_session(&session, session.connection_id).await.trace_err();
1201            leave_channel_buffers_for_session(&session)
1202                .await
1203                .trace_err();
1204
1205            if !session
1206                .connection_pool()
1207                .await
1208                .is_user_online(session.user_id())
1209            {
1210                let db = session.db().await;
1211                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1212                    room_updated(&room, &session.peer);
1213                }
1214            }
1215
1216            update_user_contacts(session.user_id(), &session).await?;
1217        },
1218        _ = teardown.changed().fuse() => {}
1219    }
1220
1221    Ok(())
1222}
1223
1224/// Acknowledges a ping from a client, used to keep the connection alive.
1225async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1226    response.send(proto::Ack {})?;
1227    Ok(())
1228}
1229
1230/// Creates a new room for calling (outside of channels)
1231async fn create_room(
1232    _request: proto::CreateRoom,
1233    response: Response<proto::CreateRoom>,
1234    session: Session,
1235) -> Result<()> {
1236    let livekit_room = nanoid::nanoid!(30);
1237
1238    let live_kit_connection_info = util::maybe!(async {
1239        let live_kit = session.app_state.livekit_client.as_ref();
1240        let live_kit = live_kit?;
1241        let user_id = session.user_id().to_string();
1242
1243        let token = live_kit
1244            .room_token(&livekit_room, &user_id.to_string())
1245            .trace_err()?;
1246
1247        Some(proto::LiveKitConnectionInfo {
1248            server_url: live_kit.url().into(),
1249            token,
1250            can_publish: true,
1251        })
1252    })
1253    .await;
1254
1255    let room = session
1256        .db()
1257        .await
1258        .create_room(session.user_id(), session.connection_id, &livekit_room)
1259        .await?;
1260
1261    response.send(proto::CreateRoomResponse {
1262        room: Some(room.clone()),
1263        live_kit_connection_info,
1264    })?;
1265
1266    update_user_contacts(session.user_id(), &session).await?;
1267    Ok(())
1268}
1269
1270/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1271async fn join_room(
1272    request: proto::JoinRoom,
1273    response: Response<proto::JoinRoom>,
1274    session: Session,
1275) -> Result<()> {
1276    let room_id = RoomId::from_proto(request.id);
1277
1278    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1279
1280    if let Some(channel_id) = channel_id {
1281        return join_channel_internal(channel_id, Box::new(response), session).await;
1282    }
1283
1284    let joined_room = {
1285        let room = session
1286            .db()
1287            .await
1288            .join_room(room_id, session.user_id(), session.connection_id)
1289            .await?;
1290        room_updated(&room.room, &session.peer);
1291        room.into_inner()
1292    };
1293
1294    for connection_id in session
1295        .connection_pool()
1296        .await
1297        .user_connection_ids(session.user_id())
1298    {
1299        session
1300            .peer
1301            .send(
1302                connection_id,
1303                proto::CallCanceled {
1304                    room_id: room_id.to_proto(),
1305                },
1306            )
1307            .trace_err();
1308    }
1309
1310    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1311    {
1312        live_kit
1313            .room_token(
1314                &joined_room.room.livekit_room,
1315                &session.user_id().to_string(),
1316            )
1317            .trace_err()
1318            .map(|token| proto::LiveKitConnectionInfo {
1319                server_url: live_kit.url().into(),
1320                token,
1321                can_publish: true,
1322            })
1323    } else {
1324        None
1325    };
1326
1327    response.send(proto::JoinRoomResponse {
1328        room: Some(joined_room.room),
1329        channel_id: None,
1330        live_kit_connection_info,
1331    })?;
1332
1333    update_user_contacts(session.user_id(), &session).await?;
1334    Ok(())
1335}
1336
1337/// Rejoin room is used to reconnect to a room after connection errors.
1338async fn rejoin_room(
1339    request: proto::RejoinRoom,
1340    response: Response<proto::RejoinRoom>,
1341    session: Session,
1342) -> Result<()> {
1343    let room;
1344    let channel;
1345    {
1346        let mut rejoined_room = session
1347            .db()
1348            .await
1349            .rejoin_room(request, session.user_id(), session.connection_id)
1350            .await?;
1351
1352        response.send(proto::RejoinRoomResponse {
1353            room: Some(rejoined_room.room.clone()),
1354            reshared_projects: rejoined_room
1355                .reshared_projects
1356                .iter()
1357                .map(|project| proto::ResharedProject {
1358                    id: project.id.to_proto(),
1359                    collaborators: project
1360                        .collaborators
1361                        .iter()
1362                        .map(|collaborator| collaborator.to_proto())
1363                        .collect(),
1364                })
1365                .collect(),
1366            rejoined_projects: rejoined_room
1367                .rejoined_projects
1368                .iter()
1369                .map(|rejoined_project| rejoined_project.to_proto())
1370                .collect(),
1371        })?;
1372        room_updated(&rejoined_room.room, &session.peer);
1373
1374        for project in &rejoined_room.reshared_projects {
1375            for collaborator in &project.collaborators {
1376                session
1377                    .peer
1378                    .send(
1379                        collaborator.connection_id,
1380                        proto::UpdateProjectCollaborator {
1381                            project_id: project.id.to_proto(),
1382                            old_peer_id: Some(project.old_connection_id.into()),
1383                            new_peer_id: Some(session.connection_id.into()),
1384                        },
1385                    )
1386                    .trace_err();
1387            }
1388
1389            broadcast(
1390                Some(session.connection_id),
1391                project
1392                    .collaborators
1393                    .iter()
1394                    .map(|collaborator| collaborator.connection_id),
1395                |connection_id| {
1396                    session.peer.forward_send(
1397                        session.connection_id,
1398                        connection_id,
1399                        proto::UpdateProject {
1400                            project_id: project.id.to_proto(),
1401                            worktrees: project.worktrees.clone(),
1402                        },
1403                    )
1404                },
1405            );
1406        }
1407
1408        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1409
1410        let rejoined_room = rejoined_room.into_inner();
1411
1412        room = rejoined_room.room;
1413        channel = rejoined_room.channel;
1414    }
1415
1416    if let Some(channel) = channel {
1417        channel_updated(
1418            &channel,
1419            &room,
1420            &session.peer,
1421            &*session.connection_pool().await,
1422        );
1423    }
1424
1425    update_user_contacts(session.user_id(), &session).await?;
1426    Ok(())
1427}
1428
1429fn notify_rejoined_projects(
1430    rejoined_projects: &mut Vec<RejoinedProject>,
1431    session: &Session,
1432) -> Result<()> {
1433    for project in rejoined_projects.iter() {
1434        for collaborator in &project.collaborators {
1435            session
1436                .peer
1437                .send(
1438                    collaborator.connection_id,
1439                    proto::UpdateProjectCollaborator {
1440                        project_id: project.id.to_proto(),
1441                        old_peer_id: Some(project.old_connection_id.into()),
1442                        new_peer_id: Some(session.connection_id.into()),
1443                    },
1444                )
1445                .trace_err();
1446        }
1447    }
1448
1449    for project in rejoined_projects {
1450        for worktree in mem::take(&mut project.worktrees) {
1451            // Stream this worktree's entries.
1452            let message = proto::UpdateWorktree {
1453                project_id: project.id.to_proto(),
1454                worktree_id: worktree.id,
1455                abs_path: worktree.abs_path.clone(),
1456                root_name: worktree.root_name,
1457                updated_entries: worktree.updated_entries,
1458                removed_entries: worktree.removed_entries,
1459                scan_id: worktree.scan_id,
1460                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1461                updated_repositories: worktree.updated_repositories,
1462                removed_repositories: worktree.removed_repositories,
1463            };
1464            for update in proto::split_worktree_update(message) {
1465                session.peer.send(session.connection_id, update.clone())?;
1466            }
1467
1468            // Stream this worktree's diagnostics.
1469            for summary in worktree.diagnostic_summaries {
1470                session.peer.send(
1471                    session.connection_id,
1472                    proto::UpdateDiagnosticSummary {
1473                        project_id: project.id.to_proto(),
1474                        worktree_id: worktree.id,
1475                        summary: Some(summary),
1476                    },
1477                )?;
1478            }
1479
1480            for settings_file in worktree.settings_files {
1481                session.peer.send(
1482                    session.connection_id,
1483                    proto::UpdateWorktreeSettings {
1484                        project_id: project.id.to_proto(),
1485                        worktree_id: worktree.id,
1486                        path: settings_file.path,
1487                        content: Some(settings_file.content),
1488                        kind: Some(settings_file.kind.to_proto().into()),
1489                    },
1490                )?;
1491            }
1492        }
1493
1494        for language_server in &project.language_servers {
1495            session.peer.send(
1496                session.connection_id,
1497                proto::UpdateLanguageServer {
1498                    project_id: project.id.to_proto(),
1499                    language_server_id: language_server.id,
1500                    variant: Some(
1501                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1502                            proto::LspDiskBasedDiagnosticsUpdated {},
1503                        ),
1504                    ),
1505                },
1506            )?;
1507        }
1508    }
1509    Ok(())
1510}
1511
1512/// leave room disconnects from the room.
1513async fn leave_room(
1514    _: proto::LeaveRoom,
1515    response: Response<proto::LeaveRoom>,
1516    session: Session,
1517) -> Result<()> {
1518    leave_room_for_session(&session, session.connection_id).await?;
1519    response.send(proto::Ack {})?;
1520    Ok(())
1521}
1522
1523/// Updates the permissions of someone else in the room.
1524async fn set_room_participant_role(
1525    request: proto::SetRoomParticipantRole,
1526    response: Response<proto::SetRoomParticipantRole>,
1527    session: Session,
1528) -> Result<()> {
1529    let user_id = UserId::from_proto(request.user_id);
1530    let role = ChannelRole::from(request.role());
1531
1532    let (livekit_room, can_publish) = {
1533        let room = session
1534            .db()
1535            .await
1536            .set_room_participant_role(
1537                session.user_id(),
1538                RoomId::from_proto(request.room_id),
1539                user_id,
1540                role,
1541            )
1542            .await?;
1543
1544        let livekit_room = room.livekit_room.clone();
1545        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1546        room_updated(&room, &session.peer);
1547        (livekit_room, can_publish)
1548    };
1549
1550    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1551        live_kit
1552            .update_participant(
1553                livekit_room.clone(),
1554                request.user_id.to_string(),
1555                livekit_api::proto::ParticipantPermission {
1556                    can_subscribe: true,
1557                    can_publish,
1558                    can_publish_data: can_publish,
1559                    hidden: false,
1560                    recorder: false,
1561                },
1562            )
1563            .await
1564            .trace_err();
1565    }
1566
1567    response.send(proto::Ack {})?;
1568    Ok(())
1569}
1570
1571/// Call someone else into the current room
1572async fn call(
1573    request: proto::Call,
1574    response: Response<proto::Call>,
1575    session: Session,
1576) -> Result<()> {
1577    let room_id = RoomId::from_proto(request.room_id);
1578    let calling_user_id = session.user_id();
1579    let calling_connection_id = session.connection_id;
1580    let called_user_id = UserId::from_proto(request.called_user_id);
1581    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1582    if !session
1583        .db()
1584        .await
1585        .has_contact(calling_user_id, called_user_id)
1586        .await?
1587    {
1588        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1589    }
1590
1591    let incoming_call = {
1592        let (room, incoming_call) = &mut *session
1593            .db()
1594            .await
1595            .call(
1596                room_id,
1597                calling_user_id,
1598                calling_connection_id,
1599                called_user_id,
1600                initial_project_id,
1601            )
1602            .await?;
1603        room_updated(room, &session.peer);
1604        mem::take(incoming_call)
1605    };
1606    update_user_contacts(called_user_id, &session).await?;
1607
1608    let mut calls = session
1609        .connection_pool()
1610        .await
1611        .user_connection_ids(called_user_id)
1612        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1613        .collect::<FuturesUnordered<_>>();
1614
1615    while let Some(call_response) = calls.next().await {
1616        match call_response.as_ref() {
1617            Ok(_) => {
1618                response.send(proto::Ack {})?;
1619                return Ok(());
1620            }
1621            Err(_) => {
1622                call_response.trace_err();
1623            }
1624        }
1625    }
1626
1627    {
1628        let room = session
1629            .db()
1630            .await
1631            .call_failed(room_id, called_user_id)
1632            .await?;
1633        room_updated(&room, &session.peer);
1634    }
1635    update_user_contacts(called_user_id, &session).await?;
1636
1637    Err(anyhow!("failed to ring user"))?
1638}
1639
1640/// Cancel an outgoing call.
1641async fn cancel_call(
1642    request: proto::CancelCall,
1643    response: Response<proto::CancelCall>,
1644    session: Session,
1645) -> Result<()> {
1646    let called_user_id = UserId::from_proto(request.called_user_id);
1647    let room_id = RoomId::from_proto(request.room_id);
1648    {
1649        let room = session
1650            .db()
1651            .await
1652            .cancel_call(room_id, session.connection_id, called_user_id)
1653            .await?;
1654        room_updated(&room, &session.peer);
1655    }
1656
1657    for connection_id in session
1658        .connection_pool()
1659        .await
1660        .user_connection_ids(called_user_id)
1661    {
1662        session
1663            .peer
1664            .send(
1665                connection_id,
1666                proto::CallCanceled {
1667                    room_id: room_id.to_proto(),
1668                },
1669            )
1670            .trace_err();
1671    }
1672    response.send(proto::Ack {})?;
1673
1674    update_user_contacts(called_user_id, &session).await?;
1675    Ok(())
1676}
1677
1678/// Decline an incoming call.
1679async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1680    let room_id = RoomId::from_proto(message.room_id);
1681    {
1682        let room = session
1683            .db()
1684            .await
1685            .decline_call(Some(room_id), session.user_id())
1686            .await?
1687            .ok_or_else(|| anyhow!("failed to decline call"))?;
1688        room_updated(&room, &session.peer);
1689    }
1690
1691    for connection_id in session
1692        .connection_pool()
1693        .await
1694        .user_connection_ids(session.user_id())
1695    {
1696        session
1697            .peer
1698            .send(
1699                connection_id,
1700                proto::CallCanceled {
1701                    room_id: room_id.to_proto(),
1702                },
1703            )
1704            .trace_err();
1705    }
1706    update_user_contacts(session.user_id(), &session).await?;
1707    Ok(())
1708}
1709
1710/// Updates other participants in the room with your current location.
1711async fn update_participant_location(
1712    request: proto::UpdateParticipantLocation,
1713    response: Response<proto::UpdateParticipantLocation>,
1714    session: Session,
1715) -> Result<()> {
1716    let room_id = RoomId::from_proto(request.room_id);
1717    let location = request
1718        .location
1719        .ok_or_else(|| anyhow!("invalid location"))?;
1720
1721    let db = session.db().await;
1722    let room = db
1723        .update_room_participant_location(room_id, session.connection_id, location)
1724        .await?;
1725
1726    room_updated(&room, &session.peer);
1727    response.send(proto::Ack {})?;
1728    Ok(())
1729}
1730
1731/// Share a project into the room.
1732async fn share_project(
1733    request: proto::ShareProject,
1734    response: Response<proto::ShareProject>,
1735    session: Session,
1736) -> Result<()> {
1737    let (project_id, room) = &*session
1738        .db()
1739        .await
1740        .share_project(
1741            RoomId::from_proto(request.room_id),
1742            session.connection_id,
1743            &request.worktrees,
1744            request.is_ssh_project,
1745        )
1746        .await?;
1747    response.send(proto::ShareProjectResponse {
1748        project_id: project_id.to_proto(),
1749    })?;
1750    room_updated(room, &session.peer);
1751
1752    Ok(())
1753}
1754
1755/// Unshare a project from the room.
1756async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1757    let project_id = ProjectId::from_proto(message.project_id);
1758    unshare_project_internal(project_id, session.connection_id, &session).await
1759}
1760
1761async fn unshare_project_internal(
1762    project_id: ProjectId,
1763    connection_id: ConnectionId,
1764    session: &Session,
1765) -> Result<()> {
1766    let delete = {
1767        let room_guard = session
1768            .db()
1769            .await
1770            .unshare_project(project_id, connection_id)
1771            .await?;
1772
1773        let (delete, room, guest_connection_ids) = &*room_guard;
1774
1775        let message = proto::UnshareProject {
1776            project_id: project_id.to_proto(),
1777        };
1778
1779        broadcast(
1780            Some(connection_id),
1781            guest_connection_ids.iter().copied(),
1782            |conn_id| session.peer.send(conn_id, message.clone()),
1783        );
1784        if let Some(room) = room {
1785            room_updated(room, &session.peer);
1786        }
1787
1788        *delete
1789    };
1790
1791    if delete {
1792        let db = session.db().await;
1793        db.delete_project(project_id).await?;
1794    }
1795
1796    Ok(())
1797}
1798
1799/// Join someone elses shared project.
1800async fn join_project(
1801    request: proto::JoinProject,
1802    response: Response<proto::JoinProject>,
1803    session: Session,
1804) -> Result<()> {
1805    let project_id = ProjectId::from_proto(request.project_id);
1806
1807    tracing::info!(%project_id, "join project");
1808
1809    let db = session.db().await;
1810    let (project, replica_id) = &mut *db
1811        .join_project(project_id, session.connection_id, session.user_id())
1812        .await?;
1813    drop(db);
1814    tracing::info!(%project_id, "join remote project");
1815    join_project_internal(response, session, project, replica_id)
1816}
1817
1818trait JoinProjectInternalResponse {
1819    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1820}
1821impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1822    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1823        Response::<proto::JoinProject>::send(self, result)
1824    }
1825}
1826
1827fn join_project_internal(
1828    response: impl JoinProjectInternalResponse,
1829    session: Session,
1830    project: &mut Project,
1831    replica_id: &ReplicaId,
1832) -> Result<()> {
1833    let collaborators = project
1834        .collaborators
1835        .iter()
1836        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1837        .map(|collaborator| collaborator.to_proto())
1838        .collect::<Vec<_>>();
1839    let project_id = project.id;
1840    let guest_user_id = session.user_id();
1841
1842    let worktrees = project
1843        .worktrees
1844        .iter()
1845        .map(|(id, worktree)| proto::WorktreeMetadata {
1846            id: *id,
1847            root_name: worktree.root_name.clone(),
1848            visible: worktree.visible,
1849            abs_path: worktree.abs_path.clone(),
1850        })
1851        .collect::<Vec<_>>();
1852
1853    let add_project_collaborator = proto::AddProjectCollaborator {
1854        project_id: project_id.to_proto(),
1855        collaborator: Some(proto::Collaborator {
1856            peer_id: Some(session.connection_id.into()),
1857            replica_id: replica_id.0 as u32,
1858            user_id: guest_user_id.to_proto(),
1859            is_host: false,
1860        }),
1861    };
1862
1863    for collaborator in &collaborators {
1864        session
1865            .peer
1866            .send(
1867                collaborator.peer_id.unwrap().into(),
1868                add_project_collaborator.clone(),
1869            )
1870            .trace_err();
1871    }
1872
1873    // First, we send the metadata associated with each worktree.
1874    response.send(proto::JoinProjectResponse {
1875        project_id: project.id.0 as u64,
1876        worktrees: worktrees.clone(),
1877        replica_id: replica_id.0 as u32,
1878        collaborators: collaborators.clone(),
1879        language_servers: project.language_servers.clone(),
1880        role: project.role.into(),
1881    })?;
1882
1883    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1884        // Stream this worktree's entries.
1885        let message = proto::UpdateWorktree {
1886            project_id: project_id.to_proto(),
1887            worktree_id,
1888            abs_path: worktree.abs_path.clone(),
1889            root_name: worktree.root_name,
1890            updated_entries: worktree.entries,
1891            removed_entries: Default::default(),
1892            scan_id: worktree.scan_id,
1893            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1894            updated_repositories: worktree.repository_entries.into_values().collect(),
1895            removed_repositories: Default::default(),
1896        };
1897        for update in proto::split_worktree_update(message) {
1898            session.peer.send(session.connection_id, update.clone())?;
1899        }
1900
1901        // Stream this worktree's diagnostics.
1902        for summary in worktree.diagnostic_summaries {
1903            session.peer.send(
1904                session.connection_id,
1905                proto::UpdateDiagnosticSummary {
1906                    project_id: project_id.to_proto(),
1907                    worktree_id: worktree.id,
1908                    summary: Some(summary),
1909                },
1910            )?;
1911        }
1912
1913        for settings_file in worktree.settings_files {
1914            session.peer.send(
1915                session.connection_id,
1916                proto::UpdateWorktreeSettings {
1917                    project_id: project_id.to_proto(),
1918                    worktree_id: worktree.id,
1919                    path: settings_file.path,
1920                    content: Some(settings_file.content),
1921                    kind: Some(settings_file.kind.to_proto() as i32),
1922                },
1923            )?;
1924        }
1925    }
1926
1927    for language_server in &project.language_servers {
1928        session.peer.send(
1929            session.connection_id,
1930            proto::UpdateLanguageServer {
1931                project_id: project_id.to_proto(),
1932                language_server_id: language_server.id,
1933                variant: Some(
1934                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1935                        proto::LspDiskBasedDiagnosticsUpdated {},
1936                    ),
1937                ),
1938            },
1939        )?;
1940    }
1941
1942    Ok(())
1943}
1944
1945/// Leave someone elses shared project.
1946async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1947    let sender_id = session.connection_id;
1948    let project_id = ProjectId::from_proto(request.project_id);
1949    let db = session.db().await;
1950
1951    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1952    tracing::info!(
1953        %project_id,
1954        "leave project"
1955    );
1956
1957    project_left(project, &session);
1958    if let Some(room) = room {
1959        room_updated(room, &session.peer);
1960    }
1961
1962    Ok(())
1963}
1964
1965/// Updates other participants with changes to the project
1966async fn update_project(
1967    request: proto::UpdateProject,
1968    response: Response<proto::UpdateProject>,
1969    session: Session,
1970) -> Result<()> {
1971    let project_id = ProjectId::from_proto(request.project_id);
1972    let (room, guest_connection_ids) = &*session
1973        .db()
1974        .await
1975        .update_project(project_id, session.connection_id, &request.worktrees)
1976        .await?;
1977    broadcast(
1978        Some(session.connection_id),
1979        guest_connection_ids.iter().copied(),
1980        |connection_id| {
1981            session
1982                .peer
1983                .forward_send(session.connection_id, connection_id, request.clone())
1984        },
1985    );
1986    if let Some(room) = room {
1987        room_updated(room, &session.peer);
1988    }
1989    response.send(proto::Ack {})?;
1990
1991    Ok(())
1992}
1993
1994/// Updates other participants with changes to the worktree
1995async fn update_worktree(
1996    request: proto::UpdateWorktree,
1997    response: Response<proto::UpdateWorktree>,
1998    session: Session,
1999) -> Result<()> {
2000    let guest_connection_ids = session
2001        .db()
2002        .await
2003        .update_worktree(&request, session.connection_id)
2004        .await?;
2005
2006    broadcast(
2007        Some(session.connection_id),
2008        guest_connection_ids.iter().copied(),
2009        |connection_id| {
2010            session
2011                .peer
2012                .forward_send(session.connection_id, connection_id, request.clone())
2013        },
2014    );
2015    response.send(proto::Ack {})?;
2016    Ok(())
2017}
2018
2019/// Updates other participants with changes to the diagnostics
2020async fn update_diagnostic_summary(
2021    message: proto::UpdateDiagnosticSummary,
2022    session: Session,
2023) -> Result<()> {
2024    let guest_connection_ids = session
2025        .db()
2026        .await
2027        .update_diagnostic_summary(&message, session.connection_id)
2028        .await?;
2029
2030    broadcast(
2031        Some(session.connection_id),
2032        guest_connection_ids.iter().copied(),
2033        |connection_id| {
2034            session
2035                .peer
2036                .forward_send(session.connection_id, connection_id, message.clone())
2037        },
2038    );
2039
2040    Ok(())
2041}
2042
2043/// Updates other participants with changes to the worktree settings
2044async fn update_worktree_settings(
2045    message: proto::UpdateWorktreeSettings,
2046    session: Session,
2047) -> Result<()> {
2048    let guest_connection_ids = session
2049        .db()
2050        .await
2051        .update_worktree_settings(&message, session.connection_id)
2052        .await?;
2053
2054    broadcast(
2055        Some(session.connection_id),
2056        guest_connection_ids.iter().copied(),
2057        |connection_id| {
2058            session
2059                .peer
2060                .forward_send(session.connection_id, connection_id, message.clone())
2061        },
2062    );
2063
2064    Ok(())
2065}
2066
2067/// Notify other participants that a  language server has started.
2068async fn start_language_server(
2069    request: proto::StartLanguageServer,
2070    session: Session,
2071) -> Result<()> {
2072    let guest_connection_ids = session
2073        .db()
2074        .await
2075        .start_language_server(&request, session.connection_id)
2076        .await?;
2077
2078    broadcast(
2079        Some(session.connection_id),
2080        guest_connection_ids.iter().copied(),
2081        |connection_id| {
2082            session
2083                .peer
2084                .forward_send(session.connection_id, connection_id, request.clone())
2085        },
2086    );
2087    Ok(())
2088}
2089
2090/// Notify other participants that a language server has changed.
2091async fn update_language_server(
2092    request: proto::UpdateLanguageServer,
2093    session: Session,
2094) -> Result<()> {
2095    let project_id = ProjectId::from_proto(request.project_id);
2096    let project_connection_ids = session
2097        .db()
2098        .await
2099        .project_connection_ids(project_id, session.connection_id, true)
2100        .await?;
2101    broadcast(
2102        Some(session.connection_id),
2103        project_connection_ids.iter().copied(),
2104        |connection_id| {
2105            session
2106                .peer
2107                .forward_send(session.connection_id, connection_id, request.clone())
2108        },
2109    );
2110    Ok(())
2111}
2112
2113/// forward a project request to the host. These requests should be read only
2114/// as guests are allowed to send them.
2115async fn forward_read_only_project_request<T>(
2116    request: T,
2117    response: Response<T>,
2118    session: Session,
2119) -> Result<()>
2120where
2121    T: EntityMessage + RequestMessage,
2122{
2123    let project_id = ProjectId::from_proto(request.remote_entity_id());
2124    let host_connection_id = session
2125        .db()
2126        .await
2127        .host_for_read_only_project_request(project_id, session.connection_id)
2128        .await?;
2129    let payload = session
2130        .peer
2131        .forward_request(session.connection_id, host_connection_id, request)
2132        .await?;
2133    response.send(payload)?;
2134    Ok(())
2135}
2136
2137async fn forward_find_search_candidates_request(
2138    request: proto::FindSearchCandidates,
2139    response: Response<proto::FindSearchCandidates>,
2140    session: Session,
2141) -> Result<()> {
2142    let project_id = ProjectId::from_proto(request.remote_entity_id());
2143    let host_connection_id = session
2144        .db()
2145        .await
2146        .host_for_read_only_project_request(project_id, session.connection_id)
2147        .await?;
2148    let payload = session
2149        .peer
2150        .forward_request(session.connection_id, host_connection_id, request)
2151        .await?;
2152    response.send(payload)?;
2153    Ok(())
2154}
2155
2156/// forward a project request to the host. These requests are disallowed
2157/// for guests.
2158async fn forward_mutating_project_request<T>(
2159    request: T,
2160    response: Response<T>,
2161    session: Session,
2162) -> Result<()>
2163where
2164    T: EntityMessage + RequestMessage,
2165{
2166    let project_id = ProjectId::from_proto(request.remote_entity_id());
2167
2168    let host_connection_id = session
2169        .db()
2170        .await
2171        .host_for_mutating_project_request(project_id, session.connection_id)
2172        .await?;
2173    let payload = session
2174        .peer
2175        .forward_request(session.connection_id, host_connection_id, request)
2176        .await?;
2177    response.send(payload)?;
2178    Ok(())
2179}
2180
2181/// Notify other participants that a new buffer has been created
2182async fn create_buffer_for_peer(
2183    request: proto::CreateBufferForPeer,
2184    session: Session,
2185) -> Result<()> {
2186    session
2187        .db()
2188        .await
2189        .check_user_is_project_host(
2190            ProjectId::from_proto(request.project_id),
2191            session.connection_id,
2192        )
2193        .await?;
2194    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2195    session
2196        .peer
2197        .forward_send(session.connection_id, peer_id.into(), request)?;
2198    Ok(())
2199}
2200
2201/// Notify other participants that a buffer has been updated. This is
2202/// allowed for guests as long as the update is limited to selections.
2203async fn update_buffer(
2204    request: proto::UpdateBuffer,
2205    response: Response<proto::UpdateBuffer>,
2206    session: Session,
2207) -> Result<()> {
2208    let project_id = ProjectId::from_proto(request.project_id);
2209    let mut capability = Capability::ReadOnly;
2210
2211    for op in request.operations.iter() {
2212        match op.variant {
2213            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2214            Some(_) => capability = Capability::ReadWrite,
2215        }
2216    }
2217
2218    let host = {
2219        let guard = session
2220            .db()
2221            .await
2222            .connections_for_buffer_update(project_id, session.connection_id, capability)
2223            .await?;
2224
2225        let (host, guests) = &*guard;
2226
2227        broadcast(
2228            Some(session.connection_id),
2229            guests.clone(),
2230            |connection_id| {
2231                session
2232                    .peer
2233                    .forward_send(session.connection_id, connection_id, request.clone())
2234            },
2235        );
2236
2237        *host
2238    };
2239
2240    if host != session.connection_id {
2241        session
2242            .peer
2243            .forward_request(session.connection_id, host, request.clone())
2244            .await?;
2245    }
2246
2247    response.send(proto::Ack {})?;
2248    Ok(())
2249}
2250
2251async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2252    let project_id = ProjectId::from_proto(message.project_id);
2253
2254    let operation = message.operation.as_ref().context("invalid operation")?;
2255    let capability = match operation.variant.as_ref() {
2256        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2257            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2258                match buffer_op.variant {
2259                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2260                        Capability::ReadOnly
2261                    }
2262                    _ => Capability::ReadWrite,
2263                }
2264            } else {
2265                Capability::ReadWrite
2266            }
2267        }
2268        Some(_) => Capability::ReadWrite,
2269        None => Capability::ReadOnly,
2270    };
2271
2272    let guard = session
2273        .db()
2274        .await
2275        .connections_for_buffer_update(project_id, session.connection_id, capability)
2276        .await?;
2277
2278    let (host, guests) = &*guard;
2279
2280    broadcast(
2281        Some(session.connection_id),
2282        guests.iter().chain([host]).copied(),
2283        |connection_id| {
2284            session
2285                .peer
2286                .forward_send(session.connection_id, connection_id, message.clone())
2287        },
2288    );
2289
2290    Ok(())
2291}
2292
2293/// Notify other participants that a project has been updated.
2294async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2295    request: T,
2296    session: Session,
2297) -> Result<()> {
2298    let project_id = ProjectId::from_proto(request.remote_entity_id());
2299    let project_connection_ids = session
2300        .db()
2301        .await
2302        .project_connection_ids(project_id, session.connection_id, false)
2303        .await?;
2304
2305    broadcast(
2306        Some(session.connection_id),
2307        project_connection_ids.iter().copied(),
2308        |connection_id| {
2309            session
2310                .peer
2311                .forward_send(session.connection_id, connection_id, request.clone())
2312        },
2313    );
2314    Ok(())
2315}
2316
2317/// Start following another user in a call.
2318async fn follow(
2319    request: proto::Follow,
2320    response: Response<proto::Follow>,
2321    session: Session,
2322) -> Result<()> {
2323    let room_id = RoomId::from_proto(request.room_id);
2324    let project_id = request.project_id.map(ProjectId::from_proto);
2325    let leader_id = request
2326        .leader_id
2327        .ok_or_else(|| anyhow!("invalid leader id"))?
2328        .into();
2329    let follower_id = session.connection_id;
2330
2331    session
2332        .db()
2333        .await
2334        .check_room_participants(room_id, leader_id, session.connection_id)
2335        .await?;
2336
2337    let response_payload = session
2338        .peer
2339        .forward_request(session.connection_id, leader_id, request)
2340        .await?;
2341    response.send(response_payload)?;
2342
2343    if let Some(project_id) = project_id {
2344        let room = session
2345            .db()
2346            .await
2347            .follow(room_id, project_id, leader_id, follower_id)
2348            .await?;
2349        room_updated(&room, &session.peer);
2350    }
2351
2352    Ok(())
2353}
2354
2355/// Stop following another user in a call.
2356async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2357    let room_id = RoomId::from_proto(request.room_id);
2358    let project_id = request.project_id.map(ProjectId::from_proto);
2359    let leader_id = request
2360        .leader_id
2361        .ok_or_else(|| anyhow!("invalid leader id"))?
2362        .into();
2363    let follower_id = session.connection_id;
2364
2365    session
2366        .db()
2367        .await
2368        .check_room_participants(room_id, leader_id, session.connection_id)
2369        .await?;
2370
2371    session
2372        .peer
2373        .forward_send(session.connection_id, leader_id, request)?;
2374
2375    if let Some(project_id) = project_id {
2376        let room = session
2377            .db()
2378            .await
2379            .unfollow(room_id, project_id, leader_id, follower_id)
2380            .await?;
2381        room_updated(&room, &session.peer);
2382    }
2383
2384    Ok(())
2385}
2386
2387/// Notify everyone following you of your current location.
2388async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2389    let room_id = RoomId::from_proto(request.room_id);
2390    let database = session.db.lock().await;
2391
2392    let connection_ids = if let Some(project_id) = request.project_id {
2393        let project_id = ProjectId::from_proto(project_id);
2394        database
2395            .project_connection_ids(project_id, session.connection_id, true)
2396            .await?
2397    } else {
2398        database
2399            .room_connection_ids(room_id, session.connection_id)
2400            .await?
2401    };
2402
2403    // For now, don't send view update messages back to that view's current leader.
2404    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2405        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2406        _ => None,
2407    });
2408
2409    for connection_id in connection_ids.iter().cloned() {
2410        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2411            session
2412                .peer
2413                .forward_send(session.connection_id, connection_id, request.clone())?;
2414        }
2415    }
2416    Ok(())
2417}
2418
2419/// Get public data about users.
2420async fn get_users(
2421    request: proto::GetUsers,
2422    response: Response<proto::GetUsers>,
2423    session: Session,
2424) -> Result<()> {
2425    let user_ids = request
2426        .user_ids
2427        .into_iter()
2428        .map(UserId::from_proto)
2429        .collect();
2430    let users = session
2431        .db()
2432        .await
2433        .get_users_by_ids(user_ids)
2434        .await?
2435        .into_iter()
2436        .map(|user| proto::User {
2437            id: user.id.to_proto(),
2438            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2439            github_login: user.github_login,
2440            email: user.email_address,
2441            name: user.name,
2442        })
2443        .collect();
2444    response.send(proto::UsersResponse { users })?;
2445    Ok(())
2446}
2447
2448/// Search for users (to invite) buy Github login
2449async fn fuzzy_search_users(
2450    request: proto::FuzzySearchUsers,
2451    response: Response<proto::FuzzySearchUsers>,
2452    session: Session,
2453) -> Result<()> {
2454    let query = request.query;
2455    let users = match query.len() {
2456        0 => vec![],
2457        1 | 2 => session
2458            .db()
2459            .await
2460            .get_user_by_github_login(&query)
2461            .await?
2462            .into_iter()
2463            .collect(),
2464        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2465    };
2466    let users = users
2467        .into_iter()
2468        .filter(|user| user.id != session.user_id())
2469        .map(|user| proto::User {
2470            id: user.id.to_proto(),
2471            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2472            github_login: user.github_login,
2473            name: user.name,
2474            email: user.email_address,
2475        })
2476        .collect();
2477    response.send(proto::UsersResponse { users })?;
2478    Ok(())
2479}
2480
2481/// Send a contact request to another user.
2482async fn request_contact(
2483    request: proto::RequestContact,
2484    response: Response<proto::RequestContact>,
2485    session: Session,
2486) -> Result<()> {
2487    let requester_id = session.user_id();
2488    let responder_id = UserId::from_proto(request.responder_id);
2489    if requester_id == responder_id {
2490        return Err(anyhow!("cannot add yourself as a contact"))?;
2491    }
2492
2493    let notifications = session
2494        .db()
2495        .await
2496        .send_contact_request(requester_id, responder_id)
2497        .await?;
2498
2499    // Update outgoing contact requests of requester
2500    let mut update = proto::UpdateContacts::default();
2501    update.outgoing_requests.push(responder_id.to_proto());
2502    for connection_id in session
2503        .connection_pool()
2504        .await
2505        .user_connection_ids(requester_id)
2506    {
2507        session.peer.send(connection_id, update.clone())?;
2508    }
2509
2510    // Update incoming contact requests of responder
2511    let mut update = proto::UpdateContacts::default();
2512    update
2513        .incoming_requests
2514        .push(proto::IncomingContactRequest {
2515            requester_id: requester_id.to_proto(),
2516        });
2517    let connection_pool = session.connection_pool().await;
2518    for connection_id in connection_pool.user_connection_ids(responder_id) {
2519        session.peer.send(connection_id, update.clone())?;
2520    }
2521
2522    send_notifications(&connection_pool, &session.peer, notifications);
2523
2524    response.send(proto::Ack {})?;
2525    Ok(())
2526}
2527
2528/// Accept or decline a contact request
2529async fn respond_to_contact_request(
2530    request: proto::RespondToContactRequest,
2531    response: Response<proto::RespondToContactRequest>,
2532    session: Session,
2533) -> Result<()> {
2534    let responder_id = session.user_id();
2535    let requester_id = UserId::from_proto(request.requester_id);
2536    let db = session.db().await;
2537    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2538        db.dismiss_contact_notification(responder_id, requester_id)
2539            .await?;
2540    } else {
2541        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2542
2543        let notifications = db
2544            .respond_to_contact_request(responder_id, requester_id, accept)
2545            .await?;
2546        let requester_busy = db.is_user_busy(requester_id).await?;
2547        let responder_busy = db.is_user_busy(responder_id).await?;
2548
2549        let pool = session.connection_pool().await;
2550        // Update responder with new contact
2551        let mut update = proto::UpdateContacts::default();
2552        if accept {
2553            update
2554                .contacts
2555                .push(contact_for_user(requester_id, requester_busy, &pool));
2556        }
2557        update
2558            .remove_incoming_requests
2559            .push(requester_id.to_proto());
2560        for connection_id in pool.user_connection_ids(responder_id) {
2561            session.peer.send(connection_id, update.clone())?;
2562        }
2563
2564        // Update requester with new contact
2565        let mut update = proto::UpdateContacts::default();
2566        if accept {
2567            update
2568                .contacts
2569                .push(contact_for_user(responder_id, responder_busy, &pool));
2570        }
2571        update
2572            .remove_outgoing_requests
2573            .push(responder_id.to_proto());
2574
2575        for connection_id in pool.user_connection_ids(requester_id) {
2576            session.peer.send(connection_id, update.clone())?;
2577        }
2578
2579        send_notifications(&pool, &session.peer, notifications);
2580    }
2581
2582    response.send(proto::Ack {})?;
2583    Ok(())
2584}
2585
2586/// Remove a contact.
2587async fn remove_contact(
2588    request: proto::RemoveContact,
2589    response: Response<proto::RemoveContact>,
2590    session: Session,
2591) -> Result<()> {
2592    let requester_id = session.user_id();
2593    let responder_id = UserId::from_proto(request.user_id);
2594    let db = session.db().await;
2595    let (contact_accepted, deleted_notification_id) =
2596        db.remove_contact(requester_id, responder_id).await?;
2597
2598    let pool = session.connection_pool().await;
2599    // Update outgoing contact requests of requester
2600    let mut update = proto::UpdateContacts::default();
2601    if contact_accepted {
2602        update.remove_contacts.push(responder_id.to_proto());
2603    } else {
2604        update
2605            .remove_outgoing_requests
2606            .push(responder_id.to_proto());
2607    }
2608    for connection_id in pool.user_connection_ids(requester_id) {
2609        session.peer.send(connection_id, update.clone())?;
2610    }
2611
2612    // Update incoming contact requests of responder
2613    let mut update = proto::UpdateContacts::default();
2614    if contact_accepted {
2615        update.remove_contacts.push(requester_id.to_proto());
2616    } else {
2617        update
2618            .remove_incoming_requests
2619            .push(requester_id.to_proto());
2620    }
2621    for connection_id in pool.user_connection_ids(responder_id) {
2622        session.peer.send(connection_id, update.clone())?;
2623        if let Some(notification_id) = deleted_notification_id {
2624            session.peer.send(
2625                connection_id,
2626                proto::DeleteNotification {
2627                    notification_id: notification_id.to_proto(),
2628                },
2629            )?;
2630        }
2631    }
2632
2633    response.send(proto::Ack {})?;
2634    Ok(())
2635}
2636
2637fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2638    version.0.minor() < 139
2639}
2640
2641async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2642    let plan = session.current_plan(&session.db().await).await?;
2643
2644    session
2645        .peer
2646        .send(
2647            session.connection_id,
2648            proto::UpdateUserPlan { plan: plan.into() },
2649        )
2650        .trace_err();
2651
2652    Ok(())
2653}
2654
2655async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2656    subscribe_user_to_channels(session.user_id(), &session).await?;
2657    Ok(())
2658}
2659
2660async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2661    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2662    let mut pool = session.connection_pool().await;
2663    for membership in &channels_for_user.channel_memberships {
2664        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2665    }
2666    session.peer.send(
2667        session.connection_id,
2668        build_update_user_channels(&channels_for_user),
2669    )?;
2670    session.peer.send(
2671        session.connection_id,
2672        build_channels_update(channels_for_user),
2673    )?;
2674    Ok(())
2675}
2676
2677/// Creates a new channel.
2678async fn create_channel(
2679    request: proto::CreateChannel,
2680    response: Response<proto::CreateChannel>,
2681    session: Session,
2682) -> Result<()> {
2683    let db = session.db().await;
2684
2685    let parent_id = request.parent_id.map(ChannelId::from_proto);
2686    let (channel, membership) = db
2687        .create_channel(&request.name, parent_id, session.user_id())
2688        .await?;
2689
2690    let root_id = channel.root_id();
2691    let channel = Channel::from_model(channel);
2692
2693    response.send(proto::CreateChannelResponse {
2694        channel: Some(channel.to_proto()),
2695        parent_id: request.parent_id,
2696    })?;
2697
2698    let mut connection_pool = session.connection_pool().await;
2699    if let Some(membership) = membership {
2700        connection_pool.subscribe_to_channel(
2701            membership.user_id,
2702            membership.channel_id,
2703            membership.role,
2704        );
2705        let update = proto::UpdateUserChannels {
2706            channel_memberships: vec![proto::ChannelMembership {
2707                channel_id: membership.channel_id.to_proto(),
2708                role: membership.role.into(),
2709            }],
2710            ..Default::default()
2711        };
2712        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2713            session.peer.send(connection_id, update.clone())?;
2714        }
2715    }
2716
2717    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2718        if !role.can_see_channel(channel.visibility) {
2719            continue;
2720        }
2721
2722        let update = proto::UpdateChannels {
2723            channels: vec![channel.to_proto()],
2724            ..Default::default()
2725        };
2726        session.peer.send(connection_id, update.clone())?;
2727    }
2728
2729    Ok(())
2730}
2731
2732/// Delete a channel
2733async fn delete_channel(
2734    request: proto::DeleteChannel,
2735    response: Response<proto::DeleteChannel>,
2736    session: Session,
2737) -> Result<()> {
2738    let db = session.db().await;
2739
2740    let channel_id = request.channel_id;
2741    let (root_channel, removed_channels) = db
2742        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2743        .await?;
2744    response.send(proto::Ack {})?;
2745
2746    // Notify members of removed channels
2747    let mut update = proto::UpdateChannels::default();
2748    update
2749        .delete_channels
2750        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2751
2752    let connection_pool = session.connection_pool().await;
2753    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2754        session.peer.send(connection_id, update.clone())?;
2755    }
2756
2757    Ok(())
2758}
2759
2760/// Invite someone to join a channel.
2761async fn invite_channel_member(
2762    request: proto::InviteChannelMember,
2763    response: Response<proto::InviteChannelMember>,
2764    session: Session,
2765) -> Result<()> {
2766    let db = session.db().await;
2767    let channel_id = ChannelId::from_proto(request.channel_id);
2768    let invitee_id = UserId::from_proto(request.user_id);
2769    let InviteMemberResult {
2770        channel,
2771        notifications,
2772    } = db
2773        .invite_channel_member(
2774            channel_id,
2775            invitee_id,
2776            session.user_id(),
2777            request.role().into(),
2778        )
2779        .await?;
2780
2781    let update = proto::UpdateChannels {
2782        channel_invitations: vec![channel.to_proto()],
2783        ..Default::default()
2784    };
2785
2786    let connection_pool = session.connection_pool().await;
2787    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2788        session.peer.send(connection_id, update.clone())?;
2789    }
2790
2791    send_notifications(&connection_pool, &session.peer, notifications);
2792
2793    response.send(proto::Ack {})?;
2794    Ok(())
2795}
2796
2797/// remove someone from a channel
2798async fn remove_channel_member(
2799    request: proto::RemoveChannelMember,
2800    response: Response<proto::RemoveChannelMember>,
2801    session: Session,
2802) -> Result<()> {
2803    let db = session.db().await;
2804    let channel_id = ChannelId::from_proto(request.channel_id);
2805    let member_id = UserId::from_proto(request.user_id);
2806
2807    let RemoveChannelMemberResult {
2808        membership_update,
2809        notification_id,
2810    } = db
2811        .remove_channel_member(channel_id, member_id, session.user_id())
2812        .await?;
2813
2814    let mut connection_pool = session.connection_pool().await;
2815    notify_membership_updated(
2816        &mut connection_pool,
2817        membership_update,
2818        member_id,
2819        &session.peer,
2820    );
2821    for connection_id in connection_pool.user_connection_ids(member_id) {
2822        if let Some(notification_id) = notification_id {
2823            session
2824                .peer
2825                .send(
2826                    connection_id,
2827                    proto::DeleteNotification {
2828                        notification_id: notification_id.to_proto(),
2829                    },
2830                )
2831                .trace_err();
2832        }
2833    }
2834
2835    response.send(proto::Ack {})?;
2836    Ok(())
2837}
2838
2839/// Toggle the channel between public and private.
2840/// Care is taken to maintain the invariant that public channels only descend from public channels,
2841/// (though members-only channels can appear at any point in the hierarchy).
2842async fn set_channel_visibility(
2843    request: proto::SetChannelVisibility,
2844    response: Response<proto::SetChannelVisibility>,
2845    session: Session,
2846) -> Result<()> {
2847    let db = session.db().await;
2848    let channel_id = ChannelId::from_proto(request.channel_id);
2849    let visibility = request.visibility().into();
2850
2851    let channel_model = db
2852        .set_channel_visibility(channel_id, visibility, session.user_id())
2853        .await?;
2854    let root_id = channel_model.root_id();
2855    let channel = Channel::from_model(channel_model);
2856
2857    let mut connection_pool = session.connection_pool().await;
2858    for (user_id, role) in connection_pool
2859        .channel_user_ids(root_id)
2860        .collect::<Vec<_>>()
2861        .into_iter()
2862    {
2863        let update = if role.can_see_channel(channel.visibility) {
2864            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2865            proto::UpdateChannels {
2866                channels: vec![channel.to_proto()],
2867                ..Default::default()
2868            }
2869        } else {
2870            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2871            proto::UpdateChannels {
2872                delete_channels: vec![channel.id.to_proto()],
2873                ..Default::default()
2874            }
2875        };
2876
2877        for connection_id in connection_pool.user_connection_ids(user_id) {
2878            session.peer.send(connection_id, update.clone())?;
2879        }
2880    }
2881
2882    response.send(proto::Ack {})?;
2883    Ok(())
2884}
2885
2886/// Alter the role for a user in the channel.
2887async fn set_channel_member_role(
2888    request: proto::SetChannelMemberRole,
2889    response: Response<proto::SetChannelMemberRole>,
2890    session: Session,
2891) -> Result<()> {
2892    let db = session.db().await;
2893    let channel_id = ChannelId::from_proto(request.channel_id);
2894    let member_id = UserId::from_proto(request.user_id);
2895    let result = db
2896        .set_channel_member_role(
2897            channel_id,
2898            session.user_id(),
2899            member_id,
2900            request.role().into(),
2901        )
2902        .await?;
2903
2904    match result {
2905        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2906            let mut connection_pool = session.connection_pool().await;
2907            notify_membership_updated(
2908                &mut connection_pool,
2909                membership_update,
2910                member_id,
2911                &session.peer,
2912            )
2913        }
2914        db::SetMemberRoleResult::InviteUpdated(channel) => {
2915            let update = proto::UpdateChannels {
2916                channel_invitations: vec![channel.to_proto()],
2917                ..Default::default()
2918            };
2919
2920            for connection_id in session
2921                .connection_pool()
2922                .await
2923                .user_connection_ids(member_id)
2924            {
2925                session.peer.send(connection_id, update.clone())?;
2926            }
2927        }
2928    }
2929
2930    response.send(proto::Ack {})?;
2931    Ok(())
2932}
2933
2934/// Change the name of a channel
2935async fn rename_channel(
2936    request: proto::RenameChannel,
2937    response: Response<proto::RenameChannel>,
2938    session: Session,
2939) -> Result<()> {
2940    let db = session.db().await;
2941    let channel_id = ChannelId::from_proto(request.channel_id);
2942    let channel_model = db
2943        .rename_channel(channel_id, session.user_id(), &request.name)
2944        .await?;
2945    let root_id = channel_model.root_id();
2946    let channel = Channel::from_model(channel_model);
2947
2948    response.send(proto::RenameChannelResponse {
2949        channel: Some(channel.to_proto()),
2950    })?;
2951
2952    let connection_pool = session.connection_pool().await;
2953    let update = proto::UpdateChannels {
2954        channels: vec![channel.to_proto()],
2955        ..Default::default()
2956    };
2957    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2958        if role.can_see_channel(channel.visibility) {
2959            session.peer.send(connection_id, update.clone())?;
2960        }
2961    }
2962
2963    Ok(())
2964}
2965
2966/// Move a channel to a new parent.
2967async fn move_channel(
2968    request: proto::MoveChannel,
2969    response: Response<proto::MoveChannel>,
2970    session: Session,
2971) -> Result<()> {
2972    let channel_id = ChannelId::from_proto(request.channel_id);
2973    let to = ChannelId::from_proto(request.to);
2974
2975    let (root_id, channels) = session
2976        .db()
2977        .await
2978        .move_channel(channel_id, to, session.user_id())
2979        .await?;
2980
2981    let connection_pool = session.connection_pool().await;
2982    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2983        let channels = channels
2984            .iter()
2985            .filter_map(|channel| {
2986                if role.can_see_channel(channel.visibility) {
2987                    Some(channel.to_proto())
2988                } else {
2989                    None
2990                }
2991            })
2992            .collect::<Vec<_>>();
2993        if channels.is_empty() {
2994            continue;
2995        }
2996
2997        let update = proto::UpdateChannels {
2998            channels,
2999            ..Default::default()
3000        };
3001
3002        session.peer.send(connection_id, update.clone())?;
3003    }
3004
3005    response.send(Ack {})?;
3006    Ok(())
3007}
3008
3009/// Get the list of channel members
3010async fn get_channel_members(
3011    request: proto::GetChannelMembers,
3012    response: Response<proto::GetChannelMembers>,
3013    session: Session,
3014) -> Result<()> {
3015    let db = session.db().await;
3016    let channel_id = ChannelId::from_proto(request.channel_id);
3017    let limit = if request.limit == 0 {
3018        u16::MAX as u64
3019    } else {
3020        request.limit
3021    };
3022    let (members, users) = db
3023        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3024        .await?;
3025    response.send(proto::GetChannelMembersResponse { members, users })?;
3026    Ok(())
3027}
3028
3029/// Accept or decline a channel invitation.
3030async fn respond_to_channel_invite(
3031    request: proto::RespondToChannelInvite,
3032    response: Response<proto::RespondToChannelInvite>,
3033    session: Session,
3034) -> Result<()> {
3035    let db = session.db().await;
3036    let channel_id = ChannelId::from_proto(request.channel_id);
3037    let RespondToChannelInvite {
3038        membership_update,
3039        notifications,
3040    } = db
3041        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3042        .await?;
3043
3044    let mut connection_pool = session.connection_pool().await;
3045    if let Some(membership_update) = membership_update {
3046        notify_membership_updated(
3047            &mut connection_pool,
3048            membership_update,
3049            session.user_id(),
3050            &session.peer,
3051        );
3052    } else {
3053        let update = proto::UpdateChannels {
3054            remove_channel_invitations: vec![channel_id.to_proto()],
3055            ..Default::default()
3056        };
3057
3058        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3059            session.peer.send(connection_id, update.clone())?;
3060        }
3061    };
3062
3063    send_notifications(&connection_pool, &session.peer, notifications);
3064
3065    response.send(proto::Ack {})?;
3066
3067    Ok(())
3068}
3069
3070/// Join the channels' room
3071async fn join_channel(
3072    request: proto::JoinChannel,
3073    response: Response<proto::JoinChannel>,
3074    session: Session,
3075) -> Result<()> {
3076    let channel_id = ChannelId::from_proto(request.channel_id);
3077    join_channel_internal(channel_id, Box::new(response), session).await
3078}
3079
3080trait JoinChannelInternalResponse {
3081    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3082}
3083impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3084    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3085        Response::<proto::JoinChannel>::send(self, result)
3086    }
3087}
3088impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3089    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3090        Response::<proto::JoinRoom>::send(self, result)
3091    }
3092}
3093
3094async fn join_channel_internal(
3095    channel_id: ChannelId,
3096    response: Box<impl JoinChannelInternalResponse>,
3097    session: Session,
3098) -> Result<()> {
3099    let joined_room = {
3100        let mut db = session.db().await;
3101        // If zed quits without leaving the room, and the user re-opens zed before the
3102        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3103        // room they were in.
3104        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3105            tracing::info!(
3106                stale_connection_id = %connection,
3107                "cleaning up stale connection",
3108            );
3109            drop(db);
3110            leave_room_for_session(&session, connection).await?;
3111            db = session.db().await;
3112        }
3113
3114        let (joined_room, membership_updated, role) = db
3115            .join_channel(channel_id, session.user_id(), session.connection_id)
3116            .await?;
3117
3118        let live_kit_connection_info =
3119            session
3120                .app_state
3121                .livekit_client
3122                .as_ref()
3123                .and_then(|live_kit| {
3124                    let (can_publish, token) = if role == ChannelRole::Guest {
3125                        (
3126                            false,
3127                            live_kit
3128                                .guest_token(
3129                                    &joined_room.room.livekit_room,
3130                                    &session.user_id().to_string(),
3131                                )
3132                                .trace_err()?,
3133                        )
3134                    } else {
3135                        (
3136                            true,
3137                            live_kit
3138                                .room_token(
3139                                    &joined_room.room.livekit_room,
3140                                    &session.user_id().to_string(),
3141                                )
3142                                .trace_err()?,
3143                        )
3144                    };
3145
3146                    Some(LiveKitConnectionInfo {
3147                        server_url: live_kit.url().into(),
3148                        token,
3149                        can_publish,
3150                    })
3151                });
3152
3153        response.send(proto::JoinRoomResponse {
3154            room: Some(joined_room.room.clone()),
3155            channel_id: joined_room
3156                .channel
3157                .as_ref()
3158                .map(|channel| channel.id.to_proto()),
3159            live_kit_connection_info,
3160        })?;
3161
3162        let mut connection_pool = session.connection_pool().await;
3163        if let Some(membership_updated) = membership_updated {
3164            notify_membership_updated(
3165                &mut connection_pool,
3166                membership_updated,
3167                session.user_id(),
3168                &session.peer,
3169            );
3170        }
3171
3172        room_updated(&joined_room.room, &session.peer);
3173
3174        joined_room
3175    };
3176
3177    channel_updated(
3178        &joined_room
3179            .channel
3180            .ok_or_else(|| anyhow!("channel not returned"))?,
3181        &joined_room.room,
3182        &session.peer,
3183        &*session.connection_pool().await,
3184    );
3185
3186    update_user_contacts(session.user_id(), &session).await?;
3187    Ok(())
3188}
3189
3190/// Start editing the channel notes
3191async fn join_channel_buffer(
3192    request: proto::JoinChannelBuffer,
3193    response: Response<proto::JoinChannelBuffer>,
3194    session: Session,
3195) -> Result<()> {
3196    let db = session.db().await;
3197    let channel_id = ChannelId::from_proto(request.channel_id);
3198
3199    let open_response = db
3200        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3201        .await?;
3202
3203    let collaborators = open_response.collaborators.clone();
3204    response.send(open_response)?;
3205
3206    let update = UpdateChannelBufferCollaborators {
3207        channel_id: channel_id.to_proto(),
3208        collaborators: collaborators.clone(),
3209    };
3210    channel_buffer_updated(
3211        session.connection_id,
3212        collaborators
3213            .iter()
3214            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3215        &update,
3216        &session.peer,
3217    );
3218
3219    Ok(())
3220}
3221
3222/// Edit the channel notes
3223async fn update_channel_buffer(
3224    request: proto::UpdateChannelBuffer,
3225    session: Session,
3226) -> Result<()> {
3227    let db = session.db().await;
3228    let channel_id = ChannelId::from_proto(request.channel_id);
3229
3230    let (collaborators, epoch, version) = db
3231        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3232        .await?;
3233
3234    channel_buffer_updated(
3235        session.connection_id,
3236        collaborators.clone(),
3237        &proto::UpdateChannelBuffer {
3238            channel_id: channel_id.to_proto(),
3239            operations: request.operations,
3240        },
3241        &session.peer,
3242    );
3243
3244    let pool = &*session.connection_pool().await;
3245
3246    let non_collaborators =
3247        pool.channel_connection_ids(channel_id)
3248            .filter_map(|(connection_id, _)| {
3249                if collaborators.contains(&connection_id) {
3250                    None
3251                } else {
3252                    Some(connection_id)
3253                }
3254            });
3255
3256    broadcast(None, non_collaborators, |peer_id| {
3257        session.peer.send(
3258            peer_id,
3259            proto::UpdateChannels {
3260                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3261                    channel_id: channel_id.to_proto(),
3262                    epoch: epoch as u64,
3263                    version: version.clone(),
3264                }],
3265                ..Default::default()
3266            },
3267        )
3268    });
3269
3270    Ok(())
3271}
3272
3273/// Rejoin the channel notes after a connection blip
3274async fn rejoin_channel_buffers(
3275    request: proto::RejoinChannelBuffers,
3276    response: Response<proto::RejoinChannelBuffers>,
3277    session: Session,
3278) -> Result<()> {
3279    let db = session.db().await;
3280    let buffers = db
3281        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3282        .await?;
3283
3284    for rejoined_buffer in &buffers {
3285        let collaborators_to_notify = rejoined_buffer
3286            .buffer
3287            .collaborators
3288            .iter()
3289            .filter_map(|c| Some(c.peer_id?.into()));
3290        channel_buffer_updated(
3291            session.connection_id,
3292            collaborators_to_notify,
3293            &proto::UpdateChannelBufferCollaborators {
3294                channel_id: rejoined_buffer.buffer.channel_id,
3295                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3296            },
3297            &session.peer,
3298        );
3299    }
3300
3301    response.send(proto::RejoinChannelBuffersResponse {
3302        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3303    })?;
3304
3305    Ok(())
3306}
3307
3308/// Stop editing the channel notes
3309async fn leave_channel_buffer(
3310    request: proto::LeaveChannelBuffer,
3311    response: Response<proto::LeaveChannelBuffer>,
3312    session: Session,
3313) -> Result<()> {
3314    let db = session.db().await;
3315    let channel_id = ChannelId::from_proto(request.channel_id);
3316
3317    let left_buffer = db
3318        .leave_channel_buffer(channel_id, session.connection_id)
3319        .await?;
3320
3321    response.send(Ack {})?;
3322
3323    channel_buffer_updated(
3324        session.connection_id,
3325        left_buffer.connections,
3326        &proto::UpdateChannelBufferCollaborators {
3327            channel_id: channel_id.to_proto(),
3328            collaborators: left_buffer.collaborators,
3329        },
3330        &session.peer,
3331    );
3332
3333    Ok(())
3334}
3335
3336fn channel_buffer_updated<T: EnvelopedMessage>(
3337    sender_id: ConnectionId,
3338    collaborators: impl IntoIterator<Item = ConnectionId>,
3339    message: &T,
3340    peer: &Peer,
3341) {
3342    broadcast(Some(sender_id), collaborators, |peer_id| {
3343        peer.send(peer_id, message.clone())
3344    });
3345}
3346
3347fn send_notifications(
3348    connection_pool: &ConnectionPool,
3349    peer: &Peer,
3350    notifications: db::NotificationBatch,
3351) {
3352    for (user_id, notification) in notifications {
3353        for connection_id in connection_pool.user_connection_ids(user_id) {
3354            if let Err(error) = peer.send(
3355                connection_id,
3356                proto::AddNotification {
3357                    notification: Some(notification.clone()),
3358                },
3359            ) {
3360                tracing::error!(
3361                    "failed to send notification to {:?} {}",
3362                    connection_id,
3363                    error
3364                );
3365            }
3366        }
3367    }
3368}
3369
3370/// Send a message to the channel
3371async fn send_channel_message(
3372    request: proto::SendChannelMessage,
3373    response: Response<proto::SendChannelMessage>,
3374    session: Session,
3375) -> Result<()> {
3376    // Validate the message body.
3377    let body = request.body.trim().to_string();
3378    if body.len() > MAX_MESSAGE_LEN {
3379        return Err(anyhow!("message is too long"))?;
3380    }
3381    if body.is_empty() {
3382        return Err(anyhow!("message can't be blank"))?;
3383    }
3384
3385    // TODO: adjust mentions if body is trimmed
3386
3387    let timestamp = OffsetDateTime::now_utc();
3388    let nonce = request
3389        .nonce
3390        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3391
3392    let channel_id = ChannelId::from_proto(request.channel_id);
3393    let CreatedChannelMessage {
3394        message_id,
3395        participant_connection_ids,
3396        notifications,
3397    } = session
3398        .db()
3399        .await
3400        .create_channel_message(
3401            channel_id,
3402            session.user_id(),
3403            &body,
3404            &request.mentions,
3405            timestamp,
3406            nonce.clone().into(),
3407            request.reply_to_message_id.map(MessageId::from_proto),
3408        )
3409        .await?;
3410
3411    let message = proto::ChannelMessage {
3412        sender_id: session.user_id().to_proto(),
3413        id: message_id.to_proto(),
3414        body,
3415        mentions: request.mentions,
3416        timestamp: timestamp.unix_timestamp() as u64,
3417        nonce: Some(nonce),
3418        reply_to_message_id: request.reply_to_message_id,
3419        edited_at: None,
3420    };
3421    broadcast(
3422        Some(session.connection_id),
3423        participant_connection_ids.clone(),
3424        |connection| {
3425            session.peer.send(
3426                connection,
3427                proto::ChannelMessageSent {
3428                    channel_id: channel_id.to_proto(),
3429                    message: Some(message.clone()),
3430                },
3431            )
3432        },
3433    );
3434    response.send(proto::SendChannelMessageResponse {
3435        message: Some(message),
3436    })?;
3437
3438    let pool = &*session.connection_pool().await;
3439    let non_participants =
3440        pool.channel_connection_ids(channel_id)
3441            .filter_map(|(connection_id, _)| {
3442                if participant_connection_ids.contains(&connection_id) {
3443                    None
3444                } else {
3445                    Some(connection_id)
3446                }
3447            });
3448    broadcast(None, non_participants, |peer_id| {
3449        session.peer.send(
3450            peer_id,
3451            proto::UpdateChannels {
3452                latest_channel_message_ids: vec![proto::ChannelMessageId {
3453                    channel_id: channel_id.to_proto(),
3454                    message_id: message_id.to_proto(),
3455                }],
3456                ..Default::default()
3457            },
3458        )
3459    });
3460    send_notifications(pool, &session.peer, notifications);
3461
3462    Ok(())
3463}
3464
3465/// Delete a channel message
3466async fn remove_channel_message(
3467    request: proto::RemoveChannelMessage,
3468    response: Response<proto::RemoveChannelMessage>,
3469    session: Session,
3470) -> Result<()> {
3471    let channel_id = ChannelId::from_proto(request.channel_id);
3472    let message_id = MessageId::from_proto(request.message_id);
3473    let (connection_ids, existing_notification_ids) = session
3474        .db()
3475        .await
3476        .remove_channel_message(channel_id, message_id, session.user_id())
3477        .await?;
3478
3479    broadcast(
3480        Some(session.connection_id),
3481        connection_ids,
3482        move |connection| {
3483            session.peer.send(connection, request.clone())?;
3484
3485            for notification_id in &existing_notification_ids {
3486                session.peer.send(
3487                    connection,
3488                    proto::DeleteNotification {
3489                        notification_id: (*notification_id).to_proto(),
3490                    },
3491                )?;
3492            }
3493
3494            Ok(())
3495        },
3496    );
3497    response.send(proto::Ack {})?;
3498    Ok(())
3499}
3500
3501async fn update_channel_message(
3502    request: proto::UpdateChannelMessage,
3503    response: Response<proto::UpdateChannelMessage>,
3504    session: Session,
3505) -> Result<()> {
3506    let channel_id = ChannelId::from_proto(request.channel_id);
3507    let message_id = MessageId::from_proto(request.message_id);
3508    let updated_at = OffsetDateTime::now_utc();
3509    let UpdatedChannelMessage {
3510        message_id,
3511        participant_connection_ids,
3512        notifications,
3513        reply_to_message_id,
3514        timestamp,
3515        deleted_mention_notification_ids,
3516        updated_mention_notifications,
3517    } = session
3518        .db()
3519        .await
3520        .update_channel_message(
3521            channel_id,
3522            message_id,
3523            session.user_id(),
3524            request.body.as_str(),
3525            &request.mentions,
3526            updated_at,
3527        )
3528        .await?;
3529
3530    let nonce = request
3531        .nonce
3532        .clone()
3533        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3534
3535    let message = proto::ChannelMessage {
3536        sender_id: session.user_id().to_proto(),
3537        id: message_id.to_proto(),
3538        body: request.body.clone(),
3539        mentions: request.mentions.clone(),
3540        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3541        nonce: Some(nonce),
3542        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3543        edited_at: Some(updated_at.unix_timestamp() as u64),
3544    };
3545
3546    response.send(proto::Ack {})?;
3547
3548    let pool = &*session.connection_pool().await;
3549    broadcast(
3550        Some(session.connection_id),
3551        participant_connection_ids,
3552        |connection| {
3553            session.peer.send(
3554                connection,
3555                proto::ChannelMessageUpdate {
3556                    channel_id: channel_id.to_proto(),
3557                    message: Some(message.clone()),
3558                },
3559            )?;
3560
3561            for notification_id in &deleted_mention_notification_ids {
3562                session.peer.send(
3563                    connection,
3564                    proto::DeleteNotification {
3565                        notification_id: (*notification_id).to_proto(),
3566                    },
3567                )?;
3568            }
3569
3570            for notification in &updated_mention_notifications {
3571                session.peer.send(
3572                    connection,
3573                    proto::UpdateNotification {
3574                        notification: Some(notification.clone()),
3575                    },
3576                )?;
3577            }
3578
3579            Ok(())
3580        },
3581    );
3582
3583    send_notifications(pool, &session.peer, notifications);
3584
3585    Ok(())
3586}
3587
3588/// Mark a channel message as read
3589async fn acknowledge_channel_message(
3590    request: proto::AckChannelMessage,
3591    session: Session,
3592) -> Result<()> {
3593    let channel_id = ChannelId::from_proto(request.channel_id);
3594    let message_id = MessageId::from_proto(request.message_id);
3595    let notifications = session
3596        .db()
3597        .await
3598        .observe_channel_message(channel_id, session.user_id(), message_id)
3599        .await?;
3600    send_notifications(
3601        &*session.connection_pool().await,
3602        &session.peer,
3603        notifications,
3604    );
3605    Ok(())
3606}
3607
3608/// Mark a buffer version as synced
3609async fn acknowledge_buffer_version(
3610    request: proto::AckBufferOperation,
3611    session: Session,
3612) -> Result<()> {
3613    let buffer_id = BufferId::from_proto(request.buffer_id);
3614    session
3615        .db()
3616        .await
3617        .observe_buffer_version(
3618            buffer_id,
3619            session.user_id(),
3620            request.epoch as i32,
3621            &request.version,
3622        )
3623        .await?;
3624    Ok(())
3625}
3626
3627async fn count_language_model_tokens(
3628    request: proto::CountLanguageModelTokens,
3629    response: Response<proto::CountLanguageModelTokens>,
3630    session: Session,
3631    config: &Config,
3632) -> Result<()> {
3633    authorize_access_to_legacy_llm_endpoints(&session).await?;
3634
3635    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3636        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3637        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3638    };
3639
3640    session
3641        .app_state
3642        .rate_limiter
3643        .check(&*rate_limit, session.user_id())
3644        .await?;
3645
3646    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3647        Some(proto::LanguageModelProvider::Google) => {
3648            let api_key = config
3649                .google_ai_api_key
3650                .as_ref()
3651                .context("no Google AI API key configured on the server")?;
3652            google_ai::count_tokens(
3653                session.http_client.as_ref(),
3654                google_ai::API_URL,
3655                api_key,
3656                serde_json::from_str(&request.request)?,
3657            )
3658            .await?
3659        }
3660        _ => return Err(anyhow!("unsupported provider"))?,
3661    };
3662
3663    response.send(proto::CountLanguageModelTokensResponse {
3664        token_count: result.total_tokens as u32,
3665    })?;
3666
3667    Ok(())
3668}
3669
3670struct ZedProCountLanguageModelTokensRateLimit;
3671
3672impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3673    fn capacity(&self) -> usize {
3674        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3675            .ok()
3676            .and_then(|v| v.parse().ok())
3677            .unwrap_or(600) // Picked arbitrarily
3678    }
3679
3680    fn refill_duration(&self) -> chrono::Duration {
3681        chrono::Duration::hours(1)
3682    }
3683
3684    fn db_name(&self) -> &'static str {
3685        "zed-pro:count-language-model-tokens"
3686    }
3687}
3688
3689struct FreeCountLanguageModelTokensRateLimit;
3690
3691impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3692    fn capacity(&self) -> usize {
3693        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3694            .ok()
3695            .and_then(|v| v.parse().ok())
3696            .unwrap_or(600 / 10) // Picked arbitrarily
3697    }
3698
3699    fn refill_duration(&self) -> chrono::Duration {
3700        chrono::Duration::hours(1)
3701    }
3702
3703    fn db_name(&self) -> &'static str {
3704        "free:count-language-model-tokens"
3705    }
3706}
3707
3708struct ZedProComputeEmbeddingsRateLimit;
3709
3710impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3711    fn capacity(&self) -> usize {
3712        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3713            .ok()
3714            .and_then(|v| v.parse().ok())
3715            .unwrap_or(5000) // Picked arbitrarily
3716    }
3717
3718    fn refill_duration(&self) -> chrono::Duration {
3719        chrono::Duration::hours(1)
3720    }
3721
3722    fn db_name(&self) -> &'static str {
3723        "zed-pro:compute-embeddings"
3724    }
3725}
3726
3727struct FreeComputeEmbeddingsRateLimit;
3728
3729impl RateLimit for FreeComputeEmbeddingsRateLimit {
3730    fn capacity(&self) -> usize {
3731        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3732            .ok()
3733            .and_then(|v| v.parse().ok())
3734            .unwrap_or(5000 / 10) // Picked arbitrarily
3735    }
3736
3737    fn refill_duration(&self) -> chrono::Duration {
3738        chrono::Duration::hours(1)
3739    }
3740
3741    fn db_name(&self) -> &'static str {
3742        "free:compute-embeddings"
3743    }
3744}
3745
3746async fn compute_embeddings(
3747    request: proto::ComputeEmbeddings,
3748    response: Response<proto::ComputeEmbeddings>,
3749    session: Session,
3750    api_key: Option<Arc<str>>,
3751) -> Result<()> {
3752    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3753    authorize_access_to_legacy_llm_endpoints(&session).await?;
3754
3755    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3756        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3757        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3758    };
3759
3760    session
3761        .app_state
3762        .rate_limiter
3763        .check(&*rate_limit, session.user_id())
3764        .await?;
3765
3766    let embeddings = match request.model.as_str() {
3767        "openai/text-embedding-3-small" => {
3768            open_ai::embed(
3769                session.http_client.as_ref(),
3770                OPEN_AI_API_URL,
3771                &api_key,
3772                OpenAiEmbeddingModel::TextEmbedding3Small,
3773                request.texts.iter().map(|text| text.as_str()),
3774            )
3775            .await?
3776        }
3777        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3778    };
3779
3780    let embeddings = request
3781        .texts
3782        .iter()
3783        .map(|text| {
3784            let mut hasher = sha2::Sha256::new();
3785            hasher.update(text.as_bytes());
3786            let result = hasher.finalize();
3787            result.to_vec()
3788        })
3789        .zip(
3790            embeddings
3791                .data
3792                .into_iter()
3793                .map(|embedding| embedding.embedding),
3794        )
3795        .collect::<HashMap<_, _>>();
3796
3797    let db = session.db().await;
3798    db.save_embeddings(&request.model, &embeddings)
3799        .await
3800        .context("failed to save embeddings")
3801        .trace_err();
3802
3803    response.send(proto::ComputeEmbeddingsResponse {
3804        embeddings: embeddings
3805            .into_iter()
3806            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3807            .collect(),
3808    })?;
3809    Ok(())
3810}
3811
3812async fn get_cached_embeddings(
3813    request: proto::GetCachedEmbeddings,
3814    response: Response<proto::GetCachedEmbeddings>,
3815    session: Session,
3816) -> Result<()> {
3817    authorize_access_to_legacy_llm_endpoints(&session).await?;
3818
3819    let db = session.db().await;
3820    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3821
3822    response.send(proto::GetCachedEmbeddingsResponse {
3823        embeddings: embeddings
3824            .into_iter()
3825            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3826            .collect(),
3827    })?;
3828    Ok(())
3829}
3830
3831/// This is leftover from before the LLM service.
3832///
3833/// The endpoints protected by this check will be moved there eventually.
3834async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3835    if session.is_staff() {
3836        Ok(())
3837    } else {
3838        Err(anyhow!("permission denied"))?
3839    }
3840}
3841
3842/// Get a Supermaven API key for the user
3843async fn get_supermaven_api_key(
3844    _request: proto::GetSupermavenApiKey,
3845    response: Response<proto::GetSupermavenApiKey>,
3846    session: Session,
3847) -> Result<()> {
3848    let user_id: String = session.user_id().to_string();
3849    if !session.is_staff() {
3850        return Err(anyhow!("supermaven not enabled for this account"))?;
3851    }
3852
3853    let email = session
3854        .email()
3855        .ok_or_else(|| anyhow!("user must have an email"))?;
3856
3857    let supermaven_admin_api = session
3858        .supermaven_client
3859        .as_ref()
3860        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3861
3862    let result = supermaven_admin_api
3863        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3864        .await?;
3865
3866    response.send(proto::GetSupermavenApiKeyResponse {
3867        api_key: result.api_key,
3868    })?;
3869
3870    Ok(())
3871}
3872
3873/// Start receiving chat updates for a channel
3874async fn join_channel_chat(
3875    request: proto::JoinChannelChat,
3876    response: Response<proto::JoinChannelChat>,
3877    session: Session,
3878) -> Result<()> {
3879    let channel_id = ChannelId::from_proto(request.channel_id);
3880
3881    let db = session.db().await;
3882    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3883        .await?;
3884    let messages = db
3885        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3886        .await?;
3887    response.send(proto::JoinChannelChatResponse {
3888        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3889        messages,
3890    })?;
3891    Ok(())
3892}
3893
3894/// Stop receiving chat updates for a channel
3895async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3896    let channel_id = ChannelId::from_proto(request.channel_id);
3897    session
3898        .db()
3899        .await
3900        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3901        .await?;
3902    Ok(())
3903}
3904
3905/// Retrieve the chat history for a channel
3906async fn get_channel_messages(
3907    request: proto::GetChannelMessages,
3908    response: Response<proto::GetChannelMessages>,
3909    session: Session,
3910) -> Result<()> {
3911    let channel_id = ChannelId::from_proto(request.channel_id);
3912    let messages = session
3913        .db()
3914        .await
3915        .get_channel_messages(
3916            channel_id,
3917            session.user_id(),
3918            MESSAGE_COUNT_PER_PAGE,
3919            Some(MessageId::from_proto(request.before_message_id)),
3920        )
3921        .await?;
3922    response.send(proto::GetChannelMessagesResponse {
3923        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3924        messages,
3925    })?;
3926    Ok(())
3927}
3928
3929/// Retrieve specific chat messages
3930async fn get_channel_messages_by_id(
3931    request: proto::GetChannelMessagesById,
3932    response: Response<proto::GetChannelMessagesById>,
3933    session: Session,
3934) -> Result<()> {
3935    let message_ids = request
3936        .message_ids
3937        .iter()
3938        .map(|id| MessageId::from_proto(*id))
3939        .collect::<Vec<_>>();
3940    let messages = session
3941        .db()
3942        .await
3943        .get_channel_messages_by_id(session.user_id(), &message_ids)
3944        .await?;
3945    response.send(proto::GetChannelMessagesResponse {
3946        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3947        messages,
3948    })?;
3949    Ok(())
3950}
3951
3952/// Retrieve the current users notifications
3953async fn get_notifications(
3954    request: proto::GetNotifications,
3955    response: Response<proto::GetNotifications>,
3956    session: Session,
3957) -> Result<()> {
3958    let notifications = session
3959        .db()
3960        .await
3961        .get_notifications(
3962            session.user_id(),
3963            NOTIFICATION_COUNT_PER_PAGE,
3964            request.before_id.map(db::NotificationId::from_proto),
3965        )
3966        .await?;
3967    response.send(proto::GetNotificationsResponse {
3968        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3969        notifications,
3970    })?;
3971    Ok(())
3972}
3973
3974/// Mark notifications as read
3975async fn mark_notification_as_read(
3976    request: proto::MarkNotificationRead,
3977    response: Response<proto::MarkNotificationRead>,
3978    session: Session,
3979) -> Result<()> {
3980    let database = &session.db().await;
3981    let notifications = database
3982        .mark_notification_as_read_by_id(
3983            session.user_id(),
3984            NotificationId::from_proto(request.notification_id),
3985        )
3986        .await?;
3987    send_notifications(
3988        &*session.connection_pool().await,
3989        &session.peer,
3990        notifications,
3991    );
3992    response.send(proto::Ack {})?;
3993    Ok(())
3994}
3995
3996/// Get the current users information
3997async fn get_private_user_info(
3998    _request: proto::GetPrivateUserInfo,
3999    response: Response<proto::GetPrivateUserInfo>,
4000    session: Session,
4001) -> Result<()> {
4002    let db = session.db().await;
4003
4004    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4005    let user = db
4006        .get_user_by_id(session.user_id())
4007        .await?
4008        .ok_or_else(|| anyhow!("user not found"))?;
4009    let flags = db.get_user_flags(session.user_id()).await?;
4010
4011    response.send(proto::GetPrivateUserInfoResponse {
4012        metrics_id,
4013        staff: user.admin,
4014        flags,
4015        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4016    })?;
4017    Ok(())
4018}
4019
4020/// Accept the terms of service (tos) on behalf of the current user
4021async fn accept_terms_of_service(
4022    _request: proto::AcceptTermsOfService,
4023    response: Response<proto::AcceptTermsOfService>,
4024    session: Session,
4025) -> Result<()> {
4026    let db = session.db().await;
4027
4028    let accepted_tos_at = Utc::now();
4029    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4030        .await?;
4031
4032    response.send(proto::AcceptTermsOfServiceResponse {
4033        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4034    })?;
4035    Ok(())
4036}
4037
4038/// The minimum account age an account must have in order to use the LLM service.
4039pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4040
4041async fn get_llm_api_token(
4042    _request: proto::GetLlmToken,
4043    response: Response<proto::GetLlmToken>,
4044    session: Session,
4045) -> Result<()> {
4046    let db = session.db().await;
4047
4048    let flags = db.get_user_flags(session.user_id()).await?;
4049    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4050
4051    if !session.is_staff() && !has_language_models_feature_flag {
4052        Err(anyhow!("permission denied"))?
4053    }
4054
4055    let user_id = session.user_id();
4056    let user = db
4057        .get_user_by_id(user_id)
4058        .await?
4059        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4060
4061    if user.accepted_tos_at.is_none() {
4062        Err(anyhow!("terms of service not accepted"))?
4063    }
4064
4065    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4066    let billing_preferences = db.get_billing_preferences(user.id).await?;
4067
4068    let token = LlmTokenClaims::create(
4069        &user,
4070        session.is_staff(),
4071        billing_preferences,
4072        &flags,
4073        has_llm_subscription,
4074        session.current_plan(&db).await?,
4075        session.system_id.clone(),
4076        &session.app_state.config,
4077    )?;
4078    response.send(proto::GetLlmTokenResponse { token })?;
4079    Ok(())
4080}
4081
4082fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4083    let message = match message {
4084        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4085        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4086        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4087        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4088        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4089            code: frame.code.into(),
4090            reason: frame.reason,
4091        })),
4092        // We should never receive a frame while reading the message, according
4093        // to the `tungstenite` maintainers:
4094        //
4095        // > It cannot occur when you read messages from the WebSocket, but it
4096        // > can be used when you want to send the raw frames (e.g. you want to
4097        // > send the frames to the WebSocket without composing the full message first).
4098        // >
4099        // > — https://github.com/snapview/tungstenite-rs/issues/268
4100        TungsteniteMessage::Frame(_) => {
4101            bail!("received an unexpected frame while reading the message")
4102        }
4103    };
4104
4105    Ok(message)
4106}
4107
4108fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4109    match message {
4110        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4111        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4112        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4113        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4114        AxumMessage::Close(frame) => {
4115            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4116                code: frame.code.into(),
4117                reason: frame.reason,
4118            }))
4119        }
4120    }
4121}
4122
4123fn notify_membership_updated(
4124    connection_pool: &mut ConnectionPool,
4125    result: MembershipUpdated,
4126    user_id: UserId,
4127    peer: &Peer,
4128) {
4129    for membership in &result.new_channels.channel_memberships {
4130        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4131    }
4132    for channel_id in &result.removed_channels {
4133        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4134    }
4135
4136    let user_channels_update = proto::UpdateUserChannels {
4137        channel_memberships: result
4138            .new_channels
4139            .channel_memberships
4140            .iter()
4141            .map(|cm| proto::ChannelMembership {
4142                channel_id: cm.channel_id.to_proto(),
4143                role: cm.role.into(),
4144            })
4145            .collect(),
4146        ..Default::default()
4147    };
4148
4149    let mut update = build_channels_update(result.new_channels);
4150    update.delete_channels = result
4151        .removed_channels
4152        .into_iter()
4153        .map(|id| id.to_proto())
4154        .collect();
4155    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4156
4157    for connection_id in connection_pool.user_connection_ids(user_id) {
4158        peer.send(connection_id, user_channels_update.clone())
4159            .trace_err();
4160        peer.send(connection_id, update.clone()).trace_err();
4161    }
4162}
4163
4164fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4165    proto::UpdateUserChannels {
4166        channel_memberships: channels
4167            .channel_memberships
4168            .iter()
4169            .map(|m| proto::ChannelMembership {
4170                channel_id: m.channel_id.to_proto(),
4171                role: m.role.into(),
4172            })
4173            .collect(),
4174        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4175        observed_channel_message_id: channels.observed_channel_messages.clone(),
4176    }
4177}
4178
4179fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4180    let mut update = proto::UpdateChannels::default();
4181
4182    for channel in channels.channels {
4183        update.channels.push(channel.to_proto());
4184    }
4185
4186    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4187    update.latest_channel_message_ids = channels.latest_channel_messages;
4188
4189    for (channel_id, participants) in channels.channel_participants {
4190        update
4191            .channel_participants
4192            .push(proto::ChannelParticipants {
4193                channel_id: channel_id.to_proto(),
4194                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4195            });
4196    }
4197
4198    for channel in channels.invited_channels {
4199        update.channel_invitations.push(channel.to_proto());
4200    }
4201
4202    update
4203}
4204
4205fn build_initial_contacts_update(
4206    contacts: Vec<db::Contact>,
4207    pool: &ConnectionPool,
4208) -> proto::UpdateContacts {
4209    let mut update = proto::UpdateContacts::default();
4210
4211    for contact in contacts {
4212        match contact {
4213            db::Contact::Accepted { user_id, busy } => {
4214                update.contacts.push(contact_for_user(user_id, busy, pool));
4215            }
4216            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4217            db::Contact::Incoming { user_id } => {
4218                update
4219                    .incoming_requests
4220                    .push(proto::IncomingContactRequest {
4221                        requester_id: user_id.to_proto(),
4222                    })
4223            }
4224        }
4225    }
4226
4227    update
4228}
4229
4230fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4231    proto::Contact {
4232        user_id: user_id.to_proto(),
4233        online: pool.is_user_online(user_id),
4234        busy,
4235    }
4236}
4237
4238fn room_updated(room: &proto::Room, peer: &Peer) {
4239    broadcast(
4240        None,
4241        room.participants
4242            .iter()
4243            .filter_map(|participant| Some(participant.peer_id?.into())),
4244        |peer_id| {
4245            peer.send(
4246                peer_id,
4247                proto::RoomUpdated {
4248                    room: Some(room.clone()),
4249                },
4250            )
4251        },
4252    );
4253}
4254
4255fn channel_updated(
4256    channel: &db::channel::Model,
4257    room: &proto::Room,
4258    peer: &Peer,
4259    pool: &ConnectionPool,
4260) {
4261    let participants = room
4262        .participants
4263        .iter()
4264        .map(|p| p.user_id)
4265        .collect::<Vec<_>>();
4266
4267    broadcast(
4268        None,
4269        pool.channel_connection_ids(channel.root_id())
4270            .filter_map(|(channel_id, role)| {
4271                role.can_see_channel(channel.visibility)
4272                    .then_some(channel_id)
4273            }),
4274        |peer_id| {
4275            peer.send(
4276                peer_id,
4277                proto::UpdateChannels {
4278                    channel_participants: vec![proto::ChannelParticipants {
4279                        channel_id: channel.id.to_proto(),
4280                        participant_user_ids: participants.clone(),
4281                    }],
4282                    ..Default::default()
4283                },
4284            )
4285        },
4286    );
4287}
4288
4289async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4290    let db = session.db().await;
4291
4292    let contacts = db.get_contacts(user_id).await?;
4293    let busy = db.is_user_busy(user_id).await?;
4294
4295    let pool = session.connection_pool().await;
4296    let updated_contact = contact_for_user(user_id, busy, &pool);
4297    for contact in contacts {
4298        if let db::Contact::Accepted {
4299            user_id: contact_user_id,
4300            ..
4301        } = contact
4302        {
4303            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4304                session
4305                    .peer
4306                    .send(
4307                        contact_conn_id,
4308                        proto::UpdateContacts {
4309                            contacts: vec![updated_contact.clone()],
4310                            remove_contacts: Default::default(),
4311                            incoming_requests: Default::default(),
4312                            remove_incoming_requests: Default::default(),
4313                            outgoing_requests: Default::default(),
4314                            remove_outgoing_requests: Default::default(),
4315                        },
4316                    )
4317                    .trace_err();
4318            }
4319        }
4320    }
4321    Ok(())
4322}
4323
4324async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4325    let mut contacts_to_update = HashSet::default();
4326
4327    let room_id;
4328    let canceled_calls_to_user_ids;
4329    let livekit_room;
4330    let delete_livekit_room;
4331    let room;
4332    let channel;
4333
4334    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4335        contacts_to_update.insert(session.user_id());
4336
4337        for project in left_room.left_projects.values() {
4338            project_left(project, session);
4339        }
4340
4341        room_id = RoomId::from_proto(left_room.room.id);
4342        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4343        livekit_room = mem::take(&mut left_room.room.livekit_room);
4344        delete_livekit_room = left_room.deleted;
4345        room = mem::take(&mut left_room.room);
4346        channel = mem::take(&mut left_room.channel);
4347
4348        room_updated(&room, &session.peer);
4349    } else {
4350        return Ok(());
4351    }
4352
4353    if let Some(channel) = channel {
4354        channel_updated(
4355            &channel,
4356            &room,
4357            &session.peer,
4358            &*session.connection_pool().await,
4359        );
4360    }
4361
4362    {
4363        let pool = session.connection_pool().await;
4364        for canceled_user_id in canceled_calls_to_user_ids {
4365            for connection_id in pool.user_connection_ids(canceled_user_id) {
4366                session
4367                    .peer
4368                    .send(
4369                        connection_id,
4370                        proto::CallCanceled {
4371                            room_id: room_id.to_proto(),
4372                        },
4373                    )
4374                    .trace_err();
4375            }
4376            contacts_to_update.insert(canceled_user_id);
4377        }
4378    }
4379
4380    for contact_user_id in contacts_to_update {
4381        update_user_contacts(contact_user_id, session).await?;
4382    }
4383
4384    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4385        live_kit
4386            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4387            .await
4388            .trace_err();
4389
4390        if delete_livekit_room {
4391            live_kit.delete_room(livekit_room).await.trace_err();
4392        }
4393    }
4394
4395    Ok(())
4396}
4397
4398async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4399    let left_channel_buffers = session
4400        .db()
4401        .await
4402        .leave_channel_buffers(session.connection_id)
4403        .await?;
4404
4405    for left_buffer in left_channel_buffers {
4406        channel_buffer_updated(
4407            session.connection_id,
4408            left_buffer.connections,
4409            &proto::UpdateChannelBufferCollaborators {
4410                channel_id: left_buffer.channel_id.to_proto(),
4411                collaborators: left_buffer.collaborators,
4412            },
4413            &session.peer,
4414        );
4415    }
4416
4417    Ok(())
4418}
4419
4420fn project_left(project: &db::LeftProject, session: &Session) {
4421    for connection_id in &project.connection_ids {
4422        if project.should_unshare {
4423            session
4424                .peer
4425                .send(
4426                    *connection_id,
4427                    proto::UnshareProject {
4428                        project_id: project.id.to_proto(),
4429                    },
4430                )
4431                .trace_err();
4432        } else {
4433            session
4434                .peer
4435                .send(
4436                    *connection_id,
4437                    proto::RemoveProjectCollaborator {
4438                        project_id: project.id.to_proto(),
4439                        peer_id: Some(session.connection_id.into()),
4440                    },
4441                )
4442                .trace_err();
4443        }
4444    }
4445}
4446
4447pub trait ResultExt {
4448    type Ok;
4449
4450    fn trace_err(self) -> Option<Self::Ok>;
4451}
4452
4453impl<T, E> ResultExt for Result<T, E>
4454where
4455    E: std::fmt::Debug,
4456{
4457    type Ok = T;
4458
4459    #[track_caller]
4460    fn trace_err(self) -> Option<T> {
4461        match self {
4462            Ok(value) => Some(value),
4463            Err(error) => {
4464                tracing::error!("{:?}", error);
4465                None
4466            }
4467        }
4468    }
4469}