rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 341            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 342            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 344            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 345            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 346            .add_message_handler(create_buffer_for_peer)
 347            .add_request_handler(update_buffer)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 353            .add_request_handler(get_users)
 354            .add_request_handler(fuzzy_search_users)
 355            .add_request_handler(request_contact)
 356            .add_request_handler(remove_contact)
 357            .add_request_handler(respond_to_contact_request)
 358            .add_message_handler(subscribe_to_channels)
 359            .add_request_handler(create_channel)
 360            .add_request_handler(delete_channel)
 361            .add_request_handler(invite_channel_member)
 362            .add_request_handler(remove_channel_member)
 363            .add_request_handler(set_channel_member_role)
 364            .add_request_handler(set_channel_visibility)
 365            .add_request_handler(rename_channel)
 366            .add_request_handler(join_channel_buffer)
 367            .add_request_handler(leave_channel_buffer)
 368            .add_message_handler(update_channel_buffer)
 369            .add_request_handler(rejoin_channel_buffers)
 370            .add_request_handler(get_channel_members)
 371            .add_request_handler(respond_to_channel_invite)
 372            .add_request_handler(join_channel)
 373            .add_request_handler(join_channel_chat)
 374            .add_message_handler(leave_channel_chat)
 375            .add_request_handler(send_channel_message)
 376            .add_request_handler(remove_channel_message)
 377            .add_request_handler(update_channel_message)
 378            .add_request_handler(get_channel_messages)
 379            .add_request_handler(get_channel_messages_by_id)
 380            .add_request_handler(get_notifications)
 381            .add_request_handler(mark_notification_as_read)
 382            .add_request_handler(move_channel)
 383            .add_request_handler(follow)
 384            .add_message_handler(unfollow)
 385            .add_message_handler(update_followers)
 386            .add_request_handler(get_private_user_info)
 387            .add_request_handler(get_llm_api_token)
 388            .add_request_handler(accept_terms_of_service)
 389            .add_message_handler(acknowledge_channel_message)
 390            .add_message_handler(acknowledge_buffer_version)
 391            .add_request_handler(get_supermaven_api_key)
 392            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 395            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 398            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 399            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 400            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 401            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 402            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 404            .add_message_handler(update_context)
 405            .add_request_handler({
 406                let app_state = app_state.clone();
 407                move |request, response, session| {
 408                    let app_state = app_state.clone();
 409                    async move {
 410                        count_language_model_tokens(request, response, session, &app_state.config)
 411                            .await
 412                    }
 413                }
 414            })
 415            .add_request_handler(get_cached_embeddings)
 416            .add_request_handler({
 417                let app_state = app_state.clone();
 418                move |request, response, session| {
 419                    compute_embeddings(
 420                        request,
 421                        response,
 422                        session,
 423                        app_state.config.openai_api_key.clone(),
 424                    )
 425                }
 426            });
 427
 428        Arc::new(server)
 429    }
 430
 431    pub async fn start(&self) -> Result<()> {
 432        let server_id = *self.id.lock();
 433        let app_state = self.app_state.clone();
 434        let peer = self.peer.clone();
 435        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 436        let pool = self.connection_pool.clone();
 437        let livekit_client = self.app_state.livekit_client.clone();
 438
 439        let span = info_span!("start server");
 440        self.app_state.executor.spawn_detached(
 441            async move {
 442                tracing::info!("waiting for cleanup timeout");
 443                timeout.await;
 444                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 445                if let Some((room_ids, channel_ids)) = app_state
 446                    .db
 447                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 448                    .await
 449                    .trace_err()
 450                {
 451                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 452                    tracing::info!(
 453                        stale_channel_buffer_count = channel_ids.len(),
 454                        "retrieved stale channel buffers"
 455                    );
 456
 457                    for channel_id in channel_ids {
 458                        if let Some(refreshed_channel_buffer) = app_state
 459                            .db
 460                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 461                            .await
 462                            .trace_err()
 463                        {
 464                            for connection_id in refreshed_channel_buffer.connection_ids {
 465                                peer.send(
 466                                    connection_id,
 467                                    proto::UpdateChannelBufferCollaborators {
 468                                        channel_id: channel_id.to_proto(),
 469                                        collaborators: refreshed_channel_buffer
 470                                            .collaborators
 471                                            .clone(),
 472                                    },
 473                                )
 474                                .trace_err();
 475                            }
 476                        }
 477                    }
 478
 479                    for room_id in room_ids {
 480                        let mut contacts_to_update = HashSet::default();
 481                        let mut canceled_calls_to_user_ids = Vec::new();
 482                        let mut livekit_room = String::new();
 483                        let mut delete_livekit_room = false;
 484
 485                        if let Some(mut refreshed_room) = app_state
 486                            .db
 487                            .clear_stale_room_participants(room_id, server_id)
 488                            .await
 489                            .trace_err()
 490                        {
 491                            tracing::info!(
 492                                room_id = room_id.0,
 493                                new_participant_count = refreshed_room.room.participants.len(),
 494                                "refreshed room"
 495                            );
 496                            room_updated(&refreshed_room.room, &peer);
 497                            if let Some(channel) = refreshed_room.channel.as_ref() {
 498                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 499                            }
 500                            contacts_to_update
 501                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 502                            contacts_to_update
 503                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 504                            canceled_calls_to_user_ids =
 505                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 506                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 507                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 508                        }
 509
 510                        {
 511                            let pool = pool.lock();
 512                            for canceled_user_id in canceled_calls_to_user_ids {
 513                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 514                                    peer.send(
 515                                        connection_id,
 516                                        proto::CallCanceled {
 517                                            room_id: room_id.to_proto(),
 518                                        },
 519                                    )
 520                                    .trace_err();
 521                                }
 522                            }
 523                        }
 524
 525                        for user_id in contacts_to_update {
 526                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 527                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 528                            if let Some((busy, contacts)) = busy.zip(contacts) {
 529                                let pool = pool.lock();
 530                                let updated_contact = contact_for_user(user_id, busy, &pool);
 531                                for contact in contacts {
 532                                    if let db::Contact::Accepted {
 533                                        user_id: contact_user_id,
 534                                        ..
 535                                    } = contact
 536                                    {
 537                                        for contact_conn_id in
 538                                            pool.user_connection_ids(contact_user_id)
 539                                        {
 540                                            peer.send(
 541                                                contact_conn_id,
 542                                                proto::UpdateContacts {
 543                                                    contacts: vec![updated_contact.clone()],
 544                                                    remove_contacts: Default::default(),
 545                                                    incoming_requests: Default::default(),
 546                                                    remove_incoming_requests: Default::default(),
 547                                                    outgoing_requests: Default::default(),
 548                                                    remove_outgoing_requests: Default::default(),
 549                                                },
 550                                            )
 551                                            .trace_err();
 552                                        }
 553                                    }
 554                                }
 555                            }
 556                        }
 557
 558                        if let Some(live_kit) = livekit_client.as_ref() {
 559                            if delete_livekit_room {
 560                                live_kit.delete_room(livekit_room).await.trace_err();
 561                            }
 562                        }
 563                    }
 564                }
 565
 566                app_state
 567                    .db
 568                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 569                    .await
 570                    .trace_err();
 571            }
 572            .instrument(span),
 573        );
 574        Ok(())
 575    }
 576
 577    pub fn teardown(&self) {
 578        self.peer.teardown();
 579        self.connection_pool.lock().reset();
 580        let _ = self.teardown.send(true);
 581    }
 582
 583    #[cfg(test)]
 584    pub fn reset(&self, id: ServerId) {
 585        self.teardown();
 586        *self.id.lock() = id;
 587        self.peer.reset(id.0 as u32);
 588        let _ = self.teardown.send(false);
 589    }
 590
 591    #[cfg(test)]
 592    pub fn id(&self) -> ServerId {
 593        *self.id.lock()
 594    }
 595
 596    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 597    where
 598        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 599        Fut: 'static + Send + Future<Output = Result<()>>,
 600        M: EnvelopedMessage,
 601    {
 602        let prev_handler = self.handlers.insert(
 603            TypeId::of::<M>(),
 604            Box::new(move |envelope, session| {
 605                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 606                let received_at = envelope.received_at;
 607                tracing::info!("message received");
 608                let start_time = Instant::now();
 609                let future = (handler)(*envelope, session);
 610                async move {
 611                    let result = future.await;
 612                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 613                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 614                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 615                    let payload_type = M::NAME;
 616
 617                    match result {
 618                        Err(error) => {
 619                            tracing::error!(
 620                                ?error,
 621                                total_duration_ms,
 622                                processing_duration_ms,
 623                                queue_duration_ms,
 624                                payload_type,
 625                                "error handling message"
 626                            )
 627                        }
 628                        Ok(()) => tracing::info!(
 629                            total_duration_ms,
 630                            processing_duration_ms,
 631                            queue_duration_ms,
 632                            "finished handling message"
 633                        ),
 634                    }
 635                }
 636                .boxed()
 637            }),
 638        );
 639        if prev_handler.is_some() {
 640            panic!("registered a handler for the same message twice");
 641        }
 642        self
 643    }
 644
 645    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 646    where
 647        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 648        Fut: 'static + Send + Future<Output = Result<()>>,
 649        M: EnvelopedMessage,
 650    {
 651        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 652        self
 653    }
 654
 655    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 656    where
 657        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 658        Fut: Send + Future<Output = Result<()>>,
 659        M: RequestMessage,
 660    {
 661        let handler = Arc::new(handler);
 662        self.add_handler(move |envelope, session| {
 663            let receipt = envelope.receipt();
 664            let handler = handler.clone();
 665            async move {
 666                let peer = session.peer.clone();
 667                let responded = Arc::new(AtomicBool::default());
 668                let response = Response {
 669                    peer: peer.clone(),
 670                    responded: responded.clone(),
 671                    receipt,
 672                };
 673                match (handler)(envelope.payload, response, session).await {
 674                    Ok(()) => {
 675                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 676                            Ok(())
 677                        } else {
 678                            Err(anyhow!("handler did not send a response"))?
 679                        }
 680                    }
 681                    Err(error) => {
 682                        let proto_err = match &error {
 683                            Error::Internal(err) => err.to_proto(),
 684                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 685                        };
 686                        peer.respond_with_error(receipt, proto_err)?;
 687                        Err(error)
 688                    }
 689                }
 690            }
 691        })
 692    }
 693
 694    pub fn handle_connection(
 695        self: &Arc<Self>,
 696        connection: Connection,
 697        address: String,
 698        principal: Principal,
 699        zed_version: ZedVersion,
 700        geoip_country_code: Option<String>,
 701        system_id: Option<String>,
 702        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 703        executor: Executor,
 704    ) -> impl Future<Output = ()> {
 705        let this = self.clone();
 706        let span = info_span!("handle connection", %address,
 707            connection_id=field::Empty,
 708            user_id=field::Empty,
 709            login=field::Empty,
 710            impersonator=field::Empty,
 711            geoip_country_code=field::Empty
 712        );
 713        principal.update_span(&span);
 714        if let Some(country_code) = geoip_country_code.as_ref() {
 715            span.record("geoip_country_code", country_code);
 716        }
 717
 718        let mut teardown = self.teardown.subscribe();
 719        async move {
 720            if *teardown.borrow() {
 721                tracing::error!("server is tearing down");
 722                return
 723            }
 724            let (connection_id, handle_io, mut incoming_rx) = this
 725                .peer
 726                .add_connection(connection, {
 727                    let executor = executor.clone();
 728                    move |duration| executor.sleep(duration)
 729                });
 730            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 731
 732            tracing::info!("connection opened");
 733
 734            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 735            let http_client = match ReqwestClient::user_agent(&user_agent) {
 736                Ok(http_client) => Arc::new(http_client),
 737                Err(error) => {
 738                    tracing::error!(?error, "failed to create HTTP client");
 739                    return;
 740                }
 741            };
 742
 743            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 744                    supermaven_admin_api_key.to_string(),
 745                    http_client.clone(),
 746                )));
 747
 748            let session = Session {
 749                principal: principal.clone(),
 750                connection_id,
 751                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 752                peer: this.peer.clone(),
 753                connection_pool: this.connection_pool.clone(),
 754                app_state: this.app_state.clone(),
 755                http_client,
 756                geoip_country_code,
 757                system_id,
 758                _executor: executor.clone(),
 759                supermaven_client,
 760            };
 761
 762            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 763                tracing::error!(?error, "failed to send initial client update");
 764                return;
 765            }
 766
 767            let handle_io = handle_io.fuse();
 768            futures::pin_mut!(handle_io);
 769
 770            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 771            // This prevents deadlocks when e.g., client A performs a request to client B and
 772            // client B performs a request to client A. If both clients stop processing further
 773            // messages until their respective request completes, they won't have a chance to
 774            // respond to the other client's request and cause a deadlock.
 775            //
 776            // This arrangement ensures we will attempt to process earlier messages first, but fall
 777            // back to processing messages arrived later in the spirit of making progress.
 778            let mut foreground_message_handlers = FuturesUnordered::new();
 779            let concurrent_handlers = Arc::new(Semaphore::new(256));
 780            loop {
 781                let next_message = async {
 782                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 783                    let message = incoming_rx.next().await;
 784                    (permit, message)
 785                }.fuse();
 786                futures::pin_mut!(next_message);
 787                futures::select_biased! {
 788                    _ = teardown.changed().fuse() => return,
 789                    result = handle_io => {
 790                        if let Err(error) = result {
 791                            tracing::error!(?error, "error handling I/O");
 792                        }
 793                        break;
 794                    }
 795                    _ = foreground_message_handlers.next() => {}
 796                    next_message = next_message => {
 797                        let (permit, message) = next_message;
 798                        if let Some(message) = message {
 799                            let type_name = message.payload_type_name();
 800                            // note: we copy all the fields from the parent span so we can query them in the logs.
 801                            // (https://github.com/tokio-rs/tracing/issues/2670).
 802                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 803                                user_id=field::Empty,
 804                                login=field::Empty,
 805                                impersonator=field::Empty,
 806                            );
 807                            principal.update_span(&span);
 808                            let span_enter = span.enter();
 809                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 810                                let is_background = message.is_background();
 811                                let handle_message = (handler)(message, session.clone());
 812                                drop(span_enter);
 813
 814                                let handle_message = async move {
 815                                    handle_message.await;
 816                                    drop(permit);
 817                                }.instrument(span);
 818                                if is_background {
 819                                    executor.spawn_detached(handle_message);
 820                                } else {
 821                                    foreground_message_handlers.push(handle_message);
 822                                }
 823                            } else {
 824                                tracing::error!("no message handler");
 825                            }
 826                        } else {
 827                            tracing::info!("connection closed");
 828                            break;
 829                        }
 830                    }
 831                }
 832            }
 833
 834            drop(foreground_message_handlers);
 835            tracing::info!("signing out");
 836            if let Err(error) = connection_lost(session, teardown, executor).await {
 837                tracing::error!(?error, "error signing out");
 838            }
 839
 840        }.instrument(span)
 841    }
 842
 843    async fn send_initial_client_update(
 844        &self,
 845        connection_id: ConnectionId,
 846        principal: &Principal,
 847        zed_version: ZedVersion,
 848        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 849        session: &Session,
 850    ) -> Result<()> {
 851        self.peer.send(
 852            connection_id,
 853            proto::Hello {
 854                peer_id: Some(connection_id.into()),
 855            },
 856        )?;
 857        tracing::info!("sent hello message");
 858        if let Some(send_connection_id) = send_connection_id.take() {
 859            let _ = send_connection_id.send(connection_id);
 860        }
 861
 862        match principal {
 863            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 864                if !user.connected_once {
 865                    self.peer.send(connection_id, proto::ShowContacts {})?;
 866                    self.app_state
 867                        .db
 868                        .set_user_connected_once(user.id, true)
 869                        .await?;
 870                }
 871
 872                update_user_plan(user.id, session).await?;
 873
 874                let contacts = self.app_state.db.get_contacts(user.id).await?;
 875
 876                {
 877                    let mut pool = self.connection_pool.lock();
 878                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 879                    self.peer.send(
 880                        connection_id,
 881                        build_initial_contacts_update(contacts, &pool),
 882                    )?;
 883                }
 884
 885                if should_auto_subscribe_to_channels(zed_version) {
 886                    subscribe_user_to_channels(user.id, session).await?;
 887                }
 888
 889                if let Some(incoming_call) =
 890                    self.app_state.db.incoming_call_for_user(user.id).await?
 891                {
 892                    self.peer.send(connection_id, incoming_call)?;
 893                }
 894
 895                update_user_contacts(user.id, session).await?;
 896            }
 897        }
 898
 899        Ok(())
 900    }
 901
 902    pub async fn invite_code_redeemed(
 903        self: &Arc<Self>,
 904        inviter_id: UserId,
 905        invitee_id: UserId,
 906    ) -> Result<()> {
 907        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 908            if let Some(code) = &user.invite_code {
 909                let pool = self.connection_pool.lock();
 910                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 911                for connection_id in pool.user_connection_ids(inviter_id) {
 912                    self.peer.send(
 913                        connection_id,
 914                        proto::UpdateContacts {
 915                            contacts: vec![invitee_contact.clone()],
 916                            ..Default::default()
 917                        },
 918                    )?;
 919                    self.peer.send(
 920                        connection_id,
 921                        proto::UpdateInviteInfo {
 922                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 923                            count: user.invite_count as u32,
 924                        },
 925                    )?;
 926                }
 927            }
 928        }
 929        Ok(())
 930    }
 931
 932    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 933        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 934            if let Some(invite_code) = &user.invite_code {
 935                let pool = self.connection_pool.lock();
 936                for connection_id in pool.user_connection_ids(user_id) {
 937                    self.peer.send(
 938                        connection_id,
 939                        proto::UpdateInviteInfo {
 940                            url: format!(
 941                                "{}{}",
 942                                self.app_state.config.invite_link_prefix, invite_code
 943                            ),
 944                            count: user.invite_count as u32,
 945                        },
 946                    )?;
 947                }
 948            }
 949        }
 950        Ok(())
 951    }
 952
 953    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 954        let pool = self.connection_pool.lock();
 955        for connection_id in pool.user_connection_ids(user_id) {
 956            self.peer
 957                .send(connection_id, proto::RefreshLlmToken {})
 958                .trace_err();
 959        }
 960    }
 961
 962    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 963        ServerSnapshot {
 964            connection_pool: ConnectionPoolGuard {
 965                guard: self.connection_pool.lock(),
 966                _not_send: PhantomData,
 967            },
 968            peer: &self.peer,
 969        }
 970    }
 971}
 972
 973impl<'a> Deref for ConnectionPoolGuard<'a> {
 974    type Target = ConnectionPool;
 975
 976    fn deref(&self) -> &Self::Target {
 977        &self.guard
 978    }
 979}
 980
 981impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 982    fn deref_mut(&mut self) -> &mut Self::Target {
 983        &mut self.guard
 984    }
 985}
 986
 987impl<'a> Drop for ConnectionPoolGuard<'a> {
 988    fn drop(&mut self) {
 989        #[cfg(test)]
 990        self.check_invariants();
 991    }
 992}
 993
 994fn broadcast<F>(
 995    sender_id: Option<ConnectionId>,
 996    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 997    mut f: F,
 998) where
 999    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1000{
1001    for receiver_id in receiver_ids {
1002        if Some(receiver_id) != sender_id {
1003            if let Err(error) = f(receiver_id) {
1004                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1005            }
1006        }
1007    }
1008}
1009
1010pub struct ProtocolVersion(u32);
1011
1012impl Header for ProtocolVersion {
1013    fn name() -> &'static HeaderName {
1014        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1015        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1016    }
1017
1018    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1019    where
1020        Self: Sized,
1021        I: Iterator<Item = &'i axum::http::HeaderValue>,
1022    {
1023        let version = values
1024            .next()
1025            .ok_or_else(axum::headers::Error::invalid)?
1026            .to_str()
1027            .map_err(|_| axum::headers::Error::invalid())?
1028            .parse()
1029            .map_err(|_| axum::headers::Error::invalid())?;
1030        Ok(Self(version))
1031    }
1032
1033    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1034        values.extend([self.0.to_string().parse().unwrap()]);
1035    }
1036}
1037
1038pub struct AppVersionHeader(SemanticVersion);
1039impl Header for AppVersionHeader {
1040    fn name() -> &'static HeaderName {
1041        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1042        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1043    }
1044
1045    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1046    where
1047        Self: Sized,
1048        I: Iterator<Item = &'i axum::http::HeaderValue>,
1049    {
1050        let version = values
1051            .next()
1052            .ok_or_else(axum::headers::Error::invalid)?
1053            .to_str()
1054            .map_err(|_| axum::headers::Error::invalid())?
1055            .parse()
1056            .map_err(|_| axum::headers::Error::invalid())?;
1057        Ok(Self(version))
1058    }
1059
1060    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1061        values.extend([self.0.to_string().parse().unwrap()]);
1062    }
1063}
1064
1065pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1066    Router::new()
1067        .route("/rpc", get(handle_websocket_request))
1068        .layer(
1069            ServiceBuilder::new()
1070                .layer(Extension(server.app_state.clone()))
1071                .layer(middleware::from_fn(auth::validate_header)),
1072        )
1073        .route("/metrics", get(handle_metrics))
1074        .layer(Extension(server))
1075}
1076
1077pub async fn handle_websocket_request(
1078    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1079    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1080    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1081    Extension(server): Extension<Arc<Server>>,
1082    Extension(principal): Extension<Principal>,
1083    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1084    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1085    ws: WebSocketUpgrade,
1086) -> axum::response::Response {
1087    if protocol_version != rpc::PROTOCOL_VERSION {
1088        return (
1089            StatusCode::UPGRADE_REQUIRED,
1090            "client must be upgraded".to_string(),
1091        )
1092            .into_response();
1093    }
1094
1095    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1096        return (
1097            StatusCode::UPGRADE_REQUIRED,
1098            "no version header found".to_string(),
1099        )
1100            .into_response();
1101    };
1102
1103    if !version.can_collaborate() {
1104        return (
1105            StatusCode::UPGRADE_REQUIRED,
1106            "client must be upgraded".to_string(),
1107        )
1108            .into_response();
1109    }
1110
1111    let socket_address = socket_address.to_string();
1112    ws.on_upgrade(move |socket| {
1113        let socket = socket
1114            .map_ok(to_tungstenite_message)
1115            .err_into()
1116            .with(|message| async move { to_axum_message(message) });
1117        let connection = Connection::new(Box::pin(socket));
1118        async move {
1119            server
1120                .handle_connection(
1121                    connection,
1122                    socket_address,
1123                    principal,
1124                    version,
1125                    country_code_header.map(|header| header.to_string()),
1126                    system_id_header.map(|header| header.to_string()),
1127                    None,
1128                    Executor::Production,
1129                )
1130                .await;
1131        }
1132    })
1133}
1134
1135pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1136    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1137    let connections_metric = CONNECTIONS_METRIC
1138        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1139
1140    let connections = server
1141        .connection_pool
1142        .lock()
1143        .connections()
1144        .filter(|connection| !connection.admin)
1145        .count();
1146    connections_metric.set(connections as _);
1147
1148    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1149    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1150        register_int_gauge!(
1151            "shared_projects",
1152            "number of open projects with one or more guests"
1153        )
1154        .unwrap()
1155    });
1156
1157    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1158    shared_projects_metric.set(shared_projects as _);
1159
1160    let encoder = prometheus::TextEncoder::new();
1161    let metric_families = prometheus::gather();
1162    let encoded_metrics = encoder
1163        .encode_to_string(&metric_families)
1164        .map_err(|err| anyhow!("{}", err))?;
1165    Ok(encoded_metrics)
1166}
1167
1168#[instrument(err, skip(executor))]
1169async fn connection_lost(
1170    session: Session,
1171    mut teardown: watch::Receiver<bool>,
1172    executor: Executor,
1173) -> Result<()> {
1174    session.peer.disconnect(session.connection_id);
1175    session
1176        .connection_pool()
1177        .await
1178        .remove_connection(session.connection_id)?;
1179
1180    session
1181        .db()
1182        .await
1183        .connection_lost(session.connection_id)
1184        .await
1185        .trace_err();
1186
1187    futures::select_biased! {
1188        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1189
1190            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1191            leave_room_for_session(&session, session.connection_id).await.trace_err();
1192            leave_channel_buffers_for_session(&session)
1193                .await
1194                .trace_err();
1195
1196            if !session
1197                .connection_pool()
1198                .await
1199                .is_user_online(session.user_id())
1200            {
1201                let db = session.db().await;
1202                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1203                    room_updated(&room, &session.peer);
1204                }
1205            }
1206
1207            update_user_contacts(session.user_id(), &session).await?;
1208        },
1209        _ = teardown.changed().fuse() => {}
1210    }
1211
1212    Ok(())
1213}
1214
1215/// Acknowledges a ping from a client, used to keep the connection alive.
1216async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1217    response.send(proto::Ack {})?;
1218    Ok(())
1219}
1220
1221/// Creates a new room for calling (outside of channels)
1222async fn create_room(
1223    _request: proto::CreateRoom,
1224    response: Response<proto::CreateRoom>,
1225    session: Session,
1226) -> Result<()> {
1227    let livekit_room = nanoid::nanoid!(30);
1228
1229    let live_kit_connection_info = util::maybe!(async {
1230        let live_kit = session.app_state.livekit_client.as_ref();
1231        let live_kit = live_kit?;
1232        let user_id = session.user_id().to_string();
1233
1234        let token = live_kit
1235            .room_token(&livekit_room, &user_id.to_string())
1236            .trace_err()?;
1237
1238        Some(proto::LiveKitConnectionInfo {
1239            server_url: live_kit.url().into(),
1240            token,
1241            can_publish: true,
1242        })
1243    })
1244    .await;
1245
1246    let room = session
1247        .db()
1248        .await
1249        .create_room(session.user_id(), session.connection_id, &livekit_room)
1250        .await?;
1251
1252    response.send(proto::CreateRoomResponse {
1253        room: Some(room.clone()),
1254        live_kit_connection_info,
1255    })?;
1256
1257    update_user_contacts(session.user_id(), &session).await?;
1258    Ok(())
1259}
1260
1261/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1262async fn join_room(
1263    request: proto::JoinRoom,
1264    response: Response<proto::JoinRoom>,
1265    session: Session,
1266) -> Result<()> {
1267    let room_id = RoomId::from_proto(request.id);
1268
1269    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1270
1271    if let Some(channel_id) = channel_id {
1272        return join_channel_internal(channel_id, Box::new(response), session).await;
1273    }
1274
1275    let joined_room = {
1276        let room = session
1277            .db()
1278            .await
1279            .join_room(room_id, session.user_id(), session.connection_id)
1280            .await?;
1281        room_updated(&room.room, &session.peer);
1282        room.into_inner()
1283    };
1284
1285    for connection_id in session
1286        .connection_pool()
1287        .await
1288        .user_connection_ids(session.user_id())
1289    {
1290        session
1291            .peer
1292            .send(
1293                connection_id,
1294                proto::CallCanceled {
1295                    room_id: room_id.to_proto(),
1296                },
1297            )
1298            .trace_err();
1299    }
1300
1301    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1302    {
1303        live_kit
1304            .room_token(
1305                &joined_room.room.livekit_room,
1306                &session.user_id().to_string(),
1307            )
1308            .trace_err()
1309            .map(|token| proto::LiveKitConnectionInfo {
1310                server_url: live_kit.url().into(),
1311                token,
1312                can_publish: true,
1313            })
1314    } else {
1315        None
1316    };
1317
1318    response.send(proto::JoinRoomResponse {
1319        room: Some(joined_room.room),
1320        channel_id: None,
1321        live_kit_connection_info,
1322    })?;
1323
1324    update_user_contacts(session.user_id(), &session).await?;
1325    Ok(())
1326}
1327
1328/// Rejoin room is used to reconnect to a room after connection errors.
1329async fn rejoin_room(
1330    request: proto::RejoinRoom,
1331    response: Response<proto::RejoinRoom>,
1332    session: Session,
1333) -> Result<()> {
1334    let room;
1335    let channel;
1336    {
1337        let mut rejoined_room = session
1338            .db()
1339            .await
1340            .rejoin_room(request, session.user_id(), session.connection_id)
1341            .await?;
1342
1343        response.send(proto::RejoinRoomResponse {
1344            room: Some(rejoined_room.room.clone()),
1345            reshared_projects: rejoined_room
1346                .reshared_projects
1347                .iter()
1348                .map(|project| proto::ResharedProject {
1349                    id: project.id.to_proto(),
1350                    collaborators: project
1351                        .collaborators
1352                        .iter()
1353                        .map(|collaborator| collaborator.to_proto())
1354                        .collect(),
1355                })
1356                .collect(),
1357            rejoined_projects: rejoined_room
1358                .rejoined_projects
1359                .iter()
1360                .map(|rejoined_project| rejoined_project.to_proto())
1361                .collect(),
1362        })?;
1363        room_updated(&rejoined_room.room, &session.peer);
1364
1365        for project in &rejoined_room.reshared_projects {
1366            for collaborator in &project.collaborators {
1367                session
1368                    .peer
1369                    .send(
1370                        collaborator.connection_id,
1371                        proto::UpdateProjectCollaborator {
1372                            project_id: project.id.to_proto(),
1373                            old_peer_id: Some(project.old_connection_id.into()),
1374                            new_peer_id: Some(session.connection_id.into()),
1375                        },
1376                    )
1377                    .trace_err();
1378            }
1379
1380            broadcast(
1381                Some(session.connection_id),
1382                project
1383                    .collaborators
1384                    .iter()
1385                    .map(|collaborator| collaborator.connection_id),
1386                |connection_id| {
1387                    session.peer.forward_send(
1388                        session.connection_id,
1389                        connection_id,
1390                        proto::UpdateProject {
1391                            project_id: project.id.to_proto(),
1392                            worktrees: project.worktrees.clone(),
1393                        },
1394                    )
1395                },
1396            );
1397        }
1398
1399        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1400
1401        let rejoined_room = rejoined_room.into_inner();
1402
1403        room = rejoined_room.room;
1404        channel = rejoined_room.channel;
1405    }
1406
1407    if let Some(channel) = channel {
1408        channel_updated(
1409            &channel,
1410            &room,
1411            &session.peer,
1412            &*session.connection_pool().await,
1413        );
1414    }
1415
1416    update_user_contacts(session.user_id(), &session).await?;
1417    Ok(())
1418}
1419
1420fn notify_rejoined_projects(
1421    rejoined_projects: &mut Vec<RejoinedProject>,
1422    session: &Session,
1423) -> Result<()> {
1424    for project in rejoined_projects.iter() {
1425        for collaborator in &project.collaborators {
1426            session
1427                .peer
1428                .send(
1429                    collaborator.connection_id,
1430                    proto::UpdateProjectCollaborator {
1431                        project_id: project.id.to_proto(),
1432                        old_peer_id: Some(project.old_connection_id.into()),
1433                        new_peer_id: Some(session.connection_id.into()),
1434                    },
1435                )
1436                .trace_err();
1437        }
1438    }
1439
1440    for project in rejoined_projects {
1441        for worktree in mem::take(&mut project.worktrees) {
1442            // Stream this worktree's entries.
1443            let message = proto::UpdateWorktree {
1444                project_id: project.id.to_proto(),
1445                worktree_id: worktree.id,
1446                abs_path: worktree.abs_path.clone(),
1447                root_name: worktree.root_name,
1448                updated_entries: worktree.updated_entries,
1449                removed_entries: worktree.removed_entries,
1450                scan_id: worktree.scan_id,
1451                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1452                updated_repositories: worktree.updated_repositories,
1453                removed_repositories: worktree.removed_repositories,
1454            };
1455            for update in proto::split_worktree_update(message) {
1456                session.peer.send(session.connection_id, update.clone())?;
1457            }
1458
1459            // Stream this worktree's diagnostics.
1460            for summary in worktree.diagnostic_summaries {
1461                session.peer.send(
1462                    session.connection_id,
1463                    proto::UpdateDiagnosticSummary {
1464                        project_id: project.id.to_proto(),
1465                        worktree_id: worktree.id,
1466                        summary: Some(summary),
1467                    },
1468                )?;
1469            }
1470
1471            for settings_file in worktree.settings_files {
1472                session.peer.send(
1473                    session.connection_id,
1474                    proto::UpdateWorktreeSettings {
1475                        project_id: project.id.to_proto(),
1476                        worktree_id: worktree.id,
1477                        path: settings_file.path,
1478                        content: Some(settings_file.content),
1479                        kind: Some(settings_file.kind.to_proto().into()),
1480                    },
1481                )?;
1482            }
1483        }
1484
1485        for language_server in &project.language_servers {
1486            session.peer.send(
1487                session.connection_id,
1488                proto::UpdateLanguageServer {
1489                    project_id: project.id.to_proto(),
1490                    language_server_id: language_server.id,
1491                    variant: Some(
1492                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1493                            proto::LspDiskBasedDiagnosticsUpdated {},
1494                        ),
1495                    ),
1496                },
1497            )?;
1498        }
1499    }
1500    Ok(())
1501}
1502
1503/// leave room disconnects from the room.
1504async fn leave_room(
1505    _: proto::LeaveRoom,
1506    response: Response<proto::LeaveRoom>,
1507    session: Session,
1508) -> Result<()> {
1509    leave_room_for_session(&session, session.connection_id).await?;
1510    response.send(proto::Ack {})?;
1511    Ok(())
1512}
1513
1514/// Updates the permissions of someone else in the room.
1515async fn set_room_participant_role(
1516    request: proto::SetRoomParticipantRole,
1517    response: Response<proto::SetRoomParticipantRole>,
1518    session: Session,
1519) -> Result<()> {
1520    let user_id = UserId::from_proto(request.user_id);
1521    let role = ChannelRole::from(request.role());
1522
1523    let (livekit_room, can_publish) = {
1524        let room = session
1525            .db()
1526            .await
1527            .set_room_participant_role(
1528                session.user_id(),
1529                RoomId::from_proto(request.room_id),
1530                user_id,
1531                role,
1532            )
1533            .await?;
1534
1535        let livekit_room = room.livekit_room.clone();
1536        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1537        room_updated(&room, &session.peer);
1538        (livekit_room, can_publish)
1539    };
1540
1541    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1542        live_kit
1543            .update_participant(
1544                livekit_room.clone(),
1545                request.user_id.to_string(),
1546                livekit_api::proto::ParticipantPermission {
1547                    can_subscribe: true,
1548                    can_publish,
1549                    can_publish_data: can_publish,
1550                    hidden: false,
1551                    recorder: false,
1552                },
1553            )
1554            .await
1555            .trace_err();
1556    }
1557
1558    response.send(proto::Ack {})?;
1559    Ok(())
1560}
1561
1562/// Call someone else into the current room
1563async fn call(
1564    request: proto::Call,
1565    response: Response<proto::Call>,
1566    session: Session,
1567) -> Result<()> {
1568    let room_id = RoomId::from_proto(request.room_id);
1569    let calling_user_id = session.user_id();
1570    let calling_connection_id = session.connection_id;
1571    let called_user_id = UserId::from_proto(request.called_user_id);
1572    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1573    if !session
1574        .db()
1575        .await
1576        .has_contact(calling_user_id, called_user_id)
1577        .await?
1578    {
1579        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1580    }
1581
1582    let incoming_call = {
1583        let (room, incoming_call) = &mut *session
1584            .db()
1585            .await
1586            .call(
1587                room_id,
1588                calling_user_id,
1589                calling_connection_id,
1590                called_user_id,
1591                initial_project_id,
1592            )
1593            .await?;
1594        room_updated(room, &session.peer);
1595        mem::take(incoming_call)
1596    };
1597    update_user_contacts(called_user_id, &session).await?;
1598
1599    let mut calls = session
1600        .connection_pool()
1601        .await
1602        .user_connection_ids(called_user_id)
1603        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1604        .collect::<FuturesUnordered<_>>();
1605
1606    while let Some(call_response) = calls.next().await {
1607        match call_response.as_ref() {
1608            Ok(_) => {
1609                response.send(proto::Ack {})?;
1610                return Ok(());
1611            }
1612            Err(_) => {
1613                call_response.trace_err();
1614            }
1615        }
1616    }
1617
1618    {
1619        let room = session
1620            .db()
1621            .await
1622            .call_failed(room_id, called_user_id)
1623            .await?;
1624        room_updated(&room, &session.peer);
1625    }
1626    update_user_contacts(called_user_id, &session).await?;
1627
1628    Err(anyhow!("failed to ring user"))?
1629}
1630
1631/// Cancel an outgoing call.
1632async fn cancel_call(
1633    request: proto::CancelCall,
1634    response: Response<proto::CancelCall>,
1635    session: Session,
1636) -> Result<()> {
1637    let called_user_id = UserId::from_proto(request.called_user_id);
1638    let room_id = RoomId::from_proto(request.room_id);
1639    {
1640        let room = session
1641            .db()
1642            .await
1643            .cancel_call(room_id, session.connection_id, called_user_id)
1644            .await?;
1645        room_updated(&room, &session.peer);
1646    }
1647
1648    for connection_id in session
1649        .connection_pool()
1650        .await
1651        .user_connection_ids(called_user_id)
1652    {
1653        session
1654            .peer
1655            .send(
1656                connection_id,
1657                proto::CallCanceled {
1658                    room_id: room_id.to_proto(),
1659                },
1660            )
1661            .trace_err();
1662    }
1663    response.send(proto::Ack {})?;
1664
1665    update_user_contacts(called_user_id, &session).await?;
1666    Ok(())
1667}
1668
1669/// Decline an incoming call.
1670async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1671    let room_id = RoomId::from_proto(message.room_id);
1672    {
1673        let room = session
1674            .db()
1675            .await
1676            .decline_call(Some(room_id), session.user_id())
1677            .await?
1678            .ok_or_else(|| anyhow!("failed to decline call"))?;
1679        room_updated(&room, &session.peer);
1680    }
1681
1682    for connection_id in session
1683        .connection_pool()
1684        .await
1685        .user_connection_ids(session.user_id())
1686    {
1687        session
1688            .peer
1689            .send(
1690                connection_id,
1691                proto::CallCanceled {
1692                    room_id: room_id.to_proto(),
1693                },
1694            )
1695            .trace_err();
1696    }
1697    update_user_contacts(session.user_id(), &session).await?;
1698    Ok(())
1699}
1700
1701/// Updates other participants in the room with your current location.
1702async fn update_participant_location(
1703    request: proto::UpdateParticipantLocation,
1704    response: Response<proto::UpdateParticipantLocation>,
1705    session: Session,
1706) -> Result<()> {
1707    let room_id = RoomId::from_proto(request.room_id);
1708    let location = request
1709        .location
1710        .ok_or_else(|| anyhow!("invalid location"))?;
1711
1712    let db = session.db().await;
1713    let room = db
1714        .update_room_participant_location(room_id, session.connection_id, location)
1715        .await?;
1716
1717    room_updated(&room, &session.peer);
1718    response.send(proto::Ack {})?;
1719    Ok(())
1720}
1721
1722/// Share a project into the room.
1723async fn share_project(
1724    request: proto::ShareProject,
1725    response: Response<proto::ShareProject>,
1726    session: Session,
1727) -> Result<()> {
1728    let (project_id, room) = &*session
1729        .db()
1730        .await
1731        .share_project(
1732            RoomId::from_proto(request.room_id),
1733            session.connection_id,
1734            &request.worktrees,
1735            request.is_ssh_project,
1736        )
1737        .await?;
1738    response.send(proto::ShareProjectResponse {
1739        project_id: project_id.to_proto(),
1740    })?;
1741    room_updated(room, &session.peer);
1742
1743    Ok(())
1744}
1745
1746/// Unshare a project from the room.
1747async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1748    let project_id = ProjectId::from_proto(message.project_id);
1749    unshare_project_internal(project_id, session.connection_id, &session).await
1750}
1751
1752async fn unshare_project_internal(
1753    project_id: ProjectId,
1754    connection_id: ConnectionId,
1755    session: &Session,
1756) -> Result<()> {
1757    let delete = {
1758        let room_guard = session
1759            .db()
1760            .await
1761            .unshare_project(project_id, connection_id)
1762            .await?;
1763
1764        let (delete, room, guest_connection_ids) = &*room_guard;
1765
1766        let message = proto::UnshareProject {
1767            project_id: project_id.to_proto(),
1768        };
1769
1770        broadcast(
1771            Some(connection_id),
1772            guest_connection_ids.iter().copied(),
1773            |conn_id| session.peer.send(conn_id, message.clone()),
1774        );
1775        if let Some(room) = room {
1776            room_updated(room, &session.peer);
1777        }
1778
1779        *delete
1780    };
1781
1782    if delete {
1783        let db = session.db().await;
1784        db.delete_project(project_id).await?;
1785    }
1786
1787    Ok(())
1788}
1789
1790/// Join someone elses shared project.
1791async fn join_project(
1792    request: proto::JoinProject,
1793    response: Response<proto::JoinProject>,
1794    session: Session,
1795) -> Result<()> {
1796    let project_id = ProjectId::from_proto(request.project_id);
1797
1798    tracing::info!(%project_id, "join project");
1799
1800    let db = session.db().await;
1801    let (project, replica_id) = &mut *db
1802        .join_project(project_id, session.connection_id, session.user_id())
1803        .await?;
1804    drop(db);
1805    tracing::info!(%project_id, "join remote project");
1806    join_project_internal(response, session, project, replica_id)
1807}
1808
1809trait JoinProjectInternalResponse {
1810    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1811}
1812impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1813    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1814        Response::<proto::JoinProject>::send(self, result)
1815    }
1816}
1817
1818fn join_project_internal(
1819    response: impl JoinProjectInternalResponse,
1820    session: Session,
1821    project: &mut Project,
1822    replica_id: &ReplicaId,
1823) -> Result<()> {
1824    let collaborators = project
1825        .collaborators
1826        .iter()
1827        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1828        .map(|collaborator| collaborator.to_proto())
1829        .collect::<Vec<_>>();
1830    let project_id = project.id;
1831    let guest_user_id = session.user_id();
1832
1833    let worktrees = project
1834        .worktrees
1835        .iter()
1836        .map(|(id, worktree)| proto::WorktreeMetadata {
1837            id: *id,
1838            root_name: worktree.root_name.clone(),
1839            visible: worktree.visible,
1840            abs_path: worktree.abs_path.clone(),
1841        })
1842        .collect::<Vec<_>>();
1843
1844    let add_project_collaborator = proto::AddProjectCollaborator {
1845        project_id: project_id.to_proto(),
1846        collaborator: Some(proto::Collaborator {
1847            peer_id: Some(session.connection_id.into()),
1848            replica_id: replica_id.0 as u32,
1849            user_id: guest_user_id.to_proto(),
1850            is_host: false,
1851        }),
1852    };
1853
1854    for collaborator in &collaborators {
1855        session
1856            .peer
1857            .send(
1858                collaborator.peer_id.unwrap().into(),
1859                add_project_collaborator.clone(),
1860            )
1861            .trace_err();
1862    }
1863
1864    // First, we send the metadata associated with each worktree.
1865    response.send(proto::JoinProjectResponse {
1866        project_id: project.id.0 as u64,
1867        worktrees: worktrees.clone(),
1868        replica_id: replica_id.0 as u32,
1869        collaborators: collaborators.clone(),
1870        language_servers: project.language_servers.clone(),
1871        role: project.role.into(),
1872    })?;
1873
1874    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1875        // Stream this worktree's entries.
1876        let message = proto::UpdateWorktree {
1877            project_id: project_id.to_proto(),
1878            worktree_id,
1879            abs_path: worktree.abs_path.clone(),
1880            root_name: worktree.root_name,
1881            updated_entries: worktree.entries,
1882            removed_entries: Default::default(),
1883            scan_id: worktree.scan_id,
1884            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1885            updated_repositories: worktree.repository_entries.into_values().collect(),
1886            removed_repositories: Default::default(),
1887        };
1888        for update in proto::split_worktree_update(message) {
1889            session.peer.send(session.connection_id, update.clone())?;
1890        }
1891
1892        // Stream this worktree's diagnostics.
1893        for summary in worktree.diagnostic_summaries {
1894            session.peer.send(
1895                session.connection_id,
1896                proto::UpdateDiagnosticSummary {
1897                    project_id: project_id.to_proto(),
1898                    worktree_id: worktree.id,
1899                    summary: Some(summary),
1900                },
1901            )?;
1902        }
1903
1904        for settings_file in worktree.settings_files {
1905            session.peer.send(
1906                session.connection_id,
1907                proto::UpdateWorktreeSettings {
1908                    project_id: project_id.to_proto(),
1909                    worktree_id: worktree.id,
1910                    path: settings_file.path,
1911                    content: Some(settings_file.content),
1912                    kind: Some(settings_file.kind.to_proto() as i32),
1913                },
1914            )?;
1915        }
1916    }
1917
1918    for language_server in &project.language_servers {
1919        session.peer.send(
1920            session.connection_id,
1921            proto::UpdateLanguageServer {
1922                project_id: project_id.to_proto(),
1923                language_server_id: language_server.id,
1924                variant: Some(
1925                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1926                        proto::LspDiskBasedDiagnosticsUpdated {},
1927                    ),
1928                ),
1929            },
1930        )?;
1931    }
1932
1933    Ok(())
1934}
1935
1936/// Leave someone elses shared project.
1937async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1938    let sender_id = session.connection_id;
1939    let project_id = ProjectId::from_proto(request.project_id);
1940    let db = session.db().await;
1941
1942    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1943    tracing::info!(
1944        %project_id,
1945        "leave project"
1946    );
1947
1948    project_left(project, &session);
1949    if let Some(room) = room {
1950        room_updated(room, &session.peer);
1951    }
1952
1953    Ok(())
1954}
1955
1956/// Updates other participants with changes to the project
1957async fn update_project(
1958    request: proto::UpdateProject,
1959    response: Response<proto::UpdateProject>,
1960    session: Session,
1961) -> Result<()> {
1962    let project_id = ProjectId::from_proto(request.project_id);
1963    let (room, guest_connection_ids) = &*session
1964        .db()
1965        .await
1966        .update_project(project_id, session.connection_id, &request.worktrees)
1967        .await?;
1968    broadcast(
1969        Some(session.connection_id),
1970        guest_connection_ids.iter().copied(),
1971        |connection_id| {
1972            session
1973                .peer
1974                .forward_send(session.connection_id, connection_id, request.clone())
1975        },
1976    );
1977    if let Some(room) = room {
1978        room_updated(room, &session.peer);
1979    }
1980    response.send(proto::Ack {})?;
1981
1982    Ok(())
1983}
1984
1985/// Updates other participants with changes to the worktree
1986async fn update_worktree(
1987    request: proto::UpdateWorktree,
1988    response: Response<proto::UpdateWorktree>,
1989    session: Session,
1990) -> Result<()> {
1991    let guest_connection_ids = session
1992        .db()
1993        .await
1994        .update_worktree(&request, session.connection_id)
1995        .await?;
1996
1997    broadcast(
1998        Some(session.connection_id),
1999        guest_connection_ids.iter().copied(),
2000        |connection_id| {
2001            session
2002                .peer
2003                .forward_send(session.connection_id, connection_id, request.clone())
2004        },
2005    );
2006    response.send(proto::Ack {})?;
2007    Ok(())
2008}
2009
2010/// Updates other participants with changes to the diagnostics
2011async fn update_diagnostic_summary(
2012    message: proto::UpdateDiagnosticSummary,
2013    session: Session,
2014) -> Result<()> {
2015    let guest_connection_ids = session
2016        .db()
2017        .await
2018        .update_diagnostic_summary(&message, session.connection_id)
2019        .await?;
2020
2021    broadcast(
2022        Some(session.connection_id),
2023        guest_connection_ids.iter().copied(),
2024        |connection_id| {
2025            session
2026                .peer
2027                .forward_send(session.connection_id, connection_id, message.clone())
2028        },
2029    );
2030
2031    Ok(())
2032}
2033
2034/// Updates other participants with changes to the worktree settings
2035async fn update_worktree_settings(
2036    message: proto::UpdateWorktreeSettings,
2037    session: Session,
2038) -> Result<()> {
2039    let guest_connection_ids = session
2040        .db()
2041        .await
2042        .update_worktree_settings(&message, session.connection_id)
2043        .await?;
2044
2045    broadcast(
2046        Some(session.connection_id),
2047        guest_connection_ids.iter().copied(),
2048        |connection_id| {
2049            session
2050                .peer
2051                .forward_send(session.connection_id, connection_id, message.clone())
2052        },
2053    );
2054
2055    Ok(())
2056}
2057
2058/// Notify other participants that a  language server has started.
2059async fn start_language_server(
2060    request: proto::StartLanguageServer,
2061    session: Session,
2062) -> Result<()> {
2063    let guest_connection_ids = session
2064        .db()
2065        .await
2066        .start_language_server(&request, session.connection_id)
2067        .await?;
2068
2069    broadcast(
2070        Some(session.connection_id),
2071        guest_connection_ids.iter().copied(),
2072        |connection_id| {
2073            session
2074                .peer
2075                .forward_send(session.connection_id, connection_id, request.clone())
2076        },
2077    );
2078    Ok(())
2079}
2080
2081/// Notify other participants that a language server has changed.
2082async fn update_language_server(
2083    request: proto::UpdateLanguageServer,
2084    session: Session,
2085) -> Result<()> {
2086    let project_id = ProjectId::from_proto(request.project_id);
2087    let project_connection_ids = session
2088        .db()
2089        .await
2090        .project_connection_ids(project_id, session.connection_id, true)
2091        .await?;
2092    broadcast(
2093        Some(session.connection_id),
2094        project_connection_ids.iter().copied(),
2095        |connection_id| {
2096            session
2097                .peer
2098                .forward_send(session.connection_id, connection_id, request.clone())
2099        },
2100    );
2101    Ok(())
2102}
2103
2104/// forward a project request to the host. These requests should be read only
2105/// as guests are allowed to send them.
2106async fn forward_read_only_project_request<T>(
2107    request: T,
2108    response: Response<T>,
2109    session: Session,
2110) -> Result<()>
2111where
2112    T: EntityMessage + RequestMessage,
2113{
2114    let project_id = ProjectId::from_proto(request.remote_entity_id());
2115    let host_connection_id = session
2116        .db()
2117        .await
2118        .host_for_read_only_project_request(project_id, session.connection_id)
2119        .await?;
2120    let payload = session
2121        .peer
2122        .forward_request(session.connection_id, host_connection_id, request)
2123        .await?;
2124    response.send(payload)?;
2125    Ok(())
2126}
2127
2128async fn forward_find_search_candidates_request(
2129    request: proto::FindSearchCandidates,
2130    response: Response<proto::FindSearchCandidates>,
2131    session: Session,
2132) -> Result<()> {
2133    let project_id = ProjectId::from_proto(request.remote_entity_id());
2134    let host_connection_id = session
2135        .db()
2136        .await
2137        .host_for_read_only_project_request(project_id, session.connection_id)
2138        .await?;
2139    let payload = session
2140        .peer
2141        .forward_request(session.connection_id, host_connection_id, request)
2142        .await?;
2143    response.send(payload)?;
2144    Ok(())
2145}
2146
2147/// forward a project request to the host. These requests are disallowed
2148/// for guests.
2149async fn forward_mutating_project_request<T>(
2150    request: T,
2151    response: Response<T>,
2152    session: Session,
2153) -> Result<()>
2154where
2155    T: EntityMessage + RequestMessage,
2156{
2157    let project_id = ProjectId::from_proto(request.remote_entity_id());
2158
2159    let host_connection_id = session
2160        .db()
2161        .await
2162        .host_for_mutating_project_request(project_id, session.connection_id)
2163        .await?;
2164    let payload = session
2165        .peer
2166        .forward_request(session.connection_id, host_connection_id, request)
2167        .await?;
2168    response.send(payload)?;
2169    Ok(())
2170}
2171
2172/// Notify other participants that a new buffer has been created
2173async fn create_buffer_for_peer(
2174    request: proto::CreateBufferForPeer,
2175    session: Session,
2176) -> Result<()> {
2177    session
2178        .db()
2179        .await
2180        .check_user_is_project_host(
2181            ProjectId::from_proto(request.project_id),
2182            session.connection_id,
2183        )
2184        .await?;
2185    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2186    session
2187        .peer
2188        .forward_send(session.connection_id, peer_id.into(), request)?;
2189    Ok(())
2190}
2191
2192/// Notify other participants that a buffer has been updated. This is
2193/// allowed for guests as long as the update is limited to selections.
2194async fn update_buffer(
2195    request: proto::UpdateBuffer,
2196    response: Response<proto::UpdateBuffer>,
2197    session: Session,
2198) -> Result<()> {
2199    let project_id = ProjectId::from_proto(request.project_id);
2200    let mut capability = Capability::ReadOnly;
2201
2202    for op in request.operations.iter() {
2203        match op.variant {
2204            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2205            Some(_) => capability = Capability::ReadWrite,
2206        }
2207    }
2208
2209    let host = {
2210        let guard = session
2211            .db()
2212            .await
2213            .connections_for_buffer_update(project_id, session.connection_id, capability)
2214            .await?;
2215
2216        let (host, guests) = &*guard;
2217
2218        broadcast(
2219            Some(session.connection_id),
2220            guests.clone(),
2221            |connection_id| {
2222                session
2223                    .peer
2224                    .forward_send(session.connection_id, connection_id, request.clone())
2225            },
2226        );
2227
2228        *host
2229    };
2230
2231    if host != session.connection_id {
2232        session
2233            .peer
2234            .forward_request(session.connection_id, host, request.clone())
2235            .await?;
2236    }
2237
2238    response.send(proto::Ack {})?;
2239    Ok(())
2240}
2241
2242async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2243    let project_id = ProjectId::from_proto(message.project_id);
2244
2245    let operation = message.operation.as_ref().context("invalid operation")?;
2246    let capability = match operation.variant.as_ref() {
2247        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2248            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2249                match buffer_op.variant {
2250                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2251                        Capability::ReadOnly
2252                    }
2253                    _ => Capability::ReadWrite,
2254                }
2255            } else {
2256                Capability::ReadWrite
2257            }
2258        }
2259        Some(_) => Capability::ReadWrite,
2260        None => Capability::ReadOnly,
2261    };
2262
2263    let guard = session
2264        .db()
2265        .await
2266        .connections_for_buffer_update(project_id, session.connection_id, capability)
2267        .await?;
2268
2269    let (host, guests) = &*guard;
2270
2271    broadcast(
2272        Some(session.connection_id),
2273        guests.iter().chain([host]).copied(),
2274        |connection_id| {
2275            session
2276                .peer
2277                .forward_send(session.connection_id, connection_id, message.clone())
2278        },
2279    );
2280
2281    Ok(())
2282}
2283
2284/// Notify other participants that a project has been updated.
2285async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2286    request: T,
2287    session: Session,
2288) -> Result<()> {
2289    let project_id = ProjectId::from_proto(request.remote_entity_id());
2290    let project_connection_ids = session
2291        .db()
2292        .await
2293        .project_connection_ids(project_id, session.connection_id, false)
2294        .await?;
2295
2296    broadcast(
2297        Some(session.connection_id),
2298        project_connection_ids.iter().copied(),
2299        |connection_id| {
2300            session
2301                .peer
2302                .forward_send(session.connection_id, connection_id, request.clone())
2303        },
2304    );
2305    Ok(())
2306}
2307
2308/// Start following another user in a call.
2309async fn follow(
2310    request: proto::Follow,
2311    response: Response<proto::Follow>,
2312    session: Session,
2313) -> Result<()> {
2314    let room_id = RoomId::from_proto(request.room_id);
2315    let project_id = request.project_id.map(ProjectId::from_proto);
2316    let leader_id = request
2317        .leader_id
2318        .ok_or_else(|| anyhow!("invalid leader id"))?
2319        .into();
2320    let follower_id = session.connection_id;
2321
2322    session
2323        .db()
2324        .await
2325        .check_room_participants(room_id, leader_id, session.connection_id)
2326        .await?;
2327
2328    let response_payload = session
2329        .peer
2330        .forward_request(session.connection_id, leader_id, request)
2331        .await?;
2332    response.send(response_payload)?;
2333
2334    if let Some(project_id) = project_id {
2335        let room = session
2336            .db()
2337            .await
2338            .follow(room_id, project_id, leader_id, follower_id)
2339            .await?;
2340        room_updated(&room, &session.peer);
2341    }
2342
2343    Ok(())
2344}
2345
2346/// Stop following another user in a call.
2347async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2348    let room_id = RoomId::from_proto(request.room_id);
2349    let project_id = request.project_id.map(ProjectId::from_proto);
2350    let leader_id = request
2351        .leader_id
2352        .ok_or_else(|| anyhow!("invalid leader id"))?
2353        .into();
2354    let follower_id = session.connection_id;
2355
2356    session
2357        .db()
2358        .await
2359        .check_room_participants(room_id, leader_id, session.connection_id)
2360        .await?;
2361
2362    session
2363        .peer
2364        .forward_send(session.connection_id, leader_id, request)?;
2365
2366    if let Some(project_id) = project_id {
2367        let room = session
2368            .db()
2369            .await
2370            .unfollow(room_id, project_id, leader_id, follower_id)
2371            .await?;
2372        room_updated(&room, &session.peer);
2373    }
2374
2375    Ok(())
2376}
2377
2378/// Notify everyone following you of your current location.
2379async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2380    let room_id = RoomId::from_proto(request.room_id);
2381    let database = session.db.lock().await;
2382
2383    let connection_ids = if let Some(project_id) = request.project_id {
2384        let project_id = ProjectId::from_proto(project_id);
2385        database
2386            .project_connection_ids(project_id, session.connection_id, true)
2387            .await?
2388    } else {
2389        database
2390            .room_connection_ids(room_id, session.connection_id)
2391            .await?
2392    };
2393
2394    // For now, don't send view update messages back to that view's current leader.
2395    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2396        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2397        _ => None,
2398    });
2399
2400    for connection_id in connection_ids.iter().cloned() {
2401        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2402            session
2403                .peer
2404                .forward_send(session.connection_id, connection_id, request.clone())?;
2405        }
2406    }
2407    Ok(())
2408}
2409
2410/// Get public data about users.
2411async fn get_users(
2412    request: proto::GetUsers,
2413    response: Response<proto::GetUsers>,
2414    session: Session,
2415) -> Result<()> {
2416    let user_ids = request
2417        .user_ids
2418        .into_iter()
2419        .map(UserId::from_proto)
2420        .collect();
2421    let users = session
2422        .db()
2423        .await
2424        .get_users_by_ids(user_ids)
2425        .await?
2426        .into_iter()
2427        .map(|user| proto::User {
2428            id: user.id.to_proto(),
2429            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2430            github_login: user.github_login,
2431            email: user.email_address,
2432            name: user.name,
2433        })
2434        .collect();
2435    response.send(proto::UsersResponse { users })?;
2436    Ok(())
2437}
2438
2439/// Search for users (to invite) buy Github login
2440async fn fuzzy_search_users(
2441    request: proto::FuzzySearchUsers,
2442    response: Response<proto::FuzzySearchUsers>,
2443    session: Session,
2444) -> Result<()> {
2445    let query = request.query;
2446    let users = match query.len() {
2447        0 => vec![],
2448        1 | 2 => session
2449            .db()
2450            .await
2451            .get_user_by_github_login(&query)
2452            .await?
2453            .into_iter()
2454            .collect(),
2455        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2456    };
2457    let users = users
2458        .into_iter()
2459        .filter(|user| user.id != session.user_id())
2460        .map(|user| proto::User {
2461            id: user.id.to_proto(),
2462            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2463            github_login: user.github_login,
2464            name: user.name,
2465            email: user.email_address,
2466        })
2467        .collect();
2468    response.send(proto::UsersResponse { users })?;
2469    Ok(())
2470}
2471
2472/// Send a contact request to another user.
2473async fn request_contact(
2474    request: proto::RequestContact,
2475    response: Response<proto::RequestContact>,
2476    session: Session,
2477) -> Result<()> {
2478    let requester_id = session.user_id();
2479    let responder_id = UserId::from_proto(request.responder_id);
2480    if requester_id == responder_id {
2481        return Err(anyhow!("cannot add yourself as a contact"))?;
2482    }
2483
2484    let notifications = session
2485        .db()
2486        .await
2487        .send_contact_request(requester_id, responder_id)
2488        .await?;
2489
2490    // Update outgoing contact requests of requester
2491    let mut update = proto::UpdateContacts::default();
2492    update.outgoing_requests.push(responder_id.to_proto());
2493    for connection_id in session
2494        .connection_pool()
2495        .await
2496        .user_connection_ids(requester_id)
2497    {
2498        session.peer.send(connection_id, update.clone())?;
2499    }
2500
2501    // Update incoming contact requests of responder
2502    let mut update = proto::UpdateContacts::default();
2503    update
2504        .incoming_requests
2505        .push(proto::IncomingContactRequest {
2506            requester_id: requester_id.to_proto(),
2507        });
2508    let connection_pool = session.connection_pool().await;
2509    for connection_id in connection_pool.user_connection_ids(responder_id) {
2510        session.peer.send(connection_id, update.clone())?;
2511    }
2512
2513    send_notifications(&connection_pool, &session.peer, notifications);
2514
2515    response.send(proto::Ack {})?;
2516    Ok(())
2517}
2518
2519/// Accept or decline a contact request
2520async fn respond_to_contact_request(
2521    request: proto::RespondToContactRequest,
2522    response: Response<proto::RespondToContactRequest>,
2523    session: Session,
2524) -> Result<()> {
2525    let responder_id = session.user_id();
2526    let requester_id = UserId::from_proto(request.requester_id);
2527    let db = session.db().await;
2528    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2529        db.dismiss_contact_notification(responder_id, requester_id)
2530            .await?;
2531    } else {
2532        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2533
2534        let notifications = db
2535            .respond_to_contact_request(responder_id, requester_id, accept)
2536            .await?;
2537        let requester_busy = db.is_user_busy(requester_id).await?;
2538        let responder_busy = db.is_user_busy(responder_id).await?;
2539
2540        let pool = session.connection_pool().await;
2541        // Update responder with new contact
2542        let mut update = proto::UpdateContacts::default();
2543        if accept {
2544            update
2545                .contacts
2546                .push(contact_for_user(requester_id, requester_busy, &pool));
2547        }
2548        update
2549            .remove_incoming_requests
2550            .push(requester_id.to_proto());
2551        for connection_id in pool.user_connection_ids(responder_id) {
2552            session.peer.send(connection_id, update.clone())?;
2553        }
2554
2555        // Update requester with new contact
2556        let mut update = proto::UpdateContacts::default();
2557        if accept {
2558            update
2559                .contacts
2560                .push(contact_for_user(responder_id, responder_busy, &pool));
2561        }
2562        update
2563            .remove_outgoing_requests
2564            .push(responder_id.to_proto());
2565
2566        for connection_id in pool.user_connection_ids(requester_id) {
2567            session.peer.send(connection_id, update.clone())?;
2568        }
2569
2570        send_notifications(&pool, &session.peer, notifications);
2571    }
2572
2573    response.send(proto::Ack {})?;
2574    Ok(())
2575}
2576
2577/// Remove a contact.
2578async fn remove_contact(
2579    request: proto::RemoveContact,
2580    response: Response<proto::RemoveContact>,
2581    session: Session,
2582) -> Result<()> {
2583    let requester_id = session.user_id();
2584    let responder_id = UserId::from_proto(request.user_id);
2585    let db = session.db().await;
2586    let (contact_accepted, deleted_notification_id) =
2587        db.remove_contact(requester_id, responder_id).await?;
2588
2589    let pool = session.connection_pool().await;
2590    // Update outgoing contact requests of requester
2591    let mut update = proto::UpdateContacts::default();
2592    if contact_accepted {
2593        update.remove_contacts.push(responder_id.to_proto());
2594    } else {
2595        update
2596            .remove_outgoing_requests
2597            .push(responder_id.to_proto());
2598    }
2599    for connection_id in pool.user_connection_ids(requester_id) {
2600        session.peer.send(connection_id, update.clone())?;
2601    }
2602
2603    // Update incoming contact requests of responder
2604    let mut update = proto::UpdateContacts::default();
2605    if contact_accepted {
2606        update.remove_contacts.push(requester_id.to_proto());
2607    } else {
2608        update
2609            .remove_incoming_requests
2610            .push(requester_id.to_proto());
2611    }
2612    for connection_id in pool.user_connection_ids(responder_id) {
2613        session.peer.send(connection_id, update.clone())?;
2614        if let Some(notification_id) = deleted_notification_id {
2615            session.peer.send(
2616                connection_id,
2617                proto::DeleteNotification {
2618                    notification_id: notification_id.to_proto(),
2619                },
2620            )?;
2621        }
2622    }
2623
2624    response.send(proto::Ack {})?;
2625    Ok(())
2626}
2627
2628fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2629    version.0.minor() < 139
2630}
2631
2632async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2633    let plan = session.current_plan(&session.db().await).await?;
2634
2635    session
2636        .peer
2637        .send(
2638            session.connection_id,
2639            proto::UpdateUserPlan { plan: plan.into() },
2640        )
2641        .trace_err();
2642
2643    Ok(())
2644}
2645
2646async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2647    subscribe_user_to_channels(session.user_id(), &session).await?;
2648    Ok(())
2649}
2650
2651async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2652    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2653    let mut pool = session.connection_pool().await;
2654    for membership in &channels_for_user.channel_memberships {
2655        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2656    }
2657    session.peer.send(
2658        session.connection_id,
2659        build_update_user_channels(&channels_for_user),
2660    )?;
2661    session.peer.send(
2662        session.connection_id,
2663        build_channels_update(channels_for_user),
2664    )?;
2665    Ok(())
2666}
2667
2668/// Creates a new channel.
2669async fn create_channel(
2670    request: proto::CreateChannel,
2671    response: Response<proto::CreateChannel>,
2672    session: Session,
2673) -> Result<()> {
2674    let db = session.db().await;
2675
2676    let parent_id = request.parent_id.map(ChannelId::from_proto);
2677    let (channel, membership) = db
2678        .create_channel(&request.name, parent_id, session.user_id())
2679        .await?;
2680
2681    let root_id = channel.root_id();
2682    let channel = Channel::from_model(channel);
2683
2684    response.send(proto::CreateChannelResponse {
2685        channel: Some(channel.to_proto()),
2686        parent_id: request.parent_id,
2687    })?;
2688
2689    let mut connection_pool = session.connection_pool().await;
2690    if let Some(membership) = membership {
2691        connection_pool.subscribe_to_channel(
2692            membership.user_id,
2693            membership.channel_id,
2694            membership.role,
2695        );
2696        let update = proto::UpdateUserChannels {
2697            channel_memberships: vec![proto::ChannelMembership {
2698                channel_id: membership.channel_id.to_proto(),
2699                role: membership.role.into(),
2700            }],
2701            ..Default::default()
2702        };
2703        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2704            session.peer.send(connection_id, update.clone())?;
2705        }
2706    }
2707
2708    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2709        if !role.can_see_channel(channel.visibility) {
2710            continue;
2711        }
2712
2713        let update = proto::UpdateChannels {
2714            channels: vec![channel.to_proto()],
2715            ..Default::default()
2716        };
2717        session.peer.send(connection_id, update.clone())?;
2718    }
2719
2720    Ok(())
2721}
2722
2723/// Delete a channel
2724async fn delete_channel(
2725    request: proto::DeleteChannel,
2726    response: Response<proto::DeleteChannel>,
2727    session: Session,
2728) -> Result<()> {
2729    let db = session.db().await;
2730
2731    let channel_id = request.channel_id;
2732    let (root_channel, removed_channels) = db
2733        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2734        .await?;
2735    response.send(proto::Ack {})?;
2736
2737    // Notify members of removed channels
2738    let mut update = proto::UpdateChannels::default();
2739    update
2740        .delete_channels
2741        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2742
2743    let connection_pool = session.connection_pool().await;
2744    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2745        session.peer.send(connection_id, update.clone())?;
2746    }
2747
2748    Ok(())
2749}
2750
2751/// Invite someone to join a channel.
2752async fn invite_channel_member(
2753    request: proto::InviteChannelMember,
2754    response: Response<proto::InviteChannelMember>,
2755    session: Session,
2756) -> Result<()> {
2757    let db = session.db().await;
2758    let channel_id = ChannelId::from_proto(request.channel_id);
2759    let invitee_id = UserId::from_proto(request.user_id);
2760    let InviteMemberResult {
2761        channel,
2762        notifications,
2763    } = db
2764        .invite_channel_member(
2765            channel_id,
2766            invitee_id,
2767            session.user_id(),
2768            request.role().into(),
2769        )
2770        .await?;
2771
2772    let update = proto::UpdateChannels {
2773        channel_invitations: vec![channel.to_proto()],
2774        ..Default::default()
2775    };
2776
2777    let connection_pool = session.connection_pool().await;
2778    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2779        session.peer.send(connection_id, update.clone())?;
2780    }
2781
2782    send_notifications(&connection_pool, &session.peer, notifications);
2783
2784    response.send(proto::Ack {})?;
2785    Ok(())
2786}
2787
2788/// remove someone from a channel
2789async fn remove_channel_member(
2790    request: proto::RemoveChannelMember,
2791    response: Response<proto::RemoveChannelMember>,
2792    session: Session,
2793) -> Result<()> {
2794    let db = session.db().await;
2795    let channel_id = ChannelId::from_proto(request.channel_id);
2796    let member_id = UserId::from_proto(request.user_id);
2797
2798    let RemoveChannelMemberResult {
2799        membership_update,
2800        notification_id,
2801    } = db
2802        .remove_channel_member(channel_id, member_id, session.user_id())
2803        .await?;
2804
2805    let mut connection_pool = session.connection_pool().await;
2806    notify_membership_updated(
2807        &mut connection_pool,
2808        membership_update,
2809        member_id,
2810        &session.peer,
2811    );
2812    for connection_id in connection_pool.user_connection_ids(member_id) {
2813        if let Some(notification_id) = notification_id {
2814            session
2815                .peer
2816                .send(
2817                    connection_id,
2818                    proto::DeleteNotification {
2819                        notification_id: notification_id.to_proto(),
2820                    },
2821                )
2822                .trace_err();
2823        }
2824    }
2825
2826    response.send(proto::Ack {})?;
2827    Ok(())
2828}
2829
2830/// Toggle the channel between public and private.
2831/// Care is taken to maintain the invariant that public channels only descend from public channels,
2832/// (though members-only channels can appear at any point in the hierarchy).
2833async fn set_channel_visibility(
2834    request: proto::SetChannelVisibility,
2835    response: Response<proto::SetChannelVisibility>,
2836    session: Session,
2837) -> Result<()> {
2838    let db = session.db().await;
2839    let channel_id = ChannelId::from_proto(request.channel_id);
2840    let visibility = request.visibility().into();
2841
2842    let channel_model = db
2843        .set_channel_visibility(channel_id, visibility, session.user_id())
2844        .await?;
2845    let root_id = channel_model.root_id();
2846    let channel = Channel::from_model(channel_model);
2847
2848    let mut connection_pool = session.connection_pool().await;
2849    for (user_id, role) in connection_pool
2850        .channel_user_ids(root_id)
2851        .collect::<Vec<_>>()
2852        .into_iter()
2853    {
2854        let update = if role.can_see_channel(channel.visibility) {
2855            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2856            proto::UpdateChannels {
2857                channels: vec![channel.to_proto()],
2858                ..Default::default()
2859            }
2860        } else {
2861            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2862            proto::UpdateChannels {
2863                delete_channels: vec![channel.id.to_proto()],
2864                ..Default::default()
2865            }
2866        };
2867
2868        for connection_id in connection_pool.user_connection_ids(user_id) {
2869            session.peer.send(connection_id, update.clone())?;
2870        }
2871    }
2872
2873    response.send(proto::Ack {})?;
2874    Ok(())
2875}
2876
2877/// Alter the role for a user in the channel.
2878async fn set_channel_member_role(
2879    request: proto::SetChannelMemberRole,
2880    response: Response<proto::SetChannelMemberRole>,
2881    session: Session,
2882) -> Result<()> {
2883    let db = session.db().await;
2884    let channel_id = ChannelId::from_proto(request.channel_id);
2885    let member_id = UserId::from_proto(request.user_id);
2886    let result = db
2887        .set_channel_member_role(
2888            channel_id,
2889            session.user_id(),
2890            member_id,
2891            request.role().into(),
2892        )
2893        .await?;
2894
2895    match result {
2896        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2897            let mut connection_pool = session.connection_pool().await;
2898            notify_membership_updated(
2899                &mut connection_pool,
2900                membership_update,
2901                member_id,
2902                &session.peer,
2903            )
2904        }
2905        db::SetMemberRoleResult::InviteUpdated(channel) => {
2906            let update = proto::UpdateChannels {
2907                channel_invitations: vec![channel.to_proto()],
2908                ..Default::default()
2909            };
2910
2911            for connection_id in session
2912                .connection_pool()
2913                .await
2914                .user_connection_ids(member_id)
2915            {
2916                session.peer.send(connection_id, update.clone())?;
2917            }
2918        }
2919    }
2920
2921    response.send(proto::Ack {})?;
2922    Ok(())
2923}
2924
2925/// Change the name of a channel
2926async fn rename_channel(
2927    request: proto::RenameChannel,
2928    response: Response<proto::RenameChannel>,
2929    session: Session,
2930) -> Result<()> {
2931    let db = session.db().await;
2932    let channel_id = ChannelId::from_proto(request.channel_id);
2933    let channel_model = db
2934        .rename_channel(channel_id, session.user_id(), &request.name)
2935        .await?;
2936    let root_id = channel_model.root_id();
2937    let channel = Channel::from_model(channel_model);
2938
2939    response.send(proto::RenameChannelResponse {
2940        channel: Some(channel.to_proto()),
2941    })?;
2942
2943    let connection_pool = session.connection_pool().await;
2944    let update = proto::UpdateChannels {
2945        channels: vec![channel.to_proto()],
2946        ..Default::default()
2947    };
2948    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2949        if role.can_see_channel(channel.visibility) {
2950            session.peer.send(connection_id, update.clone())?;
2951        }
2952    }
2953
2954    Ok(())
2955}
2956
2957/// Move a channel to a new parent.
2958async fn move_channel(
2959    request: proto::MoveChannel,
2960    response: Response<proto::MoveChannel>,
2961    session: Session,
2962) -> Result<()> {
2963    let channel_id = ChannelId::from_proto(request.channel_id);
2964    let to = ChannelId::from_proto(request.to);
2965
2966    let (root_id, channels) = session
2967        .db()
2968        .await
2969        .move_channel(channel_id, to, session.user_id())
2970        .await?;
2971
2972    let connection_pool = session.connection_pool().await;
2973    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2974        let channels = channels
2975            .iter()
2976            .filter_map(|channel| {
2977                if role.can_see_channel(channel.visibility) {
2978                    Some(channel.to_proto())
2979                } else {
2980                    None
2981                }
2982            })
2983            .collect::<Vec<_>>();
2984        if channels.is_empty() {
2985            continue;
2986        }
2987
2988        let update = proto::UpdateChannels {
2989            channels,
2990            ..Default::default()
2991        };
2992
2993        session.peer.send(connection_id, update.clone())?;
2994    }
2995
2996    response.send(Ack {})?;
2997    Ok(())
2998}
2999
3000/// Get the list of channel members
3001async fn get_channel_members(
3002    request: proto::GetChannelMembers,
3003    response: Response<proto::GetChannelMembers>,
3004    session: Session,
3005) -> Result<()> {
3006    let db = session.db().await;
3007    let channel_id = ChannelId::from_proto(request.channel_id);
3008    let limit = if request.limit == 0 {
3009        u16::MAX as u64
3010    } else {
3011        request.limit
3012    };
3013    let (members, users) = db
3014        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3015        .await?;
3016    response.send(proto::GetChannelMembersResponse { members, users })?;
3017    Ok(())
3018}
3019
3020/// Accept or decline a channel invitation.
3021async fn respond_to_channel_invite(
3022    request: proto::RespondToChannelInvite,
3023    response: Response<proto::RespondToChannelInvite>,
3024    session: Session,
3025) -> Result<()> {
3026    let db = session.db().await;
3027    let channel_id = ChannelId::from_proto(request.channel_id);
3028    let RespondToChannelInvite {
3029        membership_update,
3030        notifications,
3031    } = db
3032        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3033        .await?;
3034
3035    let mut connection_pool = session.connection_pool().await;
3036    if let Some(membership_update) = membership_update {
3037        notify_membership_updated(
3038            &mut connection_pool,
3039            membership_update,
3040            session.user_id(),
3041            &session.peer,
3042        );
3043    } else {
3044        let update = proto::UpdateChannels {
3045            remove_channel_invitations: vec![channel_id.to_proto()],
3046            ..Default::default()
3047        };
3048
3049        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3050            session.peer.send(connection_id, update.clone())?;
3051        }
3052    };
3053
3054    send_notifications(&connection_pool, &session.peer, notifications);
3055
3056    response.send(proto::Ack {})?;
3057
3058    Ok(())
3059}
3060
3061/// Join the channels' room
3062async fn join_channel(
3063    request: proto::JoinChannel,
3064    response: Response<proto::JoinChannel>,
3065    session: Session,
3066) -> Result<()> {
3067    let channel_id = ChannelId::from_proto(request.channel_id);
3068    join_channel_internal(channel_id, Box::new(response), session).await
3069}
3070
3071trait JoinChannelInternalResponse {
3072    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3073}
3074impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3075    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3076        Response::<proto::JoinChannel>::send(self, result)
3077    }
3078}
3079impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3080    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3081        Response::<proto::JoinRoom>::send(self, result)
3082    }
3083}
3084
3085async fn join_channel_internal(
3086    channel_id: ChannelId,
3087    response: Box<impl JoinChannelInternalResponse>,
3088    session: Session,
3089) -> Result<()> {
3090    let joined_room = {
3091        let mut db = session.db().await;
3092        // If zed quits without leaving the room, and the user re-opens zed before the
3093        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3094        // room they were in.
3095        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3096            tracing::info!(
3097                stale_connection_id = %connection,
3098                "cleaning up stale connection",
3099            );
3100            drop(db);
3101            leave_room_for_session(&session, connection).await?;
3102            db = session.db().await;
3103        }
3104
3105        let (joined_room, membership_updated, role) = db
3106            .join_channel(channel_id, session.user_id(), session.connection_id)
3107            .await?;
3108
3109        let live_kit_connection_info =
3110            session
3111                .app_state
3112                .livekit_client
3113                .as_ref()
3114                .and_then(|live_kit| {
3115                    let (can_publish, token) = if role == ChannelRole::Guest {
3116                        (
3117                            false,
3118                            live_kit
3119                                .guest_token(
3120                                    &joined_room.room.livekit_room,
3121                                    &session.user_id().to_string(),
3122                                )
3123                                .trace_err()?,
3124                        )
3125                    } else {
3126                        (
3127                            true,
3128                            live_kit
3129                                .room_token(
3130                                    &joined_room.room.livekit_room,
3131                                    &session.user_id().to_string(),
3132                                )
3133                                .trace_err()?,
3134                        )
3135                    };
3136
3137                    Some(LiveKitConnectionInfo {
3138                        server_url: live_kit.url().into(),
3139                        token,
3140                        can_publish,
3141                    })
3142                });
3143
3144        response.send(proto::JoinRoomResponse {
3145            room: Some(joined_room.room.clone()),
3146            channel_id: joined_room
3147                .channel
3148                .as_ref()
3149                .map(|channel| channel.id.to_proto()),
3150            live_kit_connection_info,
3151        })?;
3152
3153        let mut connection_pool = session.connection_pool().await;
3154        if let Some(membership_updated) = membership_updated {
3155            notify_membership_updated(
3156                &mut connection_pool,
3157                membership_updated,
3158                session.user_id(),
3159                &session.peer,
3160            );
3161        }
3162
3163        room_updated(&joined_room.room, &session.peer);
3164
3165        joined_room
3166    };
3167
3168    channel_updated(
3169        &joined_room
3170            .channel
3171            .ok_or_else(|| anyhow!("channel not returned"))?,
3172        &joined_room.room,
3173        &session.peer,
3174        &*session.connection_pool().await,
3175    );
3176
3177    update_user_contacts(session.user_id(), &session).await?;
3178    Ok(())
3179}
3180
3181/// Start editing the channel notes
3182async fn join_channel_buffer(
3183    request: proto::JoinChannelBuffer,
3184    response: Response<proto::JoinChannelBuffer>,
3185    session: Session,
3186) -> Result<()> {
3187    let db = session.db().await;
3188    let channel_id = ChannelId::from_proto(request.channel_id);
3189
3190    let open_response = db
3191        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3192        .await?;
3193
3194    let collaborators = open_response.collaborators.clone();
3195    response.send(open_response)?;
3196
3197    let update = UpdateChannelBufferCollaborators {
3198        channel_id: channel_id.to_proto(),
3199        collaborators: collaborators.clone(),
3200    };
3201    channel_buffer_updated(
3202        session.connection_id,
3203        collaborators
3204            .iter()
3205            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3206        &update,
3207        &session.peer,
3208    );
3209
3210    Ok(())
3211}
3212
3213/// Edit the channel notes
3214async fn update_channel_buffer(
3215    request: proto::UpdateChannelBuffer,
3216    session: Session,
3217) -> Result<()> {
3218    let db = session.db().await;
3219    let channel_id = ChannelId::from_proto(request.channel_id);
3220
3221    let (collaborators, epoch, version) = db
3222        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3223        .await?;
3224
3225    channel_buffer_updated(
3226        session.connection_id,
3227        collaborators.clone(),
3228        &proto::UpdateChannelBuffer {
3229            channel_id: channel_id.to_proto(),
3230            operations: request.operations,
3231        },
3232        &session.peer,
3233    );
3234
3235    let pool = &*session.connection_pool().await;
3236
3237    let non_collaborators =
3238        pool.channel_connection_ids(channel_id)
3239            .filter_map(|(connection_id, _)| {
3240                if collaborators.contains(&connection_id) {
3241                    None
3242                } else {
3243                    Some(connection_id)
3244                }
3245            });
3246
3247    broadcast(None, non_collaborators, |peer_id| {
3248        session.peer.send(
3249            peer_id,
3250            proto::UpdateChannels {
3251                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3252                    channel_id: channel_id.to_proto(),
3253                    epoch: epoch as u64,
3254                    version: version.clone(),
3255                }],
3256                ..Default::default()
3257            },
3258        )
3259    });
3260
3261    Ok(())
3262}
3263
3264/// Rejoin the channel notes after a connection blip
3265async fn rejoin_channel_buffers(
3266    request: proto::RejoinChannelBuffers,
3267    response: Response<proto::RejoinChannelBuffers>,
3268    session: Session,
3269) -> Result<()> {
3270    let db = session.db().await;
3271    let buffers = db
3272        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3273        .await?;
3274
3275    for rejoined_buffer in &buffers {
3276        let collaborators_to_notify = rejoined_buffer
3277            .buffer
3278            .collaborators
3279            .iter()
3280            .filter_map(|c| Some(c.peer_id?.into()));
3281        channel_buffer_updated(
3282            session.connection_id,
3283            collaborators_to_notify,
3284            &proto::UpdateChannelBufferCollaborators {
3285                channel_id: rejoined_buffer.buffer.channel_id,
3286                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3287            },
3288            &session.peer,
3289        );
3290    }
3291
3292    response.send(proto::RejoinChannelBuffersResponse {
3293        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3294    })?;
3295
3296    Ok(())
3297}
3298
3299/// Stop editing the channel notes
3300async fn leave_channel_buffer(
3301    request: proto::LeaveChannelBuffer,
3302    response: Response<proto::LeaveChannelBuffer>,
3303    session: Session,
3304) -> Result<()> {
3305    let db = session.db().await;
3306    let channel_id = ChannelId::from_proto(request.channel_id);
3307
3308    let left_buffer = db
3309        .leave_channel_buffer(channel_id, session.connection_id)
3310        .await?;
3311
3312    response.send(Ack {})?;
3313
3314    channel_buffer_updated(
3315        session.connection_id,
3316        left_buffer.connections,
3317        &proto::UpdateChannelBufferCollaborators {
3318            channel_id: channel_id.to_proto(),
3319            collaborators: left_buffer.collaborators,
3320        },
3321        &session.peer,
3322    );
3323
3324    Ok(())
3325}
3326
3327fn channel_buffer_updated<T: EnvelopedMessage>(
3328    sender_id: ConnectionId,
3329    collaborators: impl IntoIterator<Item = ConnectionId>,
3330    message: &T,
3331    peer: &Peer,
3332) {
3333    broadcast(Some(sender_id), collaborators, |peer_id| {
3334        peer.send(peer_id, message.clone())
3335    });
3336}
3337
3338fn send_notifications(
3339    connection_pool: &ConnectionPool,
3340    peer: &Peer,
3341    notifications: db::NotificationBatch,
3342) {
3343    for (user_id, notification) in notifications {
3344        for connection_id in connection_pool.user_connection_ids(user_id) {
3345            if let Err(error) = peer.send(
3346                connection_id,
3347                proto::AddNotification {
3348                    notification: Some(notification.clone()),
3349                },
3350            ) {
3351                tracing::error!(
3352                    "failed to send notification to {:?} {}",
3353                    connection_id,
3354                    error
3355                );
3356            }
3357        }
3358    }
3359}
3360
3361/// Send a message to the channel
3362async fn send_channel_message(
3363    request: proto::SendChannelMessage,
3364    response: Response<proto::SendChannelMessage>,
3365    session: Session,
3366) -> Result<()> {
3367    // Validate the message body.
3368    let body = request.body.trim().to_string();
3369    if body.len() > MAX_MESSAGE_LEN {
3370        return Err(anyhow!("message is too long"))?;
3371    }
3372    if body.is_empty() {
3373        return Err(anyhow!("message can't be blank"))?;
3374    }
3375
3376    // TODO: adjust mentions if body is trimmed
3377
3378    let timestamp = OffsetDateTime::now_utc();
3379    let nonce = request
3380        .nonce
3381        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3382
3383    let channel_id = ChannelId::from_proto(request.channel_id);
3384    let CreatedChannelMessage {
3385        message_id,
3386        participant_connection_ids,
3387        notifications,
3388    } = session
3389        .db()
3390        .await
3391        .create_channel_message(
3392            channel_id,
3393            session.user_id(),
3394            &body,
3395            &request.mentions,
3396            timestamp,
3397            nonce.clone().into(),
3398            request.reply_to_message_id.map(MessageId::from_proto),
3399        )
3400        .await?;
3401
3402    let message = proto::ChannelMessage {
3403        sender_id: session.user_id().to_proto(),
3404        id: message_id.to_proto(),
3405        body,
3406        mentions: request.mentions,
3407        timestamp: timestamp.unix_timestamp() as u64,
3408        nonce: Some(nonce),
3409        reply_to_message_id: request.reply_to_message_id,
3410        edited_at: None,
3411    };
3412    broadcast(
3413        Some(session.connection_id),
3414        participant_connection_ids.clone(),
3415        |connection| {
3416            session.peer.send(
3417                connection,
3418                proto::ChannelMessageSent {
3419                    channel_id: channel_id.to_proto(),
3420                    message: Some(message.clone()),
3421                },
3422            )
3423        },
3424    );
3425    response.send(proto::SendChannelMessageResponse {
3426        message: Some(message),
3427    })?;
3428
3429    let pool = &*session.connection_pool().await;
3430    let non_participants =
3431        pool.channel_connection_ids(channel_id)
3432            .filter_map(|(connection_id, _)| {
3433                if participant_connection_ids.contains(&connection_id) {
3434                    None
3435                } else {
3436                    Some(connection_id)
3437                }
3438            });
3439    broadcast(None, non_participants, |peer_id| {
3440        session.peer.send(
3441            peer_id,
3442            proto::UpdateChannels {
3443                latest_channel_message_ids: vec![proto::ChannelMessageId {
3444                    channel_id: channel_id.to_proto(),
3445                    message_id: message_id.to_proto(),
3446                }],
3447                ..Default::default()
3448            },
3449        )
3450    });
3451    send_notifications(pool, &session.peer, notifications);
3452
3453    Ok(())
3454}
3455
3456/// Delete a channel message
3457async fn remove_channel_message(
3458    request: proto::RemoveChannelMessage,
3459    response: Response<proto::RemoveChannelMessage>,
3460    session: Session,
3461) -> Result<()> {
3462    let channel_id = ChannelId::from_proto(request.channel_id);
3463    let message_id = MessageId::from_proto(request.message_id);
3464    let (connection_ids, existing_notification_ids) = session
3465        .db()
3466        .await
3467        .remove_channel_message(channel_id, message_id, session.user_id())
3468        .await?;
3469
3470    broadcast(
3471        Some(session.connection_id),
3472        connection_ids,
3473        move |connection| {
3474            session.peer.send(connection, request.clone())?;
3475
3476            for notification_id in &existing_notification_ids {
3477                session.peer.send(
3478                    connection,
3479                    proto::DeleteNotification {
3480                        notification_id: (*notification_id).to_proto(),
3481                    },
3482                )?;
3483            }
3484
3485            Ok(())
3486        },
3487    );
3488    response.send(proto::Ack {})?;
3489    Ok(())
3490}
3491
3492async fn update_channel_message(
3493    request: proto::UpdateChannelMessage,
3494    response: Response<proto::UpdateChannelMessage>,
3495    session: Session,
3496) -> Result<()> {
3497    let channel_id = ChannelId::from_proto(request.channel_id);
3498    let message_id = MessageId::from_proto(request.message_id);
3499    let updated_at = OffsetDateTime::now_utc();
3500    let UpdatedChannelMessage {
3501        message_id,
3502        participant_connection_ids,
3503        notifications,
3504        reply_to_message_id,
3505        timestamp,
3506        deleted_mention_notification_ids,
3507        updated_mention_notifications,
3508    } = session
3509        .db()
3510        .await
3511        .update_channel_message(
3512            channel_id,
3513            message_id,
3514            session.user_id(),
3515            request.body.as_str(),
3516            &request.mentions,
3517            updated_at,
3518        )
3519        .await?;
3520
3521    let nonce = request
3522        .nonce
3523        .clone()
3524        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3525
3526    let message = proto::ChannelMessage {
3527        sender_id: session.user_id().to_proto(),
3528        id: message_id.to_proto(),
3529        body: request.body.clone(),
3530        mentions: request.mentions.clone(),
3531        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3532        nonce: Some(nonce),
3533        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3534        edited_at: Some(updated_at.unix_timestamp() as u64),
3535    };
3536
3537    response.send(proto::Ack {})?;
3538
3539    let pool = &*session.connection_pool().await;
3540    broadcast(
3541        Some(session.connection_id),
3542        participant_connection_ids,
3543        |connection| {
3544            session.peer.send(
3545                connection,
3546                proto::ChannelMessageUpdate {
3547                    channel_id: channel_id.to_proto(),
3548                    message: Some(message.clone()),
3549                },
3550            )?;
3551
3552            for notification_id in &deleted_mention_notification_ids {
3553                session.peer.send(
3554                    connection,
3555                    proto::DeleteNotification {
3556                        notification_id: (*notification_id).to_proto(),
3557                    },
3558                )?;
3559            }
3560
3561            for notification in &updated_mention_notifications {
3562                session.peer.send(
3563                    connection,
3564                    proto::UpdateNotification {
3565                        notification: Some(notification.clone()),
3566                    },
3567                )?;
3568            }
3569
3570            Ok(())
3571        },
3572    );
3573
3574    send_notifications(pool, &session.peer, notifications);
3575
3576    Ok(())
3577}
3578
3579/// Mark a channel message as read
3580async fn acknowledge_channel_message(
3581    request: proto::AckChannelMessage,
3582    session: Session,
3583) -> Result<()> {
3584    let channel_id = ChannelId::from_proto(request.channel_id);
3585    let message_id = MessageId::from_proto(request.message_id);
3586    let notifications = session
3587        .db()
3588        .await
3589        .observe_channel_message(channel_id, session.user_id(), message_id)
3590        .await?;
3591    send_notifications(
3592        &*session.connection_pool().await,
3593        &session.peer,
3594        notifications,
3595    );
3596    Ok(())
3597}
3598
3599/// Mark a buffer version as synced
3600async fn acknowledge_buffer_version(
3601    request: proto::AckBufferOperation,
3602    session: Session,
3603) -> Result<()> {
3604    let buffer_id = BufferId::from_proto(request.buffer_id);
3605    session
3606        .db()
3607        .await
3608        .observe_buffer_version(
3609            buffer_id,
3610            session.user_id(),
3611            request.epoch as i32,
3612            &request.version,
3613        )
3614        .await?;
3615    Ok(())
3616}
3617
3618async fn count_language_model_tokens(
3619    request: proto::CountLanguageModelTokens,
3620    response: Response<proto::CountLanguageModelTokens>,
3621    session: Session,
3622    config: &Config,
3623) -> Result<()> {
3624    authorize_access_to_legacy_llm_endpoints(&session).await?;
3625
3626    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3627        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3628        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3629    };
3630
3631    session
3632        .app_state
3633        .rate_limiter
3634        .check(&*rate_limit, session.user_id())
3635        .await?;
3636
3637    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3638        Some(proto::LanguageModelProvider::Google) => {
3639            let api_key = config
3640                .google_ai_api_key
3641                .as_ref()
3642                .context("no Google AI API key configured on the server")?;
3643            google_ai::count_tokens(
3644                session.http_client.as_ref(),
3645                google_ai::API_URL,
3646                api_key,
3647                serde_json::from_str(&request.request)?,
3648            )
3649            .await?
3650        }
3651        _ => return Err(anyhow!("unsupported provider"))?,
3652    };
3653
3654    response.send(proto::CountLanguageModelTokensResponse {
3655        token_count: result.total_tokens as u32,
3656    })?;
3657
3658    Ok(())
3659}
3660
3661struct ZedProCountLanguageModelTokensRateLimit;
3662
3663impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3664    fn capacity(&self) -> usize {
3665        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3666            .ok()
3667            .and_then(|v| v.parse().ok())
3668            .unwrap_or(600) // Picked arbitrarily
3669    }
3670
3671    fn refill_duration(&self) -> chrono::Duration {
3672        chrono::Duration::hours(1)
3673    }
3674
3675    fn db_name(&self) -> &'static str {
3676        "zed-pro:count-language-model-tokens"
3677    }
3678}
3679
3680struct FreeCountLanguageModelTokensRateLimit;
3681
3682impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3683    fn capacity(&self) -> usize {
3684        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3685            .ok()
3686            .and_then(|v| v.parse().ok())
3687            .unwrap_or(600 / 10) // Picked arbitrarily
3688    }
3689
3690    fn refill_duration(&self) -> chrono::Duration {
3691        chrono::Duration::hours(1)
3692    }
3693
3694    fn db_name(&self) -> &'static str {
3695        "free:count-language-model-tokens"
3696    }
3697}
3698
3699struct ZedProComputeEmbeddingsRateLimit;
3700
3701impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3702    fn capacity(&self) -> usize {
3703        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3704            .ok()
3705            .and_then(|v| v.parse().ok())
3706            .unwrap_or(5000) // Picked arbitrarily
3707    }
3708
3709    fn refill_duration(&self) -> chrono::Duration {
3710        chrono::Duration::hours(1)
3711    }
3712
3713    fn db_name(&self) -> &'static str {
3714        "zed-pro:compute-embeddings"
3715    }
3716}
3717
3718struct FreeComputeEmbeddingsRateLimit;
3719
3720impl RateLimit for FreeComputeEmbeddingsRateLimit {
3721    fn capacity(&self) -> usize {
3722        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3723            .ok()
3724            .and_then(|v| v.parse().ok())
3725            .unwrap_or(5000 / 10) // Picked arbitrarily
3726    }
3727
3728    fn refill_duration(&self) -> chrono::Duration {
3729        chrono::Duration::hours(1)
3730    }
3731
3732    fn db_name(&self) -> &'static str {
3733        "free:compute-embeddings"
3734    }
3735}
3736
3737async fn compute_embeddings(
3738    request: proto::ComputeEmbeddings,
3739    response: Response<proto::ComputeEmbeddings>,
3740    session: Session,
3741    api_key: Option<Arc<str>>,
3742) -> Result<()> {
3743    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3744    authorize_access_to_legacy_llm_endpoints(&session).await?;
3745
3746    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3747        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3748        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3749    };
3750
3751    session
3752        .app_state
3753        .rate_limiter
3754        .check(&*rate_limit, session.user_id())
3755        .await?;
3756
3757    let embeddings = match request.model.as_str() {
3758        "openai/text-embedding-3-small" => {
3759            open_ai::embed(
3760                session.http_client.as_ref(),
3761                OPEN_AI_API_URL,
3762                &api_key,
3763                OpenAiEmbeddingModel::TextEmbedding3Small,
3764                request.texts.iter().map(|text| text.as_str()),
3765            )
3766            .await?
3767        }
3768        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3769    };
3770
3771    let embeddings = request
3772        .texts
3773        .iter()
3774        .map(|text| {
3775            let mut hasher = sha2::Sha256::new();
3776            hasher.update(text.as_bytes());
3777            let result = hasher.finalize();
3778            result.to_vec()
3779        })
3780        .zip(
3781            embeddings
3782                .data
3783                .into_iter()
3784                .map(|embedding| embedding.embedding),
3785        )
3786        .collect::<HashMap<_, _>>();
3787
3788    let db = session.db().await;
3789    db.save_embeddings(&request.model, &embeddings)
3790        .await
3791        .context("failed to save embeddings")
3792        .trace_err();
3793
3794    response.send(proto::ComputeEmbeddingsResponse {
3795        embeddings: embeddings
3796            .into_iter()
3797            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3798            .collect(),
3799    })?;
3800    Ok(())
3801}
3802
3803async fn get_cached_embeddings(
3804    request: proto::GetCachedEmbeddings,
3805    response: Response<proto::GetCachedEmbeddings>,
3806    session: Session,
3807) -> Result<()> {
3808    authorize_access_to_legacy_llm_endpoints(&session).await?;
3809
3810    let db = session.db().await;
3811    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3812
3813    response.send(proto::GetCachedEmbeddingsResponse {
3814        embeddings: embeddings
3815            .into_iter()
3816            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3817            .collect(),
3818    })?;
3819    Ok(())
3820}
3821
3822/// This is leftover from before the LLM service.
3823///
3824/// The endpoints protected by this check will be moved there eventually.
3825async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3826    if session.is_staff() {
3827        Ok(())
3828    } else {
3829        Err(anyhow!("permission denied"))?
3830    }
3831}
3832
3833/// Get a Supermaven API key for the user
3834async fn get_supermaven_api_key(
3835    _request: proto::GetSupermavenApiKey,
3836    response: Response<proto::GetSupermavenApiKey>,
3837    session: Session,
3838) -> Result<()> {
3839    let user_id: String = session.user_id().to_string();
3840    if !session.is_staff() {
3841        return Err(anyhow!("supermaven not enabled for this account"))?;
3842    }
3843
3844    let email = session
3845        .email()
3846        .ok_or_else(|| anyhow!("user must have an email"))?;
3847
3848    let supermaven_admin_api = session
3849        .supermaven_client
3850        .as_ref()
3851        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3852
3853    let result = supermaven_admin_api
3854        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3855        .await?;
3856
3857    response.send(proto::GetSupermavenApiKeyResponse {
3858        api_key: result.api_key,
3859    })?;
3860
3861    Ok(())
3862}
3863
3864/// Start receiving chat updates for a channel
3865async fn join_channel_chat(
3866    request: proto::JoinChannelChat,
3867    response: Response<proto::JoinChannelChat>,
3868    session: Session,
3869) -> Result<()> {
3870    let channel_id = ChannelId::from_proto(request.channel_id);
3871
3872    let db = session.db().await;
3873    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3874        .await?;
3875    let messages = db
3876        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3877        .await?;
3878    response.send(proto::JoinChannelChatResponse {
3879        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3880        messages,
3881    })?;
3882    Ok(())
3883}
3884
3885/// Stop receiving chat updates for a channel
3886async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3887    let channel_id = ChannelId::from_proto(request.channel_id);
3888    session
3889        .db()
3890        .await
3891        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3892        .await?;
3893    Ok(())
3894}
3895
3896/// Retrieve the chat history for a channel
3897async fn get_channel_messages(
3898    request: proto::GetChannelMessages,
3899    response: Response<proto::GetChannelMessages>,
3900    session: Session,
3901) -> Result<()> {
3902    let channel_id = ChannelId::from_proto(request.channel_id);
3903    let messages = session
3904        .db()
3905        .await
3906        .get_channel_messages(
3907            channel_id,
3908            session.user_id(),
3909            MESSAGE_COUNT_PER_PAGE,
3910            Some(MessageId::from_proto(request.before_message_id)),
3911        )
3912        .await?;
3913    response.send(proto::GetChannelMessagesResponse {
3914        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3915        messages,
3916    })?;
3917    Ok(())
3918}
3919
3920/// Retrieve specific chat messages
3921async fn get_channel_messages_by_id(
3922    request: proto::GetChannelMessagesById,
3923    response: Response<proto::GetChannelMessagesById>,
3924    session: Session,
3925) -> Result<()> {
3926    let message_ids = request
3927        .message_ids
3928        .iter()
3929        .map(|id| MessageId::from_proto(*id))
3930        .collect::<Vec<_>>();
3931    let messages = session
3932        .db()
3933        .await
3934        .get_channel_messages_by_id(session.user_id(), &message_ids)
3935        .await?;
3936    response.send(proto::GetChannelMessagesResponse {
3937        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3938        messages,
3939    })?;
3940    Ok(())
3941}
3942
3943/// Retrieve the current users notifications
3944async fn get_notifications(
3945    request: proto::GetNotifications,
3946    response: Response<proto::GetNotifications>,
3947    session: Session,
3948) -> Result<()> {
3949    let notifications = session
3950        .db()
3951        .await
3952        .get_notifications(
3953            session.user_id(),
3954            NOTIFICATION_COUNT_PER_PAGE,
3955            request.before_id.map(db::NotificationId::from_proto),
3956        )
3957        .await?;
3958    response.send(proto::GetNotificationsResponse {
3959        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3960        notifications,
3961    })?;
3962    Ok(())
3963}
3964
3965/// Mark notifications as read
3966async fn mark_notification_as_read(
3967    request: proto::MarkNotificationRead,
3968    response: Response<proto::MarkNotificationRead>,
3969    session: Session,
3970) -> Result<()> {
3971    let database = &session.db().await;
3972    let notifications = database
3973        .mark_notification_as_read_by_id(
3974            session.user_id(),
3975            NotificationId::from_proto(request.notification_id),
3976        )
3977        .await?;
3978    send_notifications(
3979        &*session.connection_pool().await,
3980        &session.peer,
3981        notifications,
3982    );
3983    response.send(proto::Ack {})?;
3984    Ok(())
3985}
3986
3987/// Get the current users information
3988async fn get_private_user_info(
3989    _request: proto::GetPrivateUserInfo,
3990    response: Response<proto::GetPrivateUserInfo>,
3991    session: Session,
3992) -> Result<()> {
3993    let db = session.db().await;
3994
3995    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3996    let user = db
3997        .get_user_by_id(session.user_id())
3998        .await?
3999        .ok_or_else(|| anyhow!("user not found"))?;
4000    let flags = db.get_user_flags(session.user_id()).await?;
4001
4002    response.send(proto::GetPrivateUserInfoResponse {
4003        metrics_id,
4004        staff: user.admin,
4005        flags,
4006        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4007    })?;
4008    Ok(())
4009}
4010
4011/// Accept the terms of service (tos) on behalf of the current user
4012async fn accept_terms_of_service(
4013    _request: proto::AcceptTermsOfService,
4014    response: Response<proto::AcceptTermsOfService>,
4015    session: Session,
4016) -> Result<()> {
4017    let db = session.db().await;
4018
4019    let accepted_tos_at = Utc::now();
4020    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4021        .await?;
4022
4023    response.send(proto::AcceptTermsOfServiceResponse {
4024        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4025    })?;
4026    Ok(())
4027}
4028
4029/// The minimum account age an account must have in order to use the LLM service.
4030const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4031
4032async fn get_llm_api_token(
4033    _request: proto::GetLlmToken,
4034    response: Response<proto::GetLlmToken>,
4035    session: Session,
4036) -> Result<()> {
4037    let db = session.db().await;
4038
4039    let flags = db.get_user_flags(session.user_id()).await?;
4040    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4041    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4042    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4043
4044    if !session.is_staff() && !has_language_models_feature_flag {
4045        Err(anyhow!("permission denied"))?
4046    }
4047
4048    let user_id = session.user_id();
4049    let user = db
4050        .get_user_by_id(user_id)
4051        .await?
4052        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4053
4054    if user.accepted_tos_at.is_none() {
4055        Err(anyhow!("terms of service not accepted"))?
4056    }
4057
4058    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4059
4060    let bypass_account_age_check =
4061        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4062    if !bypass_account_age_check {
4063        let mut account_created_at = user.created_at;
4064        if let Some(github_created_at) = user.github_user_created_at {
4065            account_created_at = account_created_at.min(github_created_at);
4066        }
4067        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4068            Err(anyhow!("account too young"))?
4069        }
4070    }
4071
4072    let billing_preferences = db.get_billing_preferences(user.id).await?;
4073
4074    let token = LlmTokenClaims::create(
4075        &user,
4076        session.is_staff(),
4077        billing_preferences,
4078        has_llm_closed_beta_feature_flag,
4079        has_predict_edits_feature_flag,
4080        has_llm_subscription,
4081        session.current_plan(&db).await?,
4082        session.system_id.clone(),
4083        &session.app_state.config,
4084    )?;
4085    response.send(proto::GetLlmTokenResponse { token })?;
4086    Ok(())
4087}
4088
4089fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4090    let message = match message {
4091        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4092        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4093        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4094        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4095        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4096            code: frame.code.into(),
4097            reason: frame.reason,
4098        })),
4099        // We should never receive a frame while reading the message, according
4100        // to the `tungstenite` maintainers:
4101        //
4102        // > It cannot occur when you read messages from the WebSocket, but it
4103        // > can be used when you want to send the raw frames (e.g. you want to
4104        // > send the frames to the WebSocket without composing the full message first).
4105        // >
4106        // > — https://github.com/snapview/tungstenite-rs/issues/268
4107        TungsteniteMessage::Frame(_) => {
4108            bail!("received an unexpected frame while reading the message")
4109        }
4110    };
4111
4112    Ok(message)
4113}
4114
4115fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4116    match message {
4117        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4118        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4119        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4120        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4121        AxumMessage::Close(frame) => {
4122            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4123                code: frame.code.into(),
4124                reason: frame.reason,
4125            }))
4126        }
4127    }
4128}
4129
4130fn notify_membership_updated(
4131    connection_pool: &mut ConnectionPool,
4132    result: MembershipUpdated,
4133    user_id: UserId,
4134    peer: &Peer,
4135) {
4136    for membership in &result.new_channels.channel_memberships {
4137        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4138    }
4139    for channel_id in &result.removed_channels {
4140        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4141    }
4142
4143    let user_channels_update = proto::UpdateUserChannels {
4144        channel_memberships: result
4145            .new_channels
4146            .channel_memberships
4147            .iter()
4148            .map(|cm| proto::ChannelMembership {
4149                channel_id: cm.channel_id.to_proto(),
4150                role: cm.role.into(),
4151            })
4152            .collect(),
4153        ..Default::default()
4154    };
4155
4156    let mut update = build_channels_update(result.new_channels);
4157    update.delete_channels = result
4158        .removed_channels
4159        .into_iter()
4160        .map(|id| id.to_proto())
4161        .collect();
4162    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4163
4164    for connection_id in connection_pool.user_connection_ids(user_id) {
4165        peer.send(connection_id, user_channels_update.clone())
4166            .trace_err();
4167        peer.send(connection_id, update.clone()).trace_err();
4168    }
4169}
4170
4171fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4172    proto::UpdateUserChannels {
4173        channel_memberships: channels
4174            .channel_memberships
4175            .iter()
4176            .map(|m| proto::ChannelMembership {
4177                channel_id: m.channel_id.to_proto(),
4178                role: m.role.into(),
4179            })
4180            .collect(),
4181        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4182        observed_channel_message_id: channels.observed_channel_messages.clone(),
4183    }
4184}
4185
4186fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4187    let mut update = proto::UpdateChannels::default();
4188
4189    for channel in channels.channels {
4190        update.channels.push(channel.to_proto());
4191    }
4192
4193    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4194    update.latest_channel_message_ids = channels.latest_channel_messages;
4195
4196    for (channel_id, participants) in channels.channel_participants {
4197        update
4198            .channel_participants
4199            .push(proto::ChannelParticipants {
4200                channel_id: channel_id.to_proto(),
4201                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4202            });
4203    }
4204
4205    for channel in channels.invited_channels {
4206        update.channel_invitations.push(channel.to_proto());
4207    }
4208
4209    update
4210}
4211
4212fn build_initial_contacts_update(
4213    contacts: Vec<db::Contact>,
4214    pool: &ConnectionPool,
4215) -> proto::UpdateContacts {
4216    let mut update = proto::UpdateContacts::default();
4217
4218    for contact in contacts {
4219        match contact {
4220            db::Contact::Accepted { user_id, busy } => {
4221                update.contacts.push(contact_for_user(user_id, busy, pool));
4222            }
4223            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4224            db::Contact::Incoming { user_id } => {
4225                update
4226                    .incoming_requests
4227                    .push(proto::IncomingContactRequest {
4228                        requester_id: user_id.to_proto(),
4229                    })
4230            }
4231        }
4232    }
4233
4234    update
4235}
4236
4237fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4238    proto::Contact {
4239        user_id: user_id.to_proto(),
4240        online: pool.is_user_online(user_id),
4241        busy,
4242    }
4243}
4244
4245fn room_updated(room: &proto::Room, peer: &Peer) {
4246    broadcast(
4247        None,
4248        room.participants
4249            .iter()
4250            .filter_map(|participant| Some(participant.peer_id?.into())),
4251        |peer_id| {
4252            peer.send(
4253                peer_id,
4254                proto::RoomUpdated {
4255                    room: Some(room.clone()),
4256                },
4257            )
4258        },
4259    );
4260}
4261
4262fn channel_updated(
4263    channel: &db::channel::Model,
4264    room: &proto::Room,
4265    peer: &Peer,
4266    pool: &ConnectionPool,
4267) {
4268    let participants = room
4269        .participants
4270        .iter()
4271        .map(|p| p.user_id)
4272        .collect::<Vec<_>>();
4273
4274    broadcast(
4275        None,
4276        pool.channel_connection_ids(channel.root_id())
4277            .filter_map(|(channel_id, role)| {
4278                role.can_see_channel(channel.visibility)
4279                    .then_some(channel_id)
4280            }),
4281        |peer_id| {
4282            peer.send(
4283                peer_id,
4284                proto::UpdateChannels {
4285                    channel_participants: vec![proto::ChannelParticipants {
4286                        channel_id: channel.id.to_proto(),
4287                        participant_user_ids: participants.clone(),
4288                    }],
4289                    ..Default::default()
4290                },
4291            )
4292        },
4293    );
4294}
4295
4296async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4297    let db = session.db().await;
4298
4299    let contacts = db.get_contacts(user_id).await?;
4300    let busy = db.is_user_busy(user_id).await?;
4301
4302    let pool = session.connection_pool().await;
4303    let updated_contact = contact_for_user(user_id, busy, &pool);
4304    for contact in contacts {
4305        if let db::Contact::Accepted {
4306            user_id: contact_user_id,
4307            ..
4308        } = contact
4309        {
4310            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4311                session
4312                    .peer
4313                    .send(
4314                        contact_conn_id,
4315                        proto::UpdateContacts {
4316                            contacts: vec![updated_contact.clone()],
4317                            remove_contacts: Default::default(),
4318                            incoming_requests: Default::default(),
4319                            remove_incoming_requests: Default::default(),
4320                            outgoing_requests: Default::default(),
4321                            remove_outgoing_requests: Default::default(),
4322                        },
4323                    )
4324                    .trace_err();
4325            }
4326        }
4327    }
4328    Ok(())
4329}
4330
4331async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4332    let mut contacts_to_update = HashSet::default();
4333
4334    let room_id;
4335    let canceled_calls_to_user_ids;
4336    let livekit_room;
4337    let delete_livekit_room;
4338    let room;
4339    let channel;
4340
4341    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4342        contacts_to_update.insert(session.user_id());
4343
4344        for project in left_room.left_projects.values() {
4345            project_left(project, session);
4346        }
4347
4348        room_id = RoomId::from_proto(left_room.room.id);
4349        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4350        livekit_room = mem::take(&mut left_room.room.livekit_room);
4351        delete_livekit_room = left_room.deleted;
4352        room = mem::take(&mut left_room.room);
4353        channel = mem::take(&mut left_room.channel);
4354
4355        room_updated(&room, &session.peer);
4356    } else {
4357        return Ok(());
4358    }
4359
4360    if let Some(channel) = channel {
4361        channel_updated(
4362            &channel,
4363            &room,
4364            &session.peer,
4365            &*session.connection_pool().await,
4366        );
4367    }
4368
4369    {
4370        let pool = session.connection_pool().await;
4371        for canceled_user_id in canceled_calls_to_user_ids {
4372            for connection_id in pool.user_connection_ids(canceled_user_id) {
4373                session
4374                    .peer
4375                    .send(
4376                        connection_id,
4377                        proto::CallCanceled {
4378                            room_id: room_id.to_proto(),
4379                        },
4380                    )
4381                    .trace_err();
4382            }
4383            contacts_to_update.insert(canceled_user_id);
4384        }
4385    }
4386
4387    for contact_user_id in contacts_to_update {
4388        update_user_contacts(contact_user_id, session).await?;
4389    }
4390
4391    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4392        live_kit
4393            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4394            .await
4395            .trace_err();
4396
4397        if delete_livekit_room {
4398            live_kit.delete_room(livekit_room).await.trace_err();
4399        }
4400    }
4401
4402    Ok(())
4403}
4404
4405async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4406    let left_channel_buffers = session
4407        .db()
4408        .await
4409        .leave_channel_buffers(session.connection_id)
4410        .await?;
4411
4412    for left_buffer in left_channel_buffers {
4413        channel_buffer_updated(
4414            session.connection_id,
4415            left_buffer.connections,
4416            &proto::UpdateChannelBufferCollaborators {
4417                channel_id: left_buffer.channel_id.to_proto(),
4418                collaborators: left_buffer.collaborators,
4419            },
4420            &session.peer,
4421        );
4422    }
4423
4424    Ok(())
4425}
4426
4427fn project_left(project: &db::LeftProject, session: &Session) {
4428    for connection_id in &project.connection_ids {
4429        if project.should_unshare {
4430            session
4431                .peer
4432                .send(
4433                    *connection_id,
4434                    proto::UnshareProject {
4435                        project_id: project.id.to_proto(),
4436                    },
4437                )
4438                .trace_err();
4439        } else {
4440            session
4441                .peer
4442                .send(
4443                    *connection_id,
4444                    proto::RemoveProjectCollaborator {
4445                        project_id: project.id.to_proto(),
4446                        peer_id: Some(session.connection_id.into()),
4447                    },
4448                )
4449                .trace_err();
4450        }
4451    }
4452}
4453
4454pub trait ResultExt {
4455    type Ok;
4456
4457    fn trace_err(self) -> Option<Self::Ok>;
4458}
4459
4460impl<T, E> ResultExt for Result<T, E>
4461where
4462    E: std::fmt::Debug,
4463{
4464    type Ok = T;
4465
4466    #[track_caller]
4467    fn trace_err(self) -> Option<T> {
4468        match self {
4469            Ok(value) => Some(value),
4470            Err(error) => {
4471                tracing::error!("{:?}", error);
4472                None
4473            }
4474        }
4475    }
4476}