rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
 313            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 314            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 315            .add_request_handler(
 316                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 317            )
 318            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 323            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 324            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 325            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 326            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 327            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 328            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 329            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 330            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 331            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 332            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 334            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 335            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 336            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 337            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 338            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 339            .add_message_handler(create_buffer_for_peer)
 340            .add_request_handler(update_buffer)
 341            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 342            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 343            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 344            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 345            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 346            .add_request_handler(get_users)
 347            .add_request_handler(fuzzy_search_users)
 348            .add_request_handler(request_contact)
 349            .add_request_handler(remove_contact)
 350            .add_request_handler(respond_to_contact_request)
 351            .add_message_handler(subscribe_to_channels)
 352            .add_request_handler(create_channel)
 353            .add_request_handler(delete_channel)
 354            .add_request_handler(invite_channel_member)
 355            .add_request_handler(remove_channel_member)
 356            .add_request_handler(set_channel_member_role)
 357            .add_request_handler(set_channel_visibility)
 358            .add_request_handler(rename_channel)
 359            .add_request_handler(join_channel_buffer)
 360            .add_request_handler(leave_channel_buffer)
 361            .add_message_handler(update_channel_buffer)
 362            .add_request_handler(rejoin_channel_buffers)
 363            .add_request_handler(get_channel_members)
 364            .add_request_handler(respond_to_channel_invite)
 365            .add_request_handler(join_channel)
 366            .add_request_handler(join_channel_chat)
 367            .add_message_handler(leave_channel_chat)
 368            .add_request_handler(send_channel_message)
 369            .add_request_handler(remove_channel_message)
 370            .add_request_handler(update_channel_message)
 371            .add_request_handler(get_channel_messages)
 372            .add_request_handler(get_channel_messages_by_id)
 373            .add_request_handler(get_notifications)
 374            .add_request_handler(mark_notification_as_read)
 375            .add_request_handler(move_channel)
 376            .add_request_handler(follow)
 377            .add_message_handler(unfollow)
 378            .add_message_handler(update_followers)
 379            .add_request_handler(get_private_user_info)
 380            .add_request_handler(get_llm_api_token)
 381            .add_request_handler(accept_terms_of_service)
 382            .add_message_handler(acknowledge_channel_message)
 383            .add_message_handler(acknowledge_buffer_version)
 384            .add_request_handler(get_supermaven_api_key)
 385            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 386            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 387            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 388            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 389            .add_message_handler(update_context)
 390            .add_request_handler({
 391                let app_state = app_state.clone();
 392                move |request, response, session| {
 393                    let app_state = app_state.clone();
 394                    async move {
 395                        count_language_model_tokens(request, response, session, &app_state.config)
 396                            .await
 397                    }
 398                }
 399            })
 400            .add_request_handler(get_cached_embeddings)
 401            .add_request_handler({
 402                let app_state = app_state.clone();
 403                move |request, response, session| {
 404                    compute_embeddings(
 405                        request,
 406                        response,
 407                        session,
 408                        app_state.config.openai_api_key.clone(),
 409                    )
 410                }
 411            });
 412
 413        Arc::new(server)
 414    }
 415
 416    pub async fn start(&self) -> Result<()> {
 417        let server_id = *self.id.lock();
 418        let app_state = self.app_state.clone();
 419        let peer = self.peer.clone();
 420        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 421        let pool = self.connection_pool.clone();
 422        let livekit_client = self.app_state.livekit_client.clone();
 423
 424        let span = info_span!("start server");
 425        self.app_state.executor.spawn_detached(
 426            async move {
 427                tracing::info!("waiting for cleanup timeout");
 428                timeout.await;
 429                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 430                if let Some((room_ids, channel_ids)) = app_state
 431                    .db
 432                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 433                    .await
 434                    .trace_err()
 435                {
 436                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 437                    tracing::info!(
 438                        stale_channel_buffer_count = channel_ids.len(),
 439                        "retrieved stale channel buffers"
 440                    );
 441
 442                    for channel_id in channel_ids {
 443                        if let Some(refreshed_channel_buffer) = app_state
 444                            .db
 445                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 446                            .await
 447                            .trace_err()
 448                        {
 449                            for connection_id in refreshed_channel_buffer.connection_ids {
 450                                peer.send(
 451                                    connection_id,
 452                                    proto::UpdateChannelBufferCollaborators {
 453                                        channel_id: channel_id.to_proto(),
 454                                        collaborators: refreshed_channel_buffer
 455                                            .collaborators
 456                                            .clone(),
 457                                    },
 458                                )
 459                                .trace_err();
 460                            }
 461                        }
 462                    }
 463
 464                    for room_id in room_ids {
 465                        let mut contacts_to_update = HashSet::default();
 466                        let mut canceled_calls_to_user_ids = Vec::new();
 467                        let mut livekit_room = String::new();
 468                        let mut delete_livekit_room = false;
 469
 470                        if let Some(mut refreshed_room) = app_state
 471                            .db
 472                            .clear_stale_room_participants(room_id, server_id)
 473                            .await
 474                            .trace_err()
 475                        {
 476                            tracing::info!(
 477                                room_id = room_id.0,
 478                                new_participant_count = refreshed_room.room.participants.len(),
 479                                "refreshed room"
 480                            );
 481                            room_updated(&refreshed_room.room, &peer);
 482                            if let Some(channel) = refreshed_room.channel.as_ref() {
 483                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 484                            }
 485                            contacts_to_update
 486                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 487                            contacts_to_update
 488                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 489                            canceled_calls_to_user_ids =
 490                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 491                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 492                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 493                        }
 494
 495                        {
 496                            let pool = pool.lock();
 497                            for canceled_user_id in canceled_calls_to_user_ids {
 498                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 499                                    peer.send(
 500                                        connection_id,
 501                                        proto::CallCanceled {
 502                                            room_id: room_id.to_proto(),
 503                                        },
 504                                    )
 505                                    .trace_err();
 506                                }
 507                            }
 508                        }
 509
 510                        for user_id in contacts_to_update {
 511                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 512                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 513                            if let Some((busy, contacts)) = busy.zip(contacts) {
 514                                let pool = pool.lock();
 515                                let updated_contact = contact_for_user(user_id, busy, &pool);
 516                                for contact in contacts {
 517                                    if let db::Contact::Accepted {
 518                                        user_id: contact_user_id,
 519                                        ..
 520                                    } = contact
 521                                    {
 522                                        for contact_conn_id in
 523                                            pool.user_connection_ids(contact_user_id)
 524                                        {
 525                                            peer.send(
 526                                                contact_conn_id,
 527                                                proto::UpdateContacts {
 528                                                    contacts: vec![updated_contact.clone()],
 529                                                    remove_contacts: Default::default(),
 530                                                    incoming_requests: Default::default(),
 531                                                    remove_incoming_requests: Default::default(),
 532                                                    outgoing_requests: Default::default(),
 533                                                    remove_outgoing_requests: Default::default(),
 534                                                },
 535                                            )
 536                                            .trace_err();
 537                                        }
 538                                    }
 539                                }
 540                            }
 541                        }
 542
 543                        if let Some(live_kit) = livekit_client.as_ref() {
 544                            if delete_livekit_room {
 545                                live_kit.delete_room(livekit_room).await.trace_err();
 546                            }
 547                        }
 548                    }
 549                }
 550
 551                app_state
 552                    .db
 553                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 554                    .await
 555                    .trace_err();
 556            }
 557            .instrument(span),
 558        );
 559        Ok(())
 560    }
 561
 562    pub fn teardown(&self) {
 563        self.peer.teardown();
 564        self.connection_pool.lock().reset();
 565        let _ = self.teardown.send(true);
 566    }
 567
 568    #[cfg(test)]
 569    pub fn reset(&self, id: ServerId) {
 570        self.teardown();
 571        *self.id.lock() = id;
 572        self.peer.reset(id.0 as u32);
 573        let _ = self.teardown.send(false);
 574    }
 575
 576    #[cfg(test)]
 577    pub fn id(&self) -> ServerId {
 578        *self.id.lock()
 579    }
 580
 581    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 582    where
 583        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 584        Fut: 'static + Send + Future<Output = Result<()>>,
 585        M: EnvelopedMessage,
 586    {
 587        let prev_handler = self.handlers.insert(
 588            TypeId::of::<M>(),
 589            Box::new(move |envelope, session| {
 590                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 591                let received_at = envelope.received_at;
 592                tracing::info!("message received");
 593                let start_time = Instant::now();
 594                let future = (handler)(*envelope, session);
 595                async move {
 596                    let result = future.await;
 597                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 598                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 599                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 600                    let payload_type = M::NAME;
 601
 602                    match result {
 603                        Err(error) => {
 604                            tracing::error!(
 605                                ?error,
 606                                total_duration_ms,
 607                                processing_duration_ms,
 608                                queue_duration_ms,
 609                                payload_type,
 610                                "error handling message"
 611                            )
 612                        }
 613                        Ok(()) => tracing::info!(
 614                            total_duration_ms,
 615                            processing_duration_ms,
 616                            queue_duration_ms,
 617                            "finished handling message"
 618                        ),
 619                    }
 620                }
 621                .boxed()
 622            }),
 623        );
 624        if prev_handler.is_some() {
 625            panic!("registered a handler for the same message twice");
 626        }
 627        self
 628    }
 629
 630    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 631    where
 632        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 633        Fut: 'static + Send + Future<Output = Result<()>>,
 634        M: EnvelopedMessage,
 635    {
 636        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 637        self
 638    }
 639
 640    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 641    where
 642        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 643        Fut: Send + Future<Output = Result<()>>,
 644        M: RequestMessage,
 645    {
 646        let handler = Arc::new(handler);
 647        self.add_handler(move |envelope, session| {
 648            let receipt = envelope.receipt();
 649            let handler = handler.clone();
 650            async move {
 651                let peer = session.peer.clone();
 652                let responded = Arc::new(AtomicBool::default());
 653                let response = Response {
 654                    peer: peer.clone(),
 655                    responded: responded.clone(),
 656                    receipt,
 657                };
 658                match (handler)(envelope.payload, response, session).await {
 659                    Ok(()) => {
 660                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 661                            Ok(())
 662                        } else {
 663                            Err(anyhow!("handler did not send a response"))?
 664                        }
 665                    }
 666                    Err(error) => {
 667                        let proto_err = match &error {
 668                            Error::Internal(err) => err.to_proto(),
 669                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 670                        };
 671                        peer.respond_with_error(receipt, proto_err)?;
 672                        Err(error)
 673                    }
 674                }
 675            }
 676        })
 677    }
 678
 679    #[allow(clippy::too_many_arguments)]
 680    pub fn handle_connection(
 681        self: &Arc<Self>,
 682        connection: Connection,
 683        address: String,
 684        principal: Principal,
 685        zed_version: ZedVersion,
 686        geoip_country_code: Option<String>,
 687        system_id: Option<String>,
 688        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 689        executor: Executor,
 690    ) -> impl Future<Output = ()> {
 691        let this = self.clone();
 692        let span = info_span!("handle connection", %address,
 693            connection_id=field::Empty,
 694            user_id=field::Empty,
 695            login=field::Empty,
 696            impersonator=field::Empty,
 697            geoip_country_code=field::Empty
 698        );
 699        principal.update_span(&span);
 700        if let Some(country_code) = geoip_country_code.as_ref() {
 701            span.record("geoip_country_code", country_code);
 702        }
 703
 704        let mut teardown = self.teardown.subscribe();
 705        async move {
 706            if *teardown.borrow() {
 707                tracing::error!("server is tearing down");
 708                return
 709            }
 710            let (connection_id, handle_io, mut incoming_rx) = this
 711                .peer
 712                .add_connection(connection, {
 713                    let executor = executor.clone();
 714                    move |duration| executor.sleep(duration)
 715                });
 716            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 717
 718            tracing::info!("connection opened");
 719
 720            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 721            let http_client = match ReqwestClient::user_agent(&user_agent) {
 722                Ok(http_client) => Arc::new(http_client),
 723                Err(error) => {
 724                    tracing::error!(?error, "failed to create HTTP client");
 725                    return;
 726                }
 727            };
 728
 729            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 730                    supermaven_admin_api_key.to_string(),
 731                    http_client.clone(),
 732                )));
 733
 734            let session = Session {
 735                principal: principal.clone(),
 736                connection_id,
 737                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 738                peer: this.peer.clone(),
 739                connection_pool: this.connection_pool.clone(),
 740                app_state: this.app_state.clone(),
 741                http_client,
 742                geoip_country_code,
 743                system_id,
 744                _executor: executor.clone(),
 745                supermaven_client,
 746            };
 747
 748            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 749                tracing::error!(?error, "failed to send initial client update");
 750                return;
 751            }
 752
 753            let handle_io = handle_io.fuse();
 754            futures::pin_mut!(handle_io);
 755
 756            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 757            // This prevents deadlocks when e.g., client A performs a request to client B and
 758            // client B performs a request to client A. If both clients stop processing further
 759            // messages until their respective request completes, they won't have a chance to
 760            // respond to the other client's request and cause a deadlock.
 761            //
 762            // This arrangement ensures we will attempt to process earlier messages first, but fall
 763            // back to processing messages arrived later in the spirit of making progress.
 764            let mut foreground_message_handlers = FuturesUnordered::new();
 765            let concurrent_handlers = Arc::new(Semaphore::new(256));
 766            loop {
 767                let next_message = async {
 768                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 769                    let message = incoming_rx.next().await;
 770                    (permit, message)
 771                }.fuse();
 772                futures::pin_mut!(next_message);
 773                futures::select_biased! {
 774                    _ = teardown.changed().fuse() => return,
 775                    result = handle_io => {
 776                        if let Err(error) = result {
 777                            tracing::error!(?error, "error handling I/O");
 778                        }
 779                        break;
 780                    }
 781                    _ = foreground_message_handlers.next() => {}
 782                    next_message = next_message => {
 783                        let (permit, message) = next_message;
 784                        if let Some(message) = message {
 785                            let type_name = message.payload_type_name();
 786                            // note: we copy all the fields from the parent span so we can query them in the logs.
 787                            // (https://github.com/tokio-rs/tracing/issues/2670).
 788                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 789                                user_id=field::Empty,
 790                                login=field::Empty,
 791                                impersonator=field::Empty,
 792                            );
 793                            principal.update_span(&span);
 794                            let span_enter = span.enter();
 795                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 796                                let is_background = message.is_background();
 797                                let handle_message = (handler)(message, session.clone());
 798                                drop(span_enter);
 799
 800                                let handle_message = async move {
 801                                    handle_message.await;
 802                                    drop(permit);
 803                                }.instrument(span);
 804                                if is_background {
 805                                    executor.spawn_detached(handle_message);
 806                                } else {
 807                                    foreground_message_handlers.push(handle_message);
 808                                }
 809                            } else {
 810                                tracing::error!("no message handler");
 811                            }
 812                        } else {
 813                            tracing::info!("connection closed");
 814                            break;
 815                        }
 816                    }
 817                }
 818            }
 819
 820            drop(foreground_message_handlers);
 821            tracing::info!("signing out");
 822            if let Err(error) = connection_lost(session, teardown, executor).await {
 823                tracing::error!(?error, "error signing out");
 824            }
 825
 826        }.instrument(span)
 827    }
 828
 829    async fn send_initial_client_update(
 830        &self,
 831        connection_id: ConnectionId,
 832        principal: &Principal,
 833        zed_version: ZedVersion,
 834        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 835        session: &Session,
 836    ) -> Result<()> {
 837        self.peer.send(
 838            connection_id,
 839            proto::Hello {
 840                peer_id: Some(connection_id.into()),
 841            },
 842        )?;
 843        tracing::info!("sent hello message");
 844        if let Some(send_connection_id) = send_connection_id.take() {
 845            let _ = send_connection_id.send(connection_id);
 846        }
 847
 848        match principal {
 849            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 850                if !user.connected_once {
 851                    self.peer.send(connection_id, proto::ShowContacts {})?;
 852                    self.app_state
 853                        .db
 854                        .set_user_connected_once(user.id, true)
 855                        .await?;
 856                }
 857
 858                update_user_plan(user.id, session).await?;
 859
 860                let contacts = self.app_state.db.get_contacts(user.id).await?;
 861
 862                {
 863                    let mut pool = self.connection_pool.lock();
 864                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 865                    self.peer.send(
 866                        connection_id,
 867                        build_initial_contacts_update(contacts, &pool),
 868                    )?;
 869                }
 870
 871                if should_auto_subscribe_to_channels(zed_version) {
 872                    subscribe_user_to_channels(user.id, session).await?;
 873                }
 874
 875                if let Some(incoming_call) =
 876                    self.app_state.db.incoming_call_for_user(user.id).await?
 877                {
 878                    self.peer.send(connection_id, incoming_call)?;
 879                }
 880
 881                update_user_contacts(user.id, session).await?;
 882            }
 883        }
 884
 885        Ok(())
 886    }
 887
 888    pub async fn invite_code_redeemed(
 889        self: &Arc<Self>,
 890        inviter_id: UserId,
 891        invitee_id: UserId,
 892    ) -> Result<()> {
 893        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 894            if let Some(code) = &user.invite_code {
 895                let pool = self.connection_pool.lock();
 896                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 897                for connection_id in pool.user_connection_ids(inviter_id) {
 898                    self.peer.send(
 899                        connection_id,
 900                        proto::UpdateContacts {
 901                            contacts: vec![invitee_contact.clone()],
 902                            ..Default::default()
 903                        },
 904                    )?;
 905                    self.peer.send(
 906                        connection_id,
 907                        proto::UpdateInviteInfo {
 908                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 909                            count: user.invite_count as u32,
 910                        },
 911                    )?;
 912                }
 913            }
 914        }
 915        Ok(())
 916    }
 917
 918    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 919        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 920            if let Some(invite_code) = &user.invite_code {
 921                let pool = self.connection_pool.lock();
 922                for connection_id in pool.user_connection_ids(user_id) {
 923                    self.peer.send(
 924                        connection_id,
 925                        proto::UpdateInviteInfo {
 926                            url: format!(
 927                                "{}{}",
 928                                self.app_state.config.invite_link_prefix, invite_code
 929                            ),
 930                            count: user.invite_count as u32,
 931                        },
 932                    )?;
 933                }
 934            }
 935        }
 936        Ok(())
 937    }
 938
 939    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 940        let pool = self.connection_pool.lock();
 941        for connection_id in pool.user_connection_ids(user_id) {
 942            self.peer
 943                .send(connection_id, proto::RefreshLlmToken {})
 944                .trace_err();
 945        }
 946    }
 947
 948    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 949        ServerSnapshot {
 950            connection_pool: ConnectionPoolGuard {
 951                guard: self.connection_pool.lock(),
 952                _not_send: PhantomData,
 953            },
 954            peer: &self.peer,
 955        }
 956    }
 957}
 958
 959impl<'a> Deref for ConnectionPoolGuard<'a> {
 960    type Target = ConnectionPool;
 961
 962    fn deref(&self) -> &Self::Target {
 963        &self.guard
 964    }
 965}
 966
 967impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 968    fn deref_mut(&mut self) -> &mut Self::Target {
 969        &mut self.guard
 970    }
 971}
 972
 973impl<'a> Drop for ConnectionPoolGuard<'a> {
 974    fn drop(&mut self) {
 975        #[cfg(test)]
 976        self.check_invariants();
 977    }
 978}
 979
 980fn broadcast<F>(
 981    sender_id: Option<ConnectionId>,
 982    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 983    mut f: F,
 984) where
 985    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 986{
 987    for receiver_id in receiver_ids {
 988        if Some(receiver_id) != sender_id {
 989            if let Err(error) = f(receiver_id) {
 990                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 991            }
 992        }
 993    }
 994}
 995
 996pub struct ProtocolVersion(u32);
 997
 998impl Header for ProtocolVersion {
 999    fn name() -> &'static HeaderName {
1000        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1001        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1002    }
1003
1004    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1005    where
1006        Self: Sized,
1007        I: Iterator<Item = &'i axum::http::HeaderValue>,
1008    {
1009        let version = values
1010            .next()
1011            .ok_or_else(axum::headers::Error::invalid)?
1012            .to_str()
1013            .map_err(|_| axum::headers::Error::invalid())?
1014            .parse()
1015            .map_err(|_| axum::headers::Error::invalid())?;
1016        Ok(Self(version))
1017    }
1018
1019    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1020        values.extend([self.0.to_string().parse().unwrap()]);
1021    }
1022}
1023
1024pub struct AppVersionHeader(SemanticVersion);
1025impl Header for AppVersionHeader {
1026    fn name() -> &'static HeaderName {
1027        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1028        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1029    }
1030
1031    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1032    where
1033        Self: Sized,
1034        I: Iterator<Item = &'i axum::http::HeaderValue>,
1035    {
1036        let version = values
1037            .next()
1038            .ok_or_else(axum::headers::Error::invalid)?
1039            .to_str()
1040            .map_err(|_| axum::headers::Error::invalid())?
1041            .parse()
1042            .map_err(|_| axum::headers::Error::invalid())?;
1043        Ok(Self(version))
1044    }
1045
1046    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1047        values.extend([self.0.to_string().parse().unwrap()]);
1048    }
1049}
1050
1051pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1052    Router::new()
1053        .route("/rpc", get(handle_websocket_request))
1054        .layer(
1055            ServiceBuilder::new()
1056                .layer(Extension(server.app_state.clone()))
1057                .layer(middleware::from_fn(auth::validate_header)),
1058        )
1059        .route("/metrics", get(handle_metrics))
1060        .layer(Extension(server))
1061}
1062
1063#[allow(clippy::too_many_arguments)]
1064pub async fn handle_websocket_request(
1065    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1066    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1067    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1068    Extension(server): Extension<Arc<Server>>,
1069    Extension(principal): Extension<Principal>,
1070    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1071    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1072    ws: WebSocketUpgrade,
1073) -> axum::response::Response {
1074    if protocol_version != rpc::PROTOCOL_VERSION {
1075        return (
1076            StatusCode::UPGRADE_REQUIRED,
1077            "client must be upgraded".to_string(),
1078        )
1079            .into_response();
1080    }
1081
1082    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1083        return (
1084            StatusCode::UPGRADE_REQUIRED,
1085            "no version header found".to_string(),
1086        )
1087            .into_response();
1088    };
1089
1090    if !version.can_collaborate() {
1091        return (
1092            StatusCode::UPGRADE_REQUIRED,
1093            "client must be upgraded".to_string(),
1094        )
1095            .into_response();
1096    }
1097
1098    let socket_address = socket_address.to_string();
1099    ws.on_upgrade(move |socket| {
1100        let socket = socket
1101            .map_ok(to_tungstenite_message)
1102            .err_into()
1103            .with(|message| async move { to_axum_message(message) });
1104        let connection = Connection::new(Box::pin(socket));
1105        async move {
1106            server
1107                .handle_connection(
1108                    connection,
1109                    socket_address,
1110                    principal,
1111                    version,
1112                    country_code_header.map(|header| header.to_string()),
1113                    system_id_header.map(|header| header.to_string()),
1114                    None,
1115                    Executor::Production,
1116                )
1117                .await;
1118        }
1119    })
1120}
1121
1122pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1123    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1124    let connections_metric = CONNECTIONS_METRIC
1125        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1126
1127    let connections = server
1128        .connection_pool
1129        .lock()
1130        .connections()
1131        .filter(|connection| !connection.admin)
1132        .count();
1133    connections_metric.set(connections as _);
1134
1135    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1136    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1137        register_int_gauge!(
1138            "shared_projects",
1139            "number of open projects with one or more guests"
1140        )
1141        .unwrap()
1142    });
1143
1144    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1145    shared_projects_metric.set(shared_projects as _);
1146
1147    let encoder = prometheus::TextEncoder::new();
1148    let metric_families = prometheus::gather();
1149    let encoded_metrics = encoder
1150        .encode_to_string(&metric_families)
1151        .map_err(|err| anyhow!("{}", err))?;
1152    Ok(encoded_metrics)
1153}
1154
1155#[instrument(err, skip(executor))]
1156async fn connection_lost(
1157    session: Session,
1158    mut teardown: watch::Receiver<bool>,
1159    executor: Executor,
1160) -> Result<()> {
1161    session.peer.disconnect(session.connection_id);
1162    session
1163        .connection_pool()
1164        .await
1165        .remove_connection(session.connection_id)?;
1166
1167    session
1168        .db()
1169        .await
1170        .connection_lost(session.connection_id)
1171        .await
1172        .trace_err();
1173
1174    futures::select_biased! {
1175        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1176
1177            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1178            leave_room_for_session(&session, session.connection_id).await.trace_err();
1179            leave_channel_buffers_for_session(&session)
1180                .await
1181                .trace_err();
1182
1183            if !session
1184                .connection_pool()
1185                .await
1186                .is_user_online(session.user_id())
1187            {
1188                let db = session.db().await;
1189                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1190                    room_updated(&room, &session.peer);
1191                }
1192            }
1193
1194            update_user_contacts(session.user_id(), &session).await?;
1195        },
1196        _ = teardown.changed().fuse() => {}
1197    }
1198
1199    Ok(())
1200}
1201
1202/// Acknowledges a ping from a client, used to keep the connection alive.
1203async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1204    response.send(proto::Ack {})?;
1205    Ok(())
1206}
1207
1208/// Creates a new room for calling (outside of channels)
1209async fn create_room(
1210    _request: proto::CreateRoom,
1211    response: Response<proto::CreateRoom>,
1212    session: Session,
1213) -> Result<()> {
1214    let livekit_room = nanoid::nanoid!(30);
1215
1216    let live_kit_connection_info = util::maybe!(async {
1217        let live_kit = session.app_state.livekit_client.as_ref();
1218        let live_kit = live_kit?;
1219        let user_id = session.user_id().to_string();
1220
1221        let token = live_kit
1222            .room_token(&livekit_room, &user_id.to_string())
1223            .trace_err()?;
1224
1225        Some(proto::LiveKitConnectionInfo {
1226            server_url: live_kit.url().into(),
1227            token,
1228            can_publish: true,
1229        })
1230    })
1231    .await;
1232
1233    let room = session
1234        .db()
1235        .await
1236        .create_room(session.user_id(), session.connection_id, &livekit_room)
1237        .await?;
1238
1239    response.send(proto::CreateRoomResponse {
1240        room: Some(room.clone()),
1241        live_kit_connection_info,
1242    })?;
1243
1244    update_user_contacts(session.user_id(), &session).await?;
1245    Ok(())
1246}
1247
1248/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1249async fn join_room(
1250    request: proto::JoinRoom,
1251    response: Response<proto::JoinRoom>,
1252    session: Session,
1253) -> Result<()> {
1254    let room_id = RoomId::from_proto(request.id);
1255
1256    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1257
1258    if let Some(channel_id) = channel_id {
1259        return join_channel_internal(channel_id, Box::new(response), session).await;
1260    }
1261
1262    let joined_room = {
1263        let room = session
1264            .db()
1265            .await
1266            .join_room(room_id, session.user_id(), session.connection_id)
1267            .await?;
1268        room_updated(&room.room, &session.peer);
1269        room.into_inner()
1270    };
1271
1272    for connection_id in session
1273        .connection_pool()
1274        .await
1275        .user_connection_ids(session.user_id())
1276    {
1277        session
1278            .peer
1279            .send(
1280                connection_id,
1281                proto::CallCanceled {
1282                    room_id: room_id.to_proto(),
1283                },
1284            )
1285            .trace_err();
1286    }
1287
1288    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1289    {
1290        live_kit
1291            .room_token(
1292                &joined_room.room.livekit_room,
1293                &session.user_id().to_string(),
1294            )
1295            .trace_err()
1296            .map(|token| proto::LiveKitConnectionInfo {
1297                server_url: live_kit.url().into(),
1298                token,
1299                can_publish: true,
1300            })
1301    } else {
1302        None
1303    };
1304
1305    response.send(proto::JoinRoomResponse {
1306        room: Some(joined_room.room),
1307        channel_id: None,
1308        live_kit_connection_info,
1309    })?;
1310
1311    update_user_contacts(session.user_id(), &session).await?;
1312    Ok(())
1313}
1314
1315/// Rejoin room is used to reconnect to a room after connection errors.
1316async fn rejoin_room(
1317    request: proto::RejoinRoom,
1318    response: Response<proto::RejoinRoom>,
1319    session: Session,
1320) -> Result<()> {
1321    let room;
1322    let channel;
1323    {
1324        let mut rejoined_room = session
1325            .db()
1326            .await
1327            .rejoin_room(request, session.user_id(), session.connection_id)
1328            .await?;
1329
1330        response.send(proto::RejoinRoomResponse {
1331            room: Some(rejoined_room.room.clone()),
1332            reshared_projects: rejoined_room
1333                .reshared_projects
1334                .iter()
1335                .map(|project| proto::ResharedProject {
1336                    id: project.id.to_proto(),
1337                    collaborators: project
1338                        .collaborators
1339                        .iter()
1340                        .map(|collaborator| collaborator.to_proto())
1341                        .collect(),
1342                })
1343                .collect(),
1344            rejoined_projects: rejoined_room
1345                .rejoined_projects
1346                .iter()
1347                .map(|rejoined_project| rejoined_project.to_proto())
1348                .collect(),
1349        })?;
1350        room_updated(&rejoined_room.room, &session.peer);
1351
1352        for project in &rejoined_room.reshared_projects {
1353            for collaborator in &project.collaborators {
1354                session
1355                    .peer
1356                    .send(
1357                        collaborator.connection_id,
1358                        proto::UpdateProjectCollaborator {
1359                            project_id: project.id.to_proto(),
1360                            old_peer_id: Some(project.old_connection_id.into()),
1361                            new_peer_id: Some(session.connection_id.into()),
1362                        },
1363                    )
1364                    .trace_err();
1365            }
1366
1367            broadcast(
1368                Some(session.connection_id),
1369                project
1370                    .collaborators
1371                    .iter()
1372                    .map(|collaborator| collaborator.connection_id),
1373                |connection_id| {
1374                    session.peer.forward_send(
1375                        session.connection_id,
1376                        connection_id,
1377                        proto::UpdateProject {
1378                            project_id: project.id.to_proto(),
1379                            worktrees: project.worktrees.clone(),
1380                        },
1381                    )
1382                },
1383            );
1384        }
1385
1386        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1387
1388        let rejoined_room = rejoined_room.into_inner();
1389
1390        room = rejoined_room.room;
1391        channel = rejoined_room.channel;
1392    }
1393
1394    if let Some(channel) = channel {
1395        channel_updated(
1396            &channel,
1397            &room,
1398            &session.peer,
1399            &*session.connection_pool().await,
1400        );
1401    }
1402
1403    update_user_contacts(session.user_id(), &session).await?;
1404    Ok(())
1405}
1406
1407fn notify_rejoined_projects(
1408    rejoined_projects: &mut Vec<RejoinedProject>,
1409    session: &Session,
1410) -> Result<()> {
1411    for project in rejoined_projects.iter() {
1412        for collaborator in &project.collaborators {
1413            session
1414                .peer
1415                .send(
1416                    collaborator.connection_id,
1417                    proto::UpdateProjectCollaborator {
1418                        project_id: project.id.to_proto(),
1419                        old_peer_id: Some(project.old_connection_id.into()),
1420                        new_peer_id: Some(session.connection_id.into()),
1421                    },
1422                )
1423                .trace_err();
1424        }
1425    }
1426
1427    for project in rejoined_projects {
1428        for worktree in mem::take(&mut project.worktrees) {
1429            // Stream this worktree's entries.
1430            let message = proto::UpdateWorktree {
1431                project_id: project.id.to_proto(),
1432                worktree_id: worktree.id,
1433                abs_path: worktree.abs_path.clone(),
1434                root_name: worktree.root_name,
1435                updated_entries: worktree.updated_entries,
1436                removed_entries: worktree.removed_entries,
1437                scan_id: worktree.scan_id,
1438                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1439                updated_repositories: worktree.updated_repositories,
1440                removed_repositories: worktree.removed_repositories,
1441            };
1442            for update in proto::split_worktree_update(message) {
1443                session.peer.send(session.connection_id, update.clone())?;
1444            }
1445
1446            // Stream this worktree's diagnostics.
1447            for summary in worktree.diagnostic_summaries {
1448                session.peer.send(
1449                    session.connection_id,
1450                    proto::UpdateDiagnosticSummary {
1451                        project_id: project.id.to_proto(),
1452                        worktree_id: worktree.id,
1453                        summary: Some(summary),
1454                    },
1455                )?;
1456            }
1457
1458            for settings_file in worktree.settings_files {
1459                session.peer.send(
1460                    session.connection_id,
1461                    proto::UpdateWorktreeSettings {
1462                        project_id: project.id.to_proto(),
1463                        worktree_id: worktree.id,
1464                        path: settings_file.path,
1465                        content: Some(settings_file.content),
1466                        kind: Some(settings_file.kind.to_proto().into()),
1467                    },
1468                )?;
1469            }
1470        }
1471
1472        for language_server in &project.language_servers {
1473            session.peer.send(
1474                session.connection_id,
1475                proto::UpdateLanguageServer {
1476                    project_id: project.id.to_proto(),
1477                    language_server_id: language_server.id,
1478                    variant: Some(
1479                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1480                            proto::LspDiskBasedDiagnosticsUpdated {},
1481                        ),
1482                    ),
1483                },
1484            )?;
1485        }
1486    }
1487    Ok(())
1488}
1489
1490/// leave room disconnects from the room.
1491async fn leave_room(
1492    _: proto::LeaveRoom,
1493    response: Response<proto::LeaveRoom>,
1494    session: Session,
1495) -> Result<()> {
1496    leave_room_for_session(&session, session.connection_id).await?;
1497    response.send(proto::Ack {})?;
1498    Ok(())
1499}
1500
1501/// Updates the permissions of someone else in the room.
1502async fn set_room_participant_role(
1503    request: proto::SetRoomParticipantRole,
1504    response: Response<proto::SetRoomParticipantRole>,
1505    session: Session,
1506) -> Result<()> {
1507    let user_id = UserId::from_proto(request.user_id);
1508    let role = ChannelRole::from(request.role());
1509
1510    let (livekit_room, can_publish) = {
1511        let room = session
1512            .db()
1513            .await
1514            .set_room_participant_role(
1515                session.user_id(),
1516                RoomId::from_proto(request.room_id),
1517                user_id,
1518                role,
1519            )
1520            .await?;
1521
1522        let livekit_room = room.livekit_room.clone();
1523        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1524        room_updated(&room, &session.peer);
1525        (livekit_room, can_publish)
1526    };
1527
1528    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1529        live_kit
1530            .update_participant(
1531                livekit_room.clone(),
1532                request.user_id.to_string(),
1533                livekit_server::proto::ParticipantPermission {
1534                    can_subscribe: true,
1535                    can_publish,
1536                    can_publish_data: can_publish,
1537                    hidden: false,
1538                    recorder: false,
1539                },
1540            )
1541            .await
1542            .trace_err();
1543    }
1544
1545    response.send(proto::Ack {})?;
1546    Ok(())
1547}
1548
1549/// Call someone else into the current room
1550async fn call(
1551    request: proto::Call,
1552    response: Response<proto::Call>,
1553    session: Session,
1554) -> Result<()> {
1555    let room_id = RoomId::from_proto(request.room_id);
1556    let calling_user_id = session.user_id();
1557    let calling_connection_id = session.connection_id;
1558    let called_user_id = UserId::from_proto(request.called_user_id);
1559    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1560    if !session
1561        .db()
1562        .await
1563        .has_contact(calling_user_id, called_user_id)
1564        .await?
1565    {
1566        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1567    }
1568
1569    let incoming_call = {
1570        let (room, incoming_call) = &mut *session
1571            .db()
1572            .await
1573            .call(
1574                room_id,
1575                calling_user_id,
1576                calling_connection_id,
1577                called_user_id,
1578                initial_project_id,
1579            )
1580            .await?;
1581        room_updated(room, &session.peer);
1582        mem::take(incoming_call)
1583    };
1584    update_user_contacts(called_user_id, &session).await?;
1585
1586    let mut calls = session
1587        .connection_pool()
1588        .await
1589        .user_connection_ids(called_user_id)
1590        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1591        .collect::<FuturesUnordered<_>>();
1592
1593    while let Some(call_response) = calls.next().await {
1594        match call_response.as_ref() {
1595            Ok(_) => {
1596                response.send(proto::Ack {})?;
1597                return Ok(());
1598            }
1599            Err(_) => {
1600                call_response.trace_err();
1601            }
1602        }
1603    }
1604
1605    {
1606        let room = session
1607            .db()
1608            .await
1609            .call_failed(room_id, called_user_id)
1610            .await?;
1611        room_updated(&room, &session.peer);
1612    }
1613    update_user_contacts(called_user_id, &session).await?;
1614
1615    Err(anyhow!("failed to ring user"))?
1616}
1617
1618/// Cancel an outgoing call.
1619async fn cancel_call(
1620    request: proto::CancelCall,
1621    response: Response<proto::CancelCall>,
1622    session: Session,
1623) -> Result<()> {
1624    let called_user_id = UserId::from_proto(request.called_user_id);
1625    let room_id = RoomId::from_proto(request.room_id);
1626    {
1627        let room = session
1628            .db()
1629            .await
1630            .cancel_call(room_id, session.connection_id, called_user_id)
1631            .await?;
1632        room_updated(&room, &session.peer);
1633    }
1634
1635    for connection_id in session
1636        .connection_pool()
1637        .await
1638        .user_connection_ids(called_user_id)
1639    {
1640        session
1641            .peer
1642            .send(
1643                connection_id,
1644                proto::CallCanceled {
1645                    room_id: room_id.to_proto(),
1646                },
1647            )
1648            .trace_err();
1649    }
1650    response.send(proto::Ack {})?;
1651
1652    update_user_contacts(called_user_id, &session).await?;
1653    Ok(())
1654}
1655
1656/// Decline an incoming call.
1657async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1658    let room_id = RoomId::from_proto(message.room_id);
1659    {
1660        let room = session
1661            .db()
1662            .await
1663            .decline_call(Some(room_id), session.user_id())
1664            .await?
1665            .ok_or_else(|| anyhow!("failed to decline call"))?;
1666        room_updated(&room, &session.peer);
1667    }
1668
1669    for connection_id in session
1670        .connection_pool()
1671        .await
1672        .user_connection_ids(session.user_id())
1673    {
1674        session
1675            .peer
1676            .send(
1677                connection_id,
1678                proto::CallCanceled {
1679                    room_id: room_id.to_proto(),
1680                },
1681            )
1682            .trace_err();
1683    }
1684    update_user_contacts(session.user_id(), &session).await?;
1685    Ok(())
1686}
1687
1688/// Updates other participants in the room with your current location.
1689async fn update_participant_location(
1690    request: proto::UpdateParticipantLocation,
1691    response: Response<proto::UpdateParticipantLocation>,
1692    session: Session,
1693) -> Result<()> {
1694    let room_id = RoomId::from_proto(request.room_id);
1695    let location = request
1696        .location
1697        .ok_or_else(|| anyhow!("invalid location"))?;
1698
1699    let db = session.db().await;
1700    let room = db
1701        .update_room_participant_location(room_id, session.connection_id, location)
1702        .await?;
1703
1704    room_updated(&room, &session.peer);
1705    response.send(proto::Ack {})?;
1706    Ok(())
1707}
1708
1709/// Share a project into the room.
1710async fn share_project(
1711    request: proto::ShareProject,
1712    response: Response<proto::ShareProject>,
1713    session: Session,
1714) -> Result<()> {
1715    let (project_id, room) = &*session
1716        .db()
1717        .await
1718        .share_project(
1719            RoomId::from_proto(request.room_id),
1720            session.connection_id,
1721            &request.worktrees,
1722            request.is_ssh_project,
1723        )
1724        .await?;
1725    response.send(proto::ShareProjectResponse {
1726        project_id: project_id.to_proto(),
1727    })?;
1728    room_updated(room, &session.peer);
1729
1730    Ok(())
1731}
1732
1733/// Unshare a project from the room.
1734async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1735    let project_id = ProjectId::from_proto(message.project_id);
1736    unshare_project_internal(project_id, session.connection_id, &session).await
1737}
1738
1739async fn unshare_project_internal(
1740    project_id: ProjectId,
1741    connection_id: ConnectionId,
1742    session: &Session,
1743) -> Result<()> {
1744    let delete = {
1745        let room_guard = session
1746            .db()
1747            .await
1748            .unshare_project(project_id, connection_id)
1749            .await?;
1750
1751        let (delete, room, guest_connection_ids) = &*room_guard;
1752
1753        let message = proto::UnshareProject {
1754            project_id: project_id.to_proto(),
1755        };
1756
1757        broadcast(
1758            Some(connection_id),
1759            guest_connection_ids.iter().copied(),
1760            |conn_id| session.peer.send(conn_id, message.clone()),
1761        );
1762        if let Some(room) = room {
1763            room_updated(room, &session.peer);
1764        }
1765
1766        *delete
1767    };
1768
1769    if delete {
1770        let db = session.db().await;
1771        db.delete_project(project_id).await?;
1772    }
1773
1774    Ok(())
1775}
1776
1777/// Join someone elses shared project.
1778async fn join_project(
1779    request: proto::JoinProject,
1780    response: Response<proto::JoinProject>,
1781    session: Session,
1782) -> Result<()> {
1783    let project_id = ProjectId::from_proto(request.project_id);
1784
1785    tracing::info!(%project_id, "join project");
1786
1787    let db = session.db().await;
1788    let (project, replica_id) = &mut *db
1789        .join_project(project_id, session.connection_id, session.user_id())
1790        .await?;
1791    drop(db);
1792    tracing::info!(%project_id, "join remote project");
1793    join_project_internal(response, session, project, replica_id)
1794}
1795
1796trait JoinProjectInternalResponse {
1797    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1798}
1799impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1800    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1801        Response::<proto::JoinProject>::send(self, result)
1802    }
1803}
1804
1805fn join_project_internal(
1806    response: impl JoinProjectInternalResponse,
1807    session: Session,
1808    project: &mut Project,
1809    replica_id: &ReplicaId,
1810) -> Result<()> {
1811    let collaborators = project
1812        .collaborators
1813        .iter()
1814        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1815        .map(|collaborator| collaborator.to_proto())
1816        .collect::<Vec<_>>();
1817    let project_id = project.id;
1818    let guest_user_id = session.user_id();
1819
1820    let worktrees = project
1821        .worktrees
1822        .iter()
1823        .map(|(id, worktree)| proto::WorktreeMetadata {
1824            id: *id,
1825            root_name: worktree.root_name.clone(),
1826            visible: worktree.visible,
1827            abs_path: worktree.abs_path.clone(),
1828        })
1829        .collect::<Vec<_>>();
1830
1831    let add_project_collaborator = proto::AddProjectCollaborator {
1832        project_id: project_id.to_proto(),
1833        collaborator: Some(proto::Collaborator {
1834            peer_id: Some(session.connection_id.into()),
1835            replica_id: replica_id.0 as u32,
1836            user_id: guest_user_id.to_proto(),
1837            is_host: false,
1838        }),
1839    };
1840
1841    for collaborator in &collaborators {
1842        session
1843            .peer
1844            .send(
1845                collaborator.peer_id.unwrap().into(),
1846                add_project_collaborator.clone(),
1847            )
1848            .trace_err();
1849    }
1850
1851    // First, we send the metadata associated with each worktree.
1852    response.send(proto::JoinProjectResponse {
1853        project_id: project.id.0 as u64,
1854        worktrees: worktrees.clone(),
1855        replica_id: replica_id.0 as u32,
1856        collaborators: collaborators.clone(),
1857        language_servers: project.language_servers.clone(),
1858        role: project.role.into(),
1859    })?;
1860
1861    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1862        // Stream this worktree's entries.
1863        let message = proto::UpdateWorktree {
1864            project_id: project_id.to_proto(),
1865            worktree_id,
1866            abs_path: worktree.abs_path.clone(),
1867            root_name: worktree.root_name,
1868            updated_entries: worktree.entries,
1869            removed_entries: Default::default(),
1870            scan_id: worktree.scan_id,
1871            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1872            updated_repositories: worktree.repository_entries.into_values().collect(),
1873            removed_repositories: Default::default(),
1874        };
1875        for update in proto::split_worktree_update(message) {
1876            session.peer.send(session.connection_id, update.clone())?;
1877        }
1878
1879        // Stream this worktree's diagnostics.
1880        for summary in worktree.diagnostic_summaries {
1881            session.peer.send(
1882                session.connection_id,
1883                proto::UpdateDiagnosticSummary {
1884                    project_id: project_id.to_proto(),
1885                    worktree_id: worktree.id,
1886                    summary: Some(summary),
1887                },
1888            )?;
1889        }
1890
1891        for settings_file in worktree.settings_files {
1892            session.peer.send(
1893                session.connection_id,
1894                proto::UpdateWorktreeSettings {
1895                    project_id: project_id.to_proto(),
1896                    worktree_id: worktree.id,
1897                    path: settings_file.path,
1898                    content: Some(settings_file.content),
1899                    kind: Some(settings_file.kind.to_proto() as i32),
1900                },
1901            )?;
1902        }
1903    }
1904
1905    for language_server in &project.language_servers {
1906        session.peer.send(
1907            session.connection_id,
1908            proto::UpdateLanguageServer {
1909                project_id: project_id.to_proto(),
1910                language_server_id: language_server.id,
1911                variant: Some(
1912                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1913                        proto::LspDiskBasedDiagnosticsUpdated {},
1914                    ),
1915                ),
1916            },
1917        )?;
1918    }
1919
1920    Ok(())
1921}
1922
1923/// Leave someone elses shared project.
1924async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1925    let sender_id = session.connection_id;
1926    let project_id = ProjectId::from_proto(request.project_id);
1927    let db = session.db().await;
1928
1929    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1930    tracing::info!(
1931        %project_id,
1932        "leave project"
1933    );
1934
1935    project_left(project, &session);
1936    if let Some(room) = room {
1937        room_updated(room, &session.peer);
1938    }
1939
1940    Ok(())
1941}
1942
1943/// Updates other participants with changes to the project
1944async fn update_project(
1945    request: proto::UpdateProject,
1946    response: Response<proto::UpdateProject>,
1947    session: Session,
1948) -> Result<()> {
1949    let project_id = ProjectId::from_proto(request.project_id);
1950    let (room, guest_connection_ids) = &*session
1951        .db()
1952        .await
1953        .update_project(project_id, session.connection_id, &request.worktrees)
1954        .await?;
1955    broadcast(
1956        Some(session.connection_id),
1957        guest_connection_ids.iter().copied(),
1958        |connection_id| {
1959            session
1960                .peer
1961                .forward_send(session.connection_id, connection_id, request.clone())
1962        },
1963    );
1964    if let Some(room) = room {
1965        room_updated(room, &session.peer);
1966    }
1967    response.send(proto::Ack {})?;
1968
1969    Ok(())
1970}
1971
1972/// Updates other participants with changes to the worktree
1973async fn update_worktree(
1974    request: proto::UpdateWorktree,
1975    response: Response<proto::UpdateWorktree>,
1976    session: Session,
1977) -> Result<()> {
1978    let guest_connection_ids = session
1979        .db()
1980        .await
1981        .update_worktree(&request, session.connection_id)
1982        .await?;
1983
1984    broadcast(
1985        Some(session.connection_id),
1986        guest_connection_ids.iter().copied(),
1987        |connection_id| {
1988            session
1989                .peer
1990                .forward_send(session.connection_id, connection_id, request.clone())
1991        },
1992    );
1993    response.send(proto::Ack {})?;
1994    Ok(())
1995}
1996
1997/// Updates other participants with changes to the diagnostics
1998async fn update_diagnostic_summary(
1999    message: proto::UpdateDiagnosticSummary,
2000    session: Session,
2001) -> Result<()> {
2002    let guest_connection_ids = session
2003        .db()
2004        .await
2005        .update_diagnostic_summary(&message, session.connection_id)
2006        .await?;
2007
2008    broadcast(
2009        Some(session.connection_id),
2010        guest_connection_ids.iter().copied(),
2011        |connection_id| {
2012            session
2013                .peer
2014                .forward_send(session.connection_id, connection_id, message.clone())
2015        },
2016    );
2017
2018    Ok(())
2019}
2020
2021/// Updates other participants with changes to the worktree settings
2022async fn update_worktree_settings(
2023    message: proto::UpdateWorktreeSettings,
2024    session: Session,
2025) -> Result<()> {
2026    let guest_connection_ids = session
2027        .db()
2028        .await
2029        .update_worktree_settings(&message, session.connection_id)
2030        .await?;
2031
2032    broadcast(
2033        Some(session.connection_id),
2034        guest_connection_ids.iter().copied(),
2035        |connection_id| {
2036            session
2037                .peer
2038                .forward_send(session.connection_id, connection_id, message.clone())
2039        },
2040    );
2041
2042    Ok(())
2043}
2044
2045/// Notify other participants that a  language server has started.
2046async fn start_language_server(
2047    request: proto::StartLanguageServer,
2048    session: Session,
2049) -> Result<()> {
2050    let guest_connection_ids = session
2051        .db()
2052        .await
2053        .start_language_server(&request, session.connection_id)
2054        .await?;
2055
2056    broadcast(
2057        Some(session.connection_id),
2058        guest_connection_ids.iter().copied(),
2059        |connection_id| {
2060            session
2061                .peer
2062                .forward_send(session.connection_id, connection_id, request.clone())
2063        },
2064    );
2065    Ok(())
2066}
2067
2068/// Notify other participants that a language server has changed.
2069async fn update_language_server(
2070    request: proto::UpdateLanguageServer,
2071    session: Session,
2072) -> Result<()> {
2073    let project_id = ProjectId::from_proto(request.project_id);
2074    let project_connection_ids = session
2075        .db()
2076        .await
2077        .project_connection_ids(project_id, session.connection_id, true)
2078        .await?;
2079    broadcast(
2080        Some(session.connection_id),
2081        project_connection_ids.iter().copied(),
2082        |connection_id| {
2083            session
2084                .peer
2085                .forward_send(session.connection_id, connection_id, request.clone())
2086        },
2087    );
2088    Ok(())
2089}
2090
2091/// forward a project request to the host. These requests should be read only
2092/// as guests are allowed to send them.
2093async fn forward_read_only_project_request<T>(
2094    request: T,
2095    response: Response<T>,
2096    session: Session,
2097) -> Result<()>
2098where
2099    T: EntityMessage + RequestMessage,
2100{
2101    let project_id = ProjectId::from_proto(request.remote_entity_id());
2102    let host_connection_id = session
2103        .db()
2104        .await
2105        .host_for_read_only_project_request(project_id, session.connection_id)
2106        .await?;
2107    let payload = session
2108        .peer
2109        .forward_request(session.connection_id, host_connection_id, request)
2110        .await?;
2111    response.send(payload)?;
2112    Ok(())
2113}
2114
2115async fn forward_find_search_candidates_request(
2116    request: proto::FindSearchCandidates,
2117    response: Response<proto::FindSearchCandidates>,
2118    session: Session,
2119) -> Result<()> {
2120    let project_id = ProjectId::from_proto(request.remote_entity_id());
2121    let host_connection_id = session
2122        .db()
2123        .await
2124        .host_for_read_only_project_request(project_id, session.connection_id)
2125        .await?;
2126    let payload = session
2127        .peer
2128        .forward_request(session.connection_id, host_connection_id, request)
2129        .await?;
2130    response.send(payload)?;
2131    Ok(())
2132}
2133
2134/// forward a project request to the host. These requests are disallowed
2135/// for guests.
2136async fn forward_mutating_project_request<T>(
2137    request: T,
2138    response: Response<T>,
2139    session: Session,
2140) -> Result<()>
2141where
2142    T: EntityMessage + RequestMessage,
2143{
2144    let project_id = ProjectId::from_proto(request.remote_entity_id());
2145
2146    let host_connection_id = session
2147        .db()
2148        .await
2149        .host_for_mutating_project_request(project_id, session.connection_id)
2150        .await?;
2151    let payload = session
2152        .peer
2153        .forward_request(session.connection_id, host_connection_id, request)
2154        .await?;
2155    response.send(payload)?;
2156    Ok(())
2157}
2158
2159/// Notify other participants that a new buffer has been created
2160async fn create_buffer_for_peer(
2161    request: proto::CreateBufferForPeer,
2162    session: Session,
2163) -> Result<()> {
2164    session
2165        .db()
2166        .await
2167        .check_user_is_project_host(
2168            ProjectId::from_proto(request.project_id),
2169            session.connection_id,
2170        )
2171        .await?;
2172    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2173    session
2174        .peer
2175        .forward_send(session.connection_id, peer_id.into(), request)?;
2176    Ok(())
2177}
2178
2179/// Notify other participants that a buffer has been updated. This is
2180/// allowed for guests as long as the update is limited to selections.
2181async fn update_buffer(
2182    request: proto::UpdateBuffer,
2183    response: Response<proto::UpdateBuffer>,
2184    session: Session,
2185) -> Result<()> {
2186    let project_id = ProjectId::from_proto(request.project_id);
2187    let mut capability = Capability::ReadOnly;
2188
2189    for op in request.operations.iter() {
2190        match op.variant {
2191            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2192            Some(_) => capability = Capability::ReadWrite,
2193        }
2194    }
2195
2196    let host = {
2197        let guard = session
2198            .db()
2199            .await
2200            .connections_for_buffer_update(project_id, session.connection_id, capability)
2201            .await?;
2202
2203        let (host, guests) = &*guard;
2204
2205        broadcast(
2206            Some(session.connection_id),
2207            guests.clone(),
2208            |connection_id| {
2209                session
2210                    .peer
2211                    .forward_send(session.connection_id, connection_id, request.clone())
2212            },
2213        );
2214
2215        *host
2216    };
2217
2218    if host != session.connection_id {
2219        session
2220            .peer
2221            .forward_request(session.connection_id, host, request.clone())
2222            .await?;
2223    }
2224
2225    response.send(proto::Ack {})?;
2226    Ok(())
2227}
2228
2229async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2230    let project_id = ProjectId::from_proto(message.project_id);
2231
2232    let operation = message.operation.as_ref().context("invalid operation")?;
2233    let capability = match operation.variant.as_ref() {
2234        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2235            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2236                match buffer_op.variant {
2237                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2238                        Capability::ReadOnly
2239                    }
2240                    _ => Capability::ReadWrite,
2241                }
2242            } else {
2243                Capability::ReadWrite
2244            }
2245        }
2246        Some(_) => Capability::ReadWrite,
2247        None => Capability::ReadOnly,
2248    };
2249
2250    let guard = session
2251        .db()
2252        .await
2253        .connections_for_buffer_update(project_id, session.connection_id, capability)
2254        .await?;
2255
2256    let (host, guests) = &*guard;
2257
2258    broadcast(
2259        Some(session.connection_id),
2260        guests.iter().chain([host]).copied(),
2261        |connection_id| {
2262            session
2263                .peer
2264                .forward_send(session.connection_id, connection_id, message.clone())
2265        },
2266    );
2267
2268    Ok(())
2269}
2270
2271/// Notify other participants that a project has been updated.
2272async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2273    request: T,
2274    session: Session,
2275) -> Result<()> {
2276    let project_id = ProjectId::from_proto(request.remote_entity_id());
2277    let project_connection_ids = session
2278        .db()
2279        .await
2280        .project_connection_ids(project_id, session.connection_id, false)
2281        .await?;
2282
2283    broadcast(
2284        Some(session.connection_id),
2285        project_connection_ids.iter().copied(),
2286        |connection_id| {
2287            session
2288                .peer
2289                .forward_send(session.connection_id, connection_id, request.clone())
2290        },
2291    );
2292    Ok(())
2293}
2294
2295/// Start following another user in a call.
2296async fn follow(
2297    request: proto::Follow,
2298    response: Response<proto::Follow>,
2299    session: Session,
2300) -> Result<()> {
2301    let room_id = RoomId::from_proto(request.room_id);
2302    let project_id = request.project_id.map(ProjectId::from_proto);
2303    let leader_id = request
2304        .leader_id
2305        .ok_or_else(|| anyhow!("invalid leader id"))?
2306        .into();
2307    let follower_id = session.connection_id;
2308
2309    session
2310        .db()
2311        .await
2312        .check_room_participants(room_id, leader_id, session.connection_id)
2313        .await?;
2314
2315    let response_payload = session
2316        .peer
2317        .forward_request(session.connection_id, leader_id, request)
2318        .await?;
2319    response.send(response_payload)?;
2320
2321    if let Some(project_id) = project_id {
2322        let room = session
2323            .db()
2324            .await
2325            .follow(room_id, project_id, leader_id, follower_id)
2326            .await?;
2327        room_updated(&room, &session.peer);
2328    }
2329
2330    Ok(())
2331}
2332
2333/// Stop following another user in a call.
2334async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2335    let room_id = RoomId::from_proto(request.room_id);
2336    let project_id = request.project_id.map(ProjectId::from_proto);
2337    let leader_id = request
2338        .leader_id
2339        .ok_or_else(|| anyhow!("invalid leader id"))?
2340        .into();
2341    let follower_id = session.connection_id;
2342
2343    session
2344        .db()
2345        .await
2346        .check_room_participants(room_id, leader_id, session.connection_id)
2347        .await?;
2348
2349    session
2350        .peer
2351        .forward_send(session.connection_id, leader_id, request)?;
2352
2353    if let Some(project_id) = project_id {
2354        let room = session
2355            .db()
2356            .await
2357            .unfollow(room_id, project_id, leader_id, follower_id)
2358            .await?;
2359        room_updated(&room, &session.peer);
2360    }
2361
2362    Ok(())
2363}
2364
2365/// Notify everyone following you of your current location.
2366async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2367    let room_id = RoomId::from_proto(request.room_id);
2368    let database = session.db.lock().await;
2369
2370    let connection_ids = if let Some(project_id) = request.project_id {
2371        let project_id = ProjectId::from_proto(project_id);
2372        database
2373            .project_connection_ids(project_id, session.connection_id, true)
2374            .await?
2375    } else {
2376        database
2377            .room_connection_ids(room_id, session.connection_id)
2378            .await?
2379    };
2380
2381    // For now, don't send view update messages back to that view's current leader.
2382    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2383        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2384        _ => None,
2385    });
2386
2387    for connection_id in connection_ids.iter().cloned() {
2388        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2389            session
2390                .peer
2391                .forward_send(session.connection_id, connection_id, request.clone())?;
2392        }
2393    }
2394    Ok(())
2395}
2396
2397/// Get public data about users.
2398async fn get_users(
2399    request: proto::GetUsers,
2400    response: Response<proto::GetUsers>,
2401    session: Session,
2402) -> Result<()> {
2403    let user_ids = request
2404        .user_ids
2405        .into_iter()
2406        .map(UserId::from_proto)
2407        .collect();
2408    let users = session
2409        .db()
2410        .await
2411        .get_users_by_ids(user_ids)
2412        .await?
2413        .into_iter()
2414        .map(|user| proto::User {
2415            id: user.id.to_proto(),
2416            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2417            github_login: user.github_login,
2418        })
2419        .collect();
2420    response.send(proto::UsersResponse { users })?;
2421    Ok(())
2422}
2423
2424/// Search for users (to invite) buy Github login
2425async fn fuzzy_search_users(
2426    request: proto::FuzzySearchUsers,
2427    response: Response<proto::FuzzySearchUsers>,
2428    session: Session,
2429) -> Result<()> {
2430    let query = request.query;
2431    let users = match query.len() {
2432        0 => vec![],
2433        1 | 2 => session
2434            .db()
2435            .await
2436            .get_user_by_github_login(&query)
2437            .await?
2438            .into_iter()
2439            .collect(),
2440        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2441    };
2442    let users = users
2443        .into_iter()
2444        .filter(|user| user.id != session.user_id())
2445        .map(|user| proto::User {
2446            id: user.id.to_proto(),
2447            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2448            github_login: user.github_login,
2449        })
2450        .collect();
2451    response.send(proto::UsersResponse { users })?;
2452    Ok(())
2453}
2454
2455/// Send a contact request to another user.
2456async fn request_contact(
2457    request: proto::RequestContact,
2458    response: Response<proto::RequestContact>,
2459    session: Session,
2460) -> Result<()> {
2461    let requester_id = session.user_id();
2462    let responder_id = UserId::from_proto(request.responder_id);
2463    if requester_id == responder_id {
2464        return Err(anyhow!("cannot add yourself as a contact"))?;
2465    }
2466
2467    let notifications = session
2468        .db()
2469        .await
2470        .send_contact_request(requester_id, responder_id)
2471        .await?;
2472
2473    // Update outgoing contact requests of requester
2474    let mut update = proto::UpdateContacts::default();
2475    update.outgoing_requests.push(responder_id.to_proto());
2476    for connection_id in session
2477        .connection_pool()
2478        .await
2479        .user_connection_ids(requester_id)
2480    {
2481        session.peer.send(connection_id, update.clone())?;
2482    }
2483
2484    // Update incoming contact requests of responder
2485    let mut update = proto::UpdateContacts::default();
2486    update
2487        .incoming_requests
2488        .push(proto::IncomingContactRequest {
2489            requester_id: requester_id.to_proto(),
2490        });
2491    let connection_pool = session.connection_pool().await;
2492    for connection_id in connection_pool.user_connection_ids(responder_id) {
2493        session.peer.send(connection_id, update.clone())?;
2494    }
2495
2496    send_notifications(&connection_pool, &session.peer, notifications);
2497
2498    response.send(proto::Ack {})?;
2499    Ok(())
2500}
2501
2502/// Accept or decline a contact request
2503async fn respond_to_contact_request(
2504    request: proto::RespondToContactRequest,
2505    response: Response<proto::RespondToContactRequest>,
2506    session: Session,
2507) -> Result<()> {
2508    let responder_id = session.user_id();
2509    let requester_id = UserId::from_proto(request.requester_id);
2510    let db = session.db().await;
2511    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2512        db.dismiss_contact_notification(responder_id, requester_id)
2513            .await?;
2514    } else {
2515        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2516
2517        let notifications = db
2518            .respond_to_contact_request(responder_id, requester_id, accept)
2519            .await?;
2520        let requester_busy = db.is_user_busy(requester_id).await?;
2521        let responder_busy = db.is_user_busy(responder_id).await?;
2522
2523        let pool = session.connection_pool().await;
2524        // Update responder with new contact
2525        let mut update = proto::UpdateContacts::default();
2526        if accept {
2527            update
2528                .contacts
2529                .push(contact_for_user(requester_id, requester_busy, &pool));
2530        }
2531        update
2532            .remove_incoming_requests
2533            .push(requester_id.to_proto());
2534        for connection_id in pool.user_connection_ids(responder_id) {
2535            session.peer.send(connection_id, update.clone())?;
2536        }
2537
2538        // Update requester with new contact
2539        let mut update = proto::UpdateContacts::default();
2540        if accept {
2541            update
2542                .contacts
2543                .push(contact_for_user(responder_id, responder_busy, &pool));
2544        }
2545        update
2546            .remove_outgoing_requests
2547            .push(responder_id.to_proto());
2548
2549        for connection_id in pool.user_connection_ids(requester_id) {
2550            session.peer.send(connection_id, update.clone())?;
2551        }
2552
2553        send_notifications(&pool, &session.peer, notifications);
2554    }
2555
2556    response.send(proto::Ack {})?;
2557    Ok(())
2558}
2559
2560/// Remove a contact.
2561async fn remove_contact(
2562    request: proto::RemoveContact,
2563    response: Response<proto::RemoveContact>,
2564    session: Session,
2565) -> Result<()> {
2566    let requester_id = session.user_id();
2567    let responder_id = UserId::from_proto(request.user_id);
2568    let db = session.db().await;
2569    let (contact_accepted, deleted_notification_id) =
2570        db.remove_contact(requester_id, responder_id).await?;
2571
2572    let pool = session.connection_pool().await;
2573    // Update outgoing contact requests of requester
2574    let mut update = proto::UpdateContacts::default();
2575    if contact_accepted {
2576        update.remove_contacts.push(responder_id.to_proto());
2577    } else {
2578        update
2579            .remove_outgoing_requests
2580            .push(responder_id.to_proto());
2581    }
2582    for connection_id in pool.user_connection_ids(requester_id) {
2583        session.peer.send(connection_id, update.clone())?;
2584    }
2585
2586    // Update incoming contact requests of responder
2587    let mut update = proto::UpdateContacts::default();
2588    if contact_accepted {
2589        update.remove_contacts.push(requester_id.to_proto());
2590    } else {
2591        update
2592            .remove_incoming_requests
2593            .push(requester_id.to_proto());
2594    }
2595    for connection_id in pool.user_connection_ids(responder_id) {
2596        session.peer.send(connection_id, update.clone())?;
2597        if let Some(notification_id) = deleted_notification_id {
2598            session.peer.send(
2599                connection_id,
2600                proto::DeleteNotification {
2601                    notification_id: notification_id.to_proto(),
2602                },
2603            )?;
2604        }
2605    }
2606
2607    response.send(proto::Ack {})?;
2608    Ok(())
2609}
2610
2611fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2612    version.0.minor() < 139
2613}
2614
2615async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2616    let plan = session.current_plan(&session.db().await).await?;
2617
2618    session
2619        .peer
2620        .send(
2621            session.connection_id,
2622            proto::UpdateUserPlan { plan: plan.into() },
2623        )
2624        .trace_err();
2625
2626    Ok(())
2627}
2628
2629async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2630    subscribe_user_to_channels(session.user_id(), &session).await?;
2631    Ok(())
2632}
2633
2634async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2635    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2636    let mut pool = session.connection_pool().await;
2637    for membership in &channels_for_user.channel_memberships {
2638        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2639    }
2640    session.peer.send(
2641        session.connection_id,
2642        build_update_user_channels(&channels_for_user),
2643    )?;
2644    session.peer.send(
2645        session.connection_id,
2646        build_channels_update(channels_for_user),
2647    )?;
2648    Ok(())
2649}
2650
2651/// Creates a new channel.
2652async fn create_channel(
2653    request: proto::CreateChannel,
2654    response: Response<proto::CreateChannel>,
2655    session: Session,
2656) -> Result<()> {
2657    let db = session.db().await;
2658
2659    let parent_id = request.parent_id.map(ChannelId::from_proto);
2660    let (channel, membership) = db
2661        .create_channel(&request.name, parent_id, session.user_id())
2662        .await?;
2663
2664    let root_id = channel.root_id();
2665    let channel = Channel::from_model(channel);
2666
2667    response.send(proto::CreateChannelResponse {
2668        channel: Some(channel.to_proto()),
2669        parent_id: request.parent_id,
2670    })?;
2671
2672    let mut connection_pool = session.connection_pool().await;
2673    if let Some(membership) = membership {
2674        connection_pool.subscribe_to_channel(
2675            membership.user_id,
2676            membership.channel_id,
2677            membership.role,
2678        );
2679        let update = proto::UpdateUserChannels {
2680            channel_memberships: vec![proto::ChannelMembership {
2681                channel_id: membership.channel_id.to_proto(),
2682                role: membership.role.into(),
2683            }],
2684            ..Default::default()
2685        };
2686        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2687            session.peer.send(connection_id, update.clone())?;
2688        }
2689    }
2690
2691    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2692        if !role.can_see_channel(channel.visibility) {
2693            continue;
2694        }
2695
2696        let update = proto::UpdateChannels {
2697            channels: vec![channel.to_proto()],
2698            ..Default::default()
2699        };
2700        session.peer.send(connection_id, update.clone())?;
2701    }
2702
2703    Ok(())
2704}
2705
2706/// Delete a channel
2707async fn delete_channel(
2708    request: proto::DeleteChannel,
2709    response: Response<proto::DeleteChannel>,
2710    session: Session,
2711) -> Result<()> {
2712    let db = session.db().await;
2713
2714    let channel_id = request.channel_id;
2715    let (root_channel, removed_channels) = db
2716        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2717        .await?;
2718    response.send(proto::Ack {})?;
2719
2720    // Notify members of removed channels
2721    let mut update = proto::UpdateChannels::default();
2722    update
2723        .delete_channels
2724        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2725
2726    let connection_pool = session.connection_pool().await;
2727    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2728        session.peer.send(connection_id, update.clone())?;
2729    }
2730
2731    Ok(())
2732}
2733
2734/// Invite someone to join a channel.
2735async fn invite_channel_member(
2736    request: proto::InviteChannelMember,
2737    response: Response<proto::InviteChannelMember>,
2738    session: Session,
2739) -> Result<()> {
2740    let db = session.db().await;
2741    let channel_id = ChannelId::from_proto(request.channel_id);
2742    let invitee_id = UserId::from_proto(request.user_id);
2743    let InviteMemberResult {
2744        channel,
2745        notifications,
2746    } = db
2747        .invite_channel_member(
2748            channel_id,
2749            invitee_id,
2750            session.user_id(),
2751            request.role().into(),
2752        )
2753        .await?;
2754
2755    let update = proto::UpdateChannels {
2756        channel_invitations: vec![channel.to_proto()],
2757        ..Default::default()
2758    };
2759
2760    let connection_pool = session.connection_pool().await;
2761    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2762        session.peer.send(connection_id, update.clone())?;
2763    }
2764
2765    send_notifications(&connection_pool, &session.peer, notifications);
2766
2767    response.send(proto::Ack {})?;
2768    Ok(())
2769}
2770
2771/// remove someone from a channel
2772async fn remove_channel_member(
2773    request: proto::RemoveChannelMember,
2774    response: Response<proto::RemoveChannelMember>,
2775    session: Session,
2776) -> Result<()> {
2777    let db = session.db().await;
2778    let channel_id = ChannelId::from_proto(request.channel_id);
2779    let member_id = UserId::from_proto(request.user_id);
2780
2781    let RemoveChannelMemberResult {
2782        membership_update,
2783        notification_id,
2784    } = db
2785        .remove_channel_member(channel_id, member_id, session.user_id())
2786        .await?;
2787
2788    let mut connection_pool = session.connection_pool().await;
2789    notify_membership_updated(
2790        &mut connection_pool,
2791        membership_update,
2792        member_id,
2793        &session.peer,
2794    );
2795    for connection_id in connection_pool.user_connection_ids(member_id) {
2796        if let Some(notification_id) = notification_id {
2797            session
2798                .peer
2799                .send(
2800                    connection_id,
2801                    proto::DeleteNotification {
2802                        notification_id: notification_id.to_proto(),
2803                    },
2804                )
2805                .trace_err();
2806        }
2807    }
2808
2809    response.send(proto::Ack {})?;
2810    Ok(())
2811}
2812
2813/// Toggle the channel between public and private.
2814/// Care is taken to maintain the invariant that public channels only descend from public channels,
2815/// (though members-only channels can appear at any point in the hierarchy).
2816async fn set_channel_visibility(
2817    request: proto::SetChannelVisibility,
2818    response: Response<proto::SetChannelVisibility>,
2819    session: Session,
2820) -> Result<()> {
2821    let db = session.db().await;
2822    let channel_id = ChannelId::from_proto(request.channel_id);
2823    let visibility = request.visibility().into();
2824
2825    let channel_model = db
2826        .set_channel_visibility(channel_id, visibility, session.user_id())
2827        .await?;
2828    let root_id = channel_model.root_id();
2829    let channel = Channel::from_model(channel_model);
2830
2831    let mut connection_pool = session.connection_pool().await;
2832    for (user_id, role) in connection_pool
2833        .channel_user_ids(root_id)
2834        .collect::<Vec<_>>()
2835        .into_iter()
2836    {
2837        let update = if role.can_see_channel(channel.visibility) {
2838            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2839            proto::UpdateChannels {
2840                channels: vec![channel.to_proto()],
2841                ..Default::default()
2842            }
2843        } else {
2844            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2845            proto::UpdateChannels {
2846                delete_channels: vec![channel.id.to_proto()],
2847                ..Default::default()
2848            }
2849        };
2850
2851        for connection_id in connection_pool.user_connection_ids(user_id) {
2852            session.peer.send(connection_id, update.clone())?;
2853        }
2854    }
2855
2856    response.send(proto::Ack {})?;
2857    Ok(())
2858}
2859
2860/// Alter the role for a user in the channel.
2861async fn set_channel_member_role(
2862    request: proto::SetChannelMemberRole,
2863    response: Response<proto::SetChannelMemberRole>,
2864    session: Session,
2865) -> Result<()> {
2866    let db = session.db().await;
2867    let channel_id = ChannelId::from_proto(request.channel_id);
2868    let member_id = UserId::from_proto(request.user_id);
2869    let result = db
2870        .set_channel_member_role(
2871            channel_id,
2872            session.user_id(),
2873            member_id,
2874            request.role().into(),
2875        )
2876        .await?;
2877
2878    match result {
2879        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2880            let mut connection_pool = session.connection_pool().await;
2881            notify_membership_updated(
2882                &mut connection_pool,
2883                membership_update,
2884                member_id,
2885                &session.peer,
2886            )
2887        }
2888        db::SetMemberRoleResult::InviteUpdated(channel) => {
2889            let update = proto::UpdateChannels {
2890                channel_invitations: vec![channel.to_proto()],
2891                ..Default::default()
2892            };
2893
2894            for connection_id in session
2895                .connection_pool()
2896                .await
2897                .user_connection_ids(member_id)
2898            {
2899                session.peer.send(connection_id, update.clone())?;
2900            }
2901        }
2902    }
2903
2904    response.send(proto::Ack {})?;
2905    Ok(())
2906}
2907
2908/// Change the name of a channel
2909async fn rename_channel(
2910    request: proto::RenameChannel,
2911    response: Response<proto::RenameChannel>,
2912    session: Session,
2913) -> Result<()> {
2914    let db = session.db().await;
2915    let channel_id = ChannelId::from_proto(request.channel_id);
2916    let channel_model = db
2917        .rename_channel(channel_id, session.user_id(), &request.name)
2918        .await?;
2919    let root_id = channel_model.root_id();
2920    let channel = Channel::from_model(channel_model);
2921
2922    response.send(proto::RenameChannelResponse {
2923        channel: Some(channel.to_proto()),
2924    })?;
2925
2926    let connection_pool = session.connection_pool().await;
2927    let update = proto::UpdateChannels {
2928        channels: vec![channel.to_proto()],
2929        ..Default::default()
2930    };
2931    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2932        if role.can_see_channel(channel.visibility) {
2933            session.peer.send(connection_id, update.clone())?;
2934        }
2935    }
2936
2937    Ok(())
2938}
2939
2940/// Move a channel to a new parent.
2941async fn move_channel(
2942    request: proto::MoveChannel,
2943    response: Response<proto::MoveChannel>,
2944    session: Session,
2945) -> Result<()> {
2946    let channel_id = ChannelId::from_proto(request.channel_id);
2947    let to = ChannelId::from_proto(request.to);
2948
2949    let (root_id, channels) = session
2950        .db()
2951        .await
2952        .move_channel(channel_id, to, session.user_id())
2953        .await?;
2954
2955    let connection_pool = session.connection_pool().await;
2956    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2957        let channels = channels
2958            .iter()
2959            .filter_map(|channel| {
2960                if role.can_see_channel(channel.visibility) {
2961                    Some(channel.to_proto())
2962                } else {
2963                    None
2964                }
2965            })
2966            .collect::<Vec<_>>();
2967        if channels.is_empty() {
2968            continue;
2969        }
2970
2971        let update = proto::UpdateChannels {
2972            channels,
2973            ..Default::default()
2974        };
2975
2976        session.peer.send(connection_id, update.clone())?;
2977    }
2978
2979    response.send(Ack {})?;
2980    Ok(())
2981}
2982
2983/// Get the list of channel members
2984async fn get_channel_members(
2985    request: proto::GetChannelMembers,
2986    response: Response<proto::GetChannelMembers>,
2987    session: Session,
2988) -> Result<()> {
2989    let db = session.db().await;
2990    let channel_id = ChannelId::from_proto(request.channel_id);
2991    let limit = if request.limit == 0 {
2992        u16::MAX as u64
2993    } else {
2994        request.limit
2995    };
2996    let (members, users) = db
2997        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
2998        .await?;
2999    response.send(proto::GetChannelMembersResponse { members, users })?;
3000    Ok(())
3001}
3002
3003/// Accept or decline a channel invitation.
3004async fn respond_to_channel_invite(
3005    request: proto::RespondToChannelInvite,
3006    response: Response<proto::RespondToChannelInvite>,
3007    session: Session,
3008) -> Result<()> {
3009    let db = session.db().await;
3010    let channel_id = ChannelId::from_proto(request.channel_id);
3011    let RespondToChannelInvite {
3012        membership_update,
3013        notifications,
3014    } = db
3015        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3016        .await?;
3017
3018    let mut connection_pool = session.connection_pool().await;
3019    if let Some(membership_update) = membership_update {
3020        notify_membership_updated(
3021            &mut connection_pool,
3022            membership_update,
3023            session.user_id(),
3024            &session.peer,
3025        );
3026    } else {
3027        let update = proto::UpdateChannels {
3028            remove_channel_invitations: vec![channel_id.to_proto()],
3029            ..Default::default()
3030        };
3031
3032        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3033            session.peer.send(connection_id, update.clone())?;
3034        }
3035    };
3036
3037    send_notifications(&connection_pool, &session.peer, notifications);
3038
3039    response.send(proto::Ack {})?;
3040
3041    Ok(())
3042}
3043
3044/// Join the channels' room
3045async fn join_channel(
3046    request: proto::JoinChannel,
3047    response: Response<proto::JoinChannel>,
3048    session: Session,
3049) -> Result<()> {
3050    let channel_id = ChannelId::from_proto(request.channel_id);
3051    join_channel_internal(channel_id, Box::new(response), session).await
3052}
3053
3054trait JoinChannelInternalResponse {
3055    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3056}
3057impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3058    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3059        Response::<proto::JoinChannel>::send(self, result)
3060    }
3061}
3062impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3063    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3064        Response::<proto::JoinRoom>::send(self, result)
3065    }
3066}
3067
3068async fn join_channel_internal(
3069    channel_id: ChannelId,
3070    response: Box<impl JoinChannelInternalResponse>,
3071    session: Session,
3072) -> Result<()> {
3073    let joined_room = {
3074        let mut db = session.db().await;
3075        // If zed quits without leaving the room, and the user re-opens zed before the
3076        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3077        // room they were in.
3078        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3079            tracing::info!(
3080                stale_connection_id = %connection,
3081                "cleaning up stale connection",
3082            );
3083            drop(db);
3084            leave_room_for_session(&session, connection).await?;
3085            db = session.db().await;
3086        }
3087
3088        let (joined_room, membership_updated, role) = db
3089            .join_channel(channel_id, session.user_id(), session.connection_id)
3090            .await?;
3091
3092        let live_kit_connection_info =
3093            session
3094                .app_state
3095                .livekit_client
3096                .as_ref()
3097                .and_then(|live_kit| {
3098                    let (can_publish, token) = if role == ChannelRole::Guest {
3099                        (
3100                            false,
3101                            live_kit
3102                                .guest_token(
3103                                    &joined_room.room.livekit_room,
3104                                    &session.user_id().to_string(),
3105                                )
3106                                .trace_err()?,
3107                        )
3108                    } else {
3109                        (
3110                            true,
3111                            live_kit
3112                                .room_token(
3113                                    &joined_room.room.livekit_room,
3114                                    &session.user_id().to_string(),
3115                                )
3116                                .trace_err()?,
3117                        )
3118                    };
3119
3120                    Some(LiveKitConnectionInfo {
3121                        server_url: live_kit.url().into(),
3122                        token,
3123                        can_publish,
3124                    })
3125                });
3126
3127        response.send(proto::JoinRoomResponse {
3128            room: Some(joined_room.room.clone()),
3129            channel_id: joined_room
3130                .channel
3131                .as_ref()
3132                .map(|channel| channel.id.to_proto()),
3133            live_kit_connection_info,
3134        })?;
3135
3136        let mut connection_pool = session.connection_pool().await;
3137        if let Some(membership_updated) = membership_updated {
3138            notify_membership_updated(
3139                &mut connection_pool,
3140                membership_updated,
3141                session.user_id(),
3142                &session.peer,
3143            );
3144        }
3145
3146        room_updated(&joined_room.room, &session.peer);
3147
3148        joined_room
3149    };
3150
3151    channel_updated(
3152        &joined_room
3153            .channel
3154            .ok_or_else(|| anyhow!("channel not returned"))?,
3155        &joined_room.room,
3156        &session.peer,
3157        &*session.connection_pool().await,
3158    );
3159
3160    update_user_contacts(session.user_id(), &session).await?;
3161    Ok(())
3162}
3163
3164/// Start editing the channel notes
3165async fn join_channel_buffer(
3166    request: proto::JoinChannelBuffer,
3167    response: Response<proto::JoinChannelBuffer>,
3168    session: Session,
3169) -> Result<()> {
3170    let db = session.db().await;
3171    let channel_id = ChannelId::from_proto(request.channel_id);
3172
3173    let open_response = db
3174        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3175        .await?;
3176
3177    let collaborators = open_response.collaborators.clone();
3178    response.send(open_response)?;
3179
3180    let update = UpdateChannelBufferCollaborators {
3181        channel_id: channel_id.to_proto(),
3182        collaborators: collaborators.clone(),
3183    };
3184    channel_buffer_updated(
3185        session.connection_id,
3186        collaborators
3187            .iter()
3188            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3189        &update,
3190        &session.peer,
3191    );
3192
3193    Ok(())
3194}
3195
3196/// Edit the channel notes
3197async fn update_channel_buffer(
3198    request: proto::UpdateChannelBuffer,
3199    session: Session,
3200) -> Result<()> {
3201    let db = session.db().await;
3202    let channel_id = ChannelId::from_proto(request.channel_id);
3203
3204    let (collaborators, epoch, version) = db
3205        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3206        .await?;
3207
3208    channel_buffer_updated(
3209        session.connection_id,
3210        collaborators.clone(),
3211        &proto::UpdateChannelBuffer {
3212            channel_id: channel_id.to_proto(),
3213            operations: request.operations,
3214        },
3215        &session.peer,
3216    );
3217
3218    let pool = &*session.connection_pool().await;
3219
3220    let non_collaborators =
3221        pool.channel_connection_ids(channel_id)
3222            .filter_map(|(connection_id, _)| {
3223                if collaborators.contains(&connection_id) {
3224                    None
3225                } else {
3226                    Some(connection_id)
3227                }
3228            });
3229
3230    broadcast(None, non_collaborators, |peer_id| {
3231        session.peer.send(
3232            peer_id,
3233            proto::UpdateChannels {
3234                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3235                    channel_id: channel_id.to_proto(),
3236                    epoch: epoch as u64,
3237                    version: version.clone(),
3238                }],
3239                ..Default::default()
3240            },
3241        )
3242    });
3243
3244    Ok(())
3245}
3246
3247/// Rejoin the channel notes after a connection blip
3248async fn rejoin_channel_buffers(
3249    request: proto::RejoinChannelBuffers,
3250    response: Response<proto::RejoinChannelBuffers>,
3251    session: Session,
3252) -> Result<()> {
3253    let db = session.db().await;
3254    let buffers = db
3255        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3256        .await?;
3257
3258    for rejoined_buffer in &buffers {
3259        let collaborators_to_notify = rejoined_buffer
3260            .buffer
3261            .collaborators
3262            .iter()
3263            .filter_map(|c| Some(c.peer_id?.into()));
3264        channel_buffer_updated(
3265            session.connection_id,
3266            collaborators_to_notify,
3267            &proto::UpdateChannelBufferCollaborators {
3268                channel_id: rejoined_buffer.buffer.channel_id,
3269                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3270            },
3271            &session.peer,
3272        );
3273    }
3274
3275    response.send(proto::RejoinChannelBuffersResponse {
3276        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3277    })?;
3278
3279    Ok(())
3280}
3281
3282/// Stop editing the channel notes
3283async fn leave_channel_buffer(
3284    request: proto::LeaveChannelBuffer,
3285    response: Response<proto::LeaveChannelBuffer>,
3286    session: Session,
3287) -> Result<()> {
3288    let db = session.db().await;
3289    let channel_id = ChannelId::from_proto(request.channel_id);
3290
3291    let left_buffer = db
3292        .leave_channel_buffer(channel_id, session.connection_id)
3293        .await?;
3294
3295    response.send(Ack {})?;
3296
3297    channel_buffer_updated(
3298        session.connection_id,
3299        left_buffer.connections,
3300        &proto::UpdateChannelBufferCollaborators {
3301            channel_id: channel_id.to_proto(),
3302            collaborators: left_buffer.collaborators,
3303        },
3304        &session.peer,
3305    );
3306
3307    Ok(())
3308}
3309
3310fn channel_buffer_updated<T: EnvelopedMessage>(
3311    sender_id: ConnectionId,
3312    collaborators: impl IntoIterator<Item = ConnectionId>,
3313    message: &T,
3314    peer: &Peer,
3315) {
3316    broadcast(Some(sender_id), collaborators, |peer_id| {
3317        peer.send(peer_id, message.clone())
3318    });
3319}
3320
3321fn send_notifications(
3322    connection_pool: &ConnectionPool,
3323    peer: &Peer,
3324    notifications: db::NotificationBatch,
3325) {
3326    for (user_id, notification) in notifications {
3327        for connection_id in connection_pool.user_connection_ids(user_id) {
3328            if let Err(error) = peer.send(
3329                connection_id,
3330                proto::AddNotification {
3331                    notification: Some(notification.clone()),
3332                },
3333            ) {
3334                tracing::error!(
3335                    "failed to send notification to {:?} {}",
3336                    connection_id,
3337                    error
3338                );
3339            }
3340        }
3341    }
3342}
3343
3344/// Send a message to the channel
3345async fn send_channel_message(
3346    request: proto::SendChannelMessage,
3347    response: Response<proto::SendChannelMessage>,
3348    session: Session,
3349) -> Result<()> {
3350    // Validate the message body.
3351    let body = request.body.trim().to_string();
3352    if body.len() > MAX_MESSAGE_LEN {
3353        return Err(anyhow!("message is too long"))?;
3354    }
3355    if body.is_empty() {
3356        return Err(anyhow!("message can't be blank"))?;
3357    }
3358
3359    // TODO: adjust mentions if body is trimmed
3360
3361    let timestamp = OffsetDateTime::now_utc();
3362    let nonce = request
3363        .nonce
3364        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3365
3366    let channel_id = ChannelId::from_proto(request.channel_id);
3367    let CreatedChannelMessage {
3368        message_id,
3369        participant_connection_ids,
3370        notifications,
3371    } = session
3372        .db()
3373        .await
3374        .create_channel_message(
3375            channel_id,
3376            session.user_id(),
3377            &body,
3378            &request.mentions,
3379            timestamp,
3380            nonce.clone().into(),
3381            request.reply_to_message_id.map(MessageId::from_proto),
3382        )
3383        .await?;
3384
3385    let message = proto::ChannelMessage {
3386        sender_id: session.user_id().to_proto(),
3387        id: message_id.to_proto(),
3388        body,
3389        mentions: request.mentions,
3390        timestamp: timestamp.unix_timestamp() as u64,
3391        nonce: Some(nonce),
3392        reply_to_message_id: request.reply_to_message_id,
3393        edited_at: None,
3394    };
3395    broadcast(
3396        Some(session.connection_id),
3397        participant_connection_ids.clone(),
3398        |connection| {
3399            session.peer.send(
3400                connection,
3401                proto::ChannelMessageSent {
3402                    channel_id: channel_id.to_proto(),
3403                    message: Some(message.clone()),
3404                },
3405            )
3406        },
3407    );
3408    response.send(proto::SendChannelMessageResponse {
3409        message: Some(message),
3410    })?;
3411
3412    let pool = &*session.connection_pool().await;
3413    let non_participants =
3414        pool.channel_connection_ids(channel_id)
3415            .filter_map(|(connection_id, _)| {
3416                if participant_connection_ids.contains(&connection_id) {
3417                    None
3418                } else {
3419                    Some(connection_id)
3420                }
3421            });
3422    broadcast(None, non_participants, |peer_id| {
3423        session.peer.send(
3424            peer_id,
3425            proto::UpdateChannels {
3426                latest_channel_message_ids: vec![proto::ChannelMessageId {
3427                    channel_id: channel_id.to_proto(),
3428                    message_id: message_id.to_proto(),
3429                }],
3430                ..Default::default()
3431            },
3432        )
3433    });
3434    send_notifications(pool, &session.peer, notifications);
3435
3436    Ok(())
3437}
3438
3439/// Delete a channel message
3440async fn remove_channel_message(
3441    request: proto::RemoveChannelMessage,
3442    response: Response<proto::RemoveChannelMessage>,
3443    session: Session,
3444) -> Result<()> {
3445    let channel_id = ChannelId::from_proto(request.channel_id);
3446    let message_id = MessageId::from_proto(request.message_id);
3447    let (connection_ids, existing_notification_ids) = session
3448        .db()
3449        .await
3450        .remove_channel_message(channel_id, message_id, session.user_id())
3451        .await?;
3452
3453    broadcast(
3454        Some(session.connection_id),
3455        connection_ids,
3456        move |connection| {
3457            session.peer.send(connection, request.clone())?;
3458
3459            for notification_id in &existing_notification_ids {
3460                session.peer.send(
3461                    connection,
3462                    proto::DeleteNotification {
3463                        notification_id: (*notification_id).to_proto(),
3464                    },
3465                )?;
3466            }
3467
3468            Ok(())
3469        },
3470    );
3471    response.send(proto::Ack {})?;
3472    Ok(())
3473}
3474
3475async fn update_channel_message(
3476    request: proto::UpdateChannelMessage,
3477    response: Response<proto::UpdateChannelMessage>,
3478    session: Session,
3479) -> Result<()> {
3480    let channel_id = ChannelId::from_proto(request.channel_id);
3481    let message_id = MessageId::from_proto(request.message_id);
3482    let updated_at = OffsetDateTime::now_utc();
3483    let UpdatedChannelMessage {
3484        message_id,
3485        participant_connection_ids,
3486        notifications,
3487        reply_to_message_id,
3488        timestamp,
3489        deleted_mention_notification_ids,
3490        updated_mention_notifications,
3491    } = session
3492        .db()
3493        .await
3494        .update_channel_message(
3495            channel_id,
3496            message_id,
3497            session.user_id(),
3498            request.body.as_str(),
3499            &request.mentions,
3500            updated_at,
3501        )
3502        .await?;
3503
3504    let nonce = request
3505        .nonce
3506        .clone()
3507        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3508
3509    let message = proto::ChannelMessage {
3510        sender_id: session.user_id().to_proto(),
3511        id: message_id.to_proto(),
3512        body: request.body.clone(),
3513        mentions: request.mentions.clone(),
3514        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3515        nonce: Some(nonce),
3516        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3517        edited_at: Some(updated_at.unix_timestamp() as u64),
3518    };
3519
3520    response.send(proto::Ack {})?;
3521
3522    let pool = &*session.connection_pool().await;
3523    broadcast(
3524        Some(session.connection_id),
3525        participant_connection_ids,
3526        |connection| {
3527            session.peer.send(
3528                connection,
3529                proto::ChannelMessageUpdate {
3530                    channel_id: channel_id.to_proto(),
3531                    message: Some(message.clone()),
3532                },
3533            )?;
3534
3535            for notification_id in &deleted_mention_notification_ids {
3536                session.peer.send(
3537                    connection,
3538                    proto::DeleteNotification {
3539                        notification_id: (*notification_id).to_proto(),
3540                    },
3541                )?;
3542            }
3543
3544            for notification in &updated_mention_notifications {
3545                session.peer.send(
3546                    connection,
3547                    proto::UpdateNotification {
3548                        notification: Some(notification.clone()),
3549                    },
3550                )?;
3551            }
3552
3553            Ok(())
3554        },
3555    );
3556
3557    send_notifications(pool, &session.peer, notifications);
3558
3559    Ok(())
3560}
3561
3562/// Mark a channel message as read
3563async fn acknowledge_channel_message(
3564    request: proto::AckChannelMessage,
3565    session: Session,
3566) -> Result<()> {
3567    let channel_id = ChannelId::from_proto(request.channel_id);
3568    let message_id = MessageId::from_proto(request.message_id);
3569    let notifications = session
3570        .db()
3571        .await
3572        .observe_channel_message(channel_id, session.user_id(), message_id)
3573        .await?;
3574    send_notifications(
3575        &*session.connection_pool().await,
3576        &session.peer,
3577        notifications,
3578    );
3579    Ok(())
3580}
3581
3582/// Mark a buffer version as synced
3583async fn acknowledge_buffer_version(
3584    request: proto::AckBufferOperation,
3585    session: Session,
3586) -> Result<()> {
3587    let buffer_id = BufferId::from_proto(request.buffer_id);
3588    session
3589        .db()
3590        .await
3591        .observe_buffer_version(
3592            buffer_id,
3593            session.user_id(),
3594            request.epoch as i32,
3595            &request.version,
3596        )
3597        .await?;
3598    Ok(())
3599}
3600
3601async fn count_language_model_tokens(
3602    request: proto::CountLanguageModelTokens,
3603    response: Response<proto::CountLanguageModelTokens>,
3604    session: Session,
3605    config: &Config,
3606) -> Result<()> {
3607    authorize_access_to_legacy_llm_endpoints(&session).await?;
3608
3609    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3610        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3611        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3612    };
3613
3614    session
3615        .app_state
3616        .rate_limiter
3617        .check(&*rate_limit, session.user_id())
3618        .await?;
3619
3620    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3621        Some(proto::LanguageModelProvider::Google) => {
3622            let api_key = config
3623                .google_ai_api_key
3624                .as_ref()
3625                .context("no Google AI API key configured on the server")?;
3626            google_ai::count_tokens(
3627                session.http_client.as_ref(),
3628                google_ai::API_URL,
3629                api_key,
3630                serde_json::from_str(&request.request)?,
3631            )
3632            .await?
3633        }
3634        _ => return Err(anyhow!("unsupported provider"))?,
3635    };
3636
3637    response.send(proto::CountLanguageModelTokensResponse {
3638        token_count: result.total_tokens as u32,
3639    })?;
3640
3641    Ok(())
3642}
3643
3644struct ZedProCountLanguageModelTokensRateLimit;
3645
3646impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3647    fn capacity(&self) -> usize {
3648        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3649            .ok()
3650            .and_then(|v| v.parse().ok())
3651            .unwrap_or(600) // Picked arbitrarily
3652    }
3653
3654    fn refill_duration(&self) -> chrono::Duration {
3655        chrono::Duration::hours(1)
3656    }
3657
3658    fn db_name(&self) -> &'static str {
3659        "zed-pro:count-language-model-tokens"
3660    }
3661}
3662
3663struct FreeCountLanguageModelTokensRateLimit;
3664
3665impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3666    fn capacity(&self) -> usize {
3667        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3668            .ok()
3669            .and_then(|v| v.parse().ok())
3670            .unwrap_or(600 / 10) // Picked arbitrarily
3671    }
3672
3673    fn refill_duration(&self) -> chrono::Duration {
3674        chrono::Duration::hours(1)
3675    }
3676
3677    fn db_name(&self) -> &'static str {
3678        "free:count-language-model-tokens"
3679    }
3680}
3681
3682struct ZedProComputeEmbeddingsRateLimit;
3683
3684impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3685    fn capacity(&self) -> usize {
3686        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3687            .ok()
3688            .and_then(|v| v.parse().ok())
3689            .unwrap_or(5000) // Picked arbitrarily
3690    }
3691
3692    fn refill_duration(&self) -> chrono::Duration {
3693        chrono::Duration::hours(1)
3694    }
3695
3696    fn db_name(&self) -> &'static str {
3697        "zed-pro:compute-embeddings"
3698    }
3699}
3700
3701struct FreeComputeEmbeddingsRateLimit;
3702
3703impl RateLimit for FreeComputeEmbeddingsRateLimit {
3704    fn capacity(&self) -> usize {
3705        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3706            .ok()
3707            .and_then(|v| v.parse().ok())
3708            .unwrap_or(5000 / 10) // Picked arbitrarily
3709    }
3710
3711    fn refill_duration(&self) -> chrono::Duration {
3712        chrono::Duration::hours(1)
3713    }
3714
3715    fn db_name(&self) -> &'static str {
3716        "free:compute-embeddings"
3717    }
3718}
3719
3720async fn compute_embeddings(
3721    request: proto::ComputeEmbeddings,
3722    response: Response<proto::ComputeEmbeddings>,
3723    session: Session,
3724    api_key: Option<Arc<str>>,
3725) -> Result<()> {
3726    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3727    authorize_access_to_legacy_llm_endpoints(&session).await?;
3728
3729    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3730        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3731        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3732    };
3733
3734    session
3735        .app_state
3736        .rate_limiter
3737        .check(&*rate_limit, session.user_id())
3738        .await?;
3739
3740    let embeddings = match request.model.as_str() {
3741        "openai/text-embedding-3-small" => {
3742            open_ai::embed(
3743                session.http_client.as_ref(),
3744                OPEN_AI_API_URL,
3745                &api_key,
3746                OpenAiEmbeddingModel::TextEmbedding3Small,
3747                request.texts.iter().map(|text| text.as_str()),
3748            )
3749            .await?
3750        }
3751        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3752    };
3753
3754    let embeddings = request
3755        .texts
3756        .iter()
3757        .map(|text| {
3758            let mut hasher = sha2::Sha256::new();
3759            hasher.update(text.as_bytes());
3760            let result = hasher.finalize();
3761            result.to_vec()
3762        })
3763        .zip(
3764            embeddings
3765                .data
3766                .into_iter()
3767                .map(|embedding| embedding.embedding),
3768        )
3769        .collect::<HashMap<_, _>>();
3770
3771    let db = session.db().await;
3772    db.save_embeddings(&request.model, &embeddings)
3773        .await
3774        .context("failed to save embeddings")
3775        .trace_err();
3776
3777    response.send(proto::ComputeEmbeddingsResponse {
3778        embeddings: embeddings
3779            .into_iter()
3780            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3781            .collect(),
3782    })?;
3783    Ok(())
3784}
3785
3786async fn get_cached_embeddings(
3787    request: proto::GetCachedEmbeddings,
3788    response: Response<proto::GetCachedEmbeddings>,
3789    session: Session,
3790) -> Result<()> {
3791    authorize_access_to_legacy_llm_endpoints(&session).await?;
3792
3793    let db = session.db().await;
3794    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3795
3796    response.send(proto::GetCachedEmbeddingsResponse {
3797        embeddings: embeddings
3798            .into_iter()
3799            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3800            .collect(),
3801    })?;
3802    Ok(())
3803}
3804
3805/// This is leftover from before the LLM service.
3806///
3807/// The endpoints protected by this check will be moved there eventually.
3808async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3809    if session.is_staff() {
3810        Ok(())
3811    } else {
3812        Err(anyhow!("permission denied"))?
3813    }
3814}
3815
3816/// Get a Supermaven API key for the user
3817async fn get_supermaven_api_key(
3818    _request: proto::GetSupermavenApiKey,
3819    response: Response<proto::GetSupermavenApiKey>,
3820    session: Session,
3821) -> Result<()> {
3822    let user_id: String = session.user_id().to_string();
3823    if !session.is_staff() {
3824        return Err(anyhow!("supermaven not enabled for this account"))?;
3825    }
3826
3827    let email = session
3828        .email()
3829        .ok_or_else(|| anyhow!("user must have an email"))?;
3830
3831    let supermaven_admin_api = session
3832        .supermaven_client
3833        .as_ref()
3834        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3835
3836    let result = supermaven_admin_api
3837        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3838        .await?;
3839
3840    response.send(proto::GetSupermavenApiKeyResponse {
3841        api_key: result.api_key,
3842    })?;
3843
3844    Ok(())
3845}
3846
3847/// Start receiving chat updates for a channel
3848async fn join_channel_chat(
3849    request: proto::JoinChannelChat,
3850    response: Response<proto::JoinChannelChat>,
3851    session: Session,
3852) -> Result<()> {
3853    let channel_id = ChannelId::from_proto(request.channel_id);
3854
3855    let db = session.db().await;
3856    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3857        .await?;
3858    let messages = db
3859        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3860        .await?;
3861    response.send(proto::JoinChannelChatResponse {
3862        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3863        messages,
3864    })?;
3865    Ok(())
3866}
3867
3868/// Stop receiving chat updates for a channel
3869async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3870    let channel_id = ChannelId::from_proto(request.channel_id);
3871    session
3872        .db()
3873        .await
3874        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3875        .await?;
3876    Ok(())
3877}
3878
3879/// Retrieve the chat history for a channel
3880async fn get_channel_messages(
3881    request: proto::GetChannelMessages,
3882    response: Response<proto::GetChannelMessages>,
3883    session: Session,
3884) -> Result<()> {
3885    let channel_id = ChannelId::from_proto(request.channel_id);
3886    let messages = session
3887        .db()
3888        .await
3889        .get_channel_messages(
3890            channel_id,
3891            session.user_id(),
3892            MESSAGE_COUNT_PER_PAGE,
3893            Some(MessageId::from_proto(request.before_message_id)),
3894        )
3895        .await?;
3896    response.send(proto::GetChannelMessagesResponse {
3897        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3898        messages,
3899    })?;
3900    Ok(())
3901}
3902
3903/// Retrieve specific chat messages
3904async fn get_channel_messages_by_id(
3905    request: proto::GetChannelMessagesById,
3906    response: Response<proto::GetChannelMessagesById>,
3907    session: Session,
3908) -> Result<()> {
3909    let message_ids = request
3910        .message_ids
3911        .iter()
3912        .map(|id| MessageId::from_proto(*id))
3913        .collect::<Vec<_>>();
3914    let messages = session
3915        .db()
3916        .await
3917        .get_channel_messages_by_id(session.user_id(), &message_ids)
3918        .await?;
3919    response.send(proto::GetChannelMessagesResponse {
3920        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3921        messages,
3922    })?;
3923    Ok(())
3924}
3925
3926/// Retrieve the current users notifications
3927async fn get_notifications(
3928    request: proto::GetNotifications,
3929    response: Response<proto::GetNotifications>,
3930    session: Session,
3931) -> Result<()> {
3932    let notifications = session
3933        .db()
3934        .await
3935        .get_notifications(
3936            session.user_id(),
3937            NOTIFICATION_COUNT_PER_PAGE,
3938            request.before_id.map(db::NotificationId::from_proto),
3939        )
3940        .await?;
3941    response.send(proto::GetNotificationsResponse {
3942        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3943        notifications,
3944    })?;
3945    Ok(())
3946}
3947
3948/// Mark notifications as read
3949async fn mark_notification_as_read(
3950    request: proto::MarkNotificationRead,
3951    response: Response<proto::MarkNotificationRead>,
3952    session: Session,
3953) -> Result<()> {
3954    let database = &session.db().await;
3955    let notifications = database
3956        .mark_notification_as_read_by_id(
3957            session.user_id(),
3958            NotificationId::from_proto(request.notification_id),
3959        )
3960        .await?;
3961    send_notifications(
3962        &*session.connection_pool().await,
3963        &session.peer,
3964        notifications,
3965    );
3966    response.send(proto::Ack {})?;
3967    Ok(())
3968}
3969
3970/// Get the current users information
3971async fn get_private_user_info(
3972    _request: proto::GetPrivateUserInfo,
3973    response: Response<proto::GetPrivateUserInfo>,
3974    session: Session,
3975) -> Result<()> {
3976    let db = session.db().await;
3977
3978    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3979    let user = db
3980        .get_user_by_id(session.user_id())
3981        .await?
3982        .ok_or_else(|| anyhow!("user not found"))?;
3983    let flags = db.get_user_flags(session.user_id()).await?;
3984
3985    response.send(proto::GetPrivateUserInfoResponse {
3986        metrics_id,
3987        staff: user.admin,
3988        flags,
3989        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3990    })?;
3991    Ok(())
3992}
3993
3994/// Accept the terms of service (tos) on behalf of the current user
3995async fn accept_terms_of_service(
3996    _request: proto::AcceptTermsOfService,
3997    response: Response<proto::AcceptTermsOfService>,
3998    session: Session,
3999) -> Result<()> {
4000    let db = session.db().await;
4001
4002    let accepted_tos_at = Utc::now();
4003    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4004        .await?;
4005
4006    response.send(proto::AcceptTermsOfServiceResponse {
4007        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4008    })?;
4009    Ok(())
4010}
4011
4012/// The minimum account age an account must have in order to use the LLM service.
4013const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4014
4015async fn get_llm_api_token(
4016    _request: proto::GetLlmToken,
4017    response: Response<proto::GetLlmToken>,
4018    session: Session,
4019) -> Result<()> {
4020    let db = session.db().await;
4021
4022    let flags = db.get_user_flags(session.user_id()).await?;
4023    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4024    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4025
4026    if !session.is_staff() && !has_language_models_feature_flag {
4027        Err(anyhow!("permission denied"))?
4028    }
4029
4030    let user_id = session.user_id();
4031    let user = db
4032        .get_user_by_id(user_id)
4033        .await?
4034        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4035
4036    if user.accepted_tos_at.is_none() {
4037        Err(anyhow!("terms of service not accepted"))?
4038    }
4039
4040    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4041
4042    let bypass_account_age_check =
4043        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4044    if !bypass_account_age_check {
4045        let mut account_created_at = user.created_at;
4046        if let Some(github_created_at) = user.github_user_created_at {
4047            account_created_at = account_created_at.min(github_created_at);
4048        }
4049        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4050            Err(anyhow!("account too young"))?
4051        }
4052    }
4053
4054    let billing_preferences = db.get_billing_preferences(user.id).await?;
4055
4056    let token = LlmTokenClaims::create(
4057        &user,
4058        session.is_staff(),
4059        billing_preferences,
4060        has_llm_closed_beta_feature_flag,
4061        has_llm_subscription,
4062        session.current_plan(&db).await?,
4063        session.system_id.clone(),
4064        &session.app_state.config,
4065    )?;
4066    response.send(proto::GetLlmTokenResponse { token })?;
4067    Ok(())
4068}
4069
4070fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4071    let message = match message {
4072        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4073        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4074        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4075        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4076        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4077            code: frame.code.into(),
4078            reason: frame.reason,
4079        })),
4080        // We should never receive a frame while reading the message, according
4081        // to the `tungstenite` maintainers:
4082        //
4083        // > It cannot occur when you read messages from the WebSocket, but it
4084        // > can be used when you want to send the raw frames (e.g. you want to
4085        // > send the frames to the WebSocket without composing the full message first).
4086        // >
4087        // > — https://github.com/snapview/tungstenite-rs/issues/268
4088        TungsteniteMessage::Frame(_) => {
4089            bail!("received an unexpected frame while reading the message")
4090        }
4091    };
4092
4093    Ok(message)
4094}
4095
4096fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4097    match message {
4098        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4099        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4100        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4101        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4102        AxumMessage::Close(frame) => {
4103            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4104                code: frame.code.into(),
4105                reason: frame.reason,
4106            }))
4107        }
4108    }
4109}
4110
4111fn notify_membership_updated(
4112    connection_pool: &mut ConnectionPool,
4113    result: MembershipUpdated,
4114    user_id: UserId,
4115    peer: &Peer,
4116) {
4117    for membership in &result.new_channels.channel_memberships {
4118        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4119    }
4120    for channel_id in &result.removed_channels {
4121        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4122    }
4123
4124    let user_channels_update = proto::UpdateUserChannels {
4125        channel_memberships: result
4126            .new_channels
4127            .channel_memberships
4128            .iter()
4129            .map(|cm| proto::ChannelMembership {
4130                channel_id: cm.channel_id.to_proto(),
4131                role: cm.role.into(),
4132            })
4133            .collect(),
4134        ..Default::default()
4135    };
4136
4137    let mut update = build_channels_update(result.new_channels);
4138    update.delete_channels = result
4139        .removed_channels
4140        .into_iter()
4141        .map(|id| id.to_proto())
4142        .collect();
4143    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4144
4145    for connection_id in connection_pool.user_connection_ids(user_id) {
4146        peer.send(connection_id, user_channels_update.clone())
4147            .trace_err();
4148        peer.send(connection_id, update.clone()).trace_err();
4149    }
4150}
4151
4152fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4153    proto::UpdateUserChannels {
4154        channel_memberships: channels
4155            .channel_memberships
4156            .iter()
4157            .map(|m| proto::ChannelMembership {
4158                channel_id: m.channel_id.to_proto(),
4159                role: m.role.into(),
4160            })
4161            .collect(),
4162        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4163        observed_channel_message_id: channels.observed_channel_messages.clone(),
4164    }
4165}
4166
4167fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4168    let mut update = proto::UpdateChannels::default();
4169
4170    for channel in channels.channels {
4171        update.channels.push(channel.to_proto());
4172    }
4173
4174    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4175    update.latest_channel_message_ids = channels.latest_channel_messages;
4176
4177    for (channel_id, participants) in channels.channel_participants {
4178        update
4179            .channel_participants
4180            .push(proto::ChannelParticipants {
4181                channel_id: channel_id.to_proto(),
4182                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4183            });
4184    }
4185
4186    for channel in channels.invited_channels {
4187        update.channel_invitations.push(channel.to_proto());
4188    }
4189
4190    update
4191}
4192
4193fn build_initial_contacts_update(
4194    contacts: Vec<db::Contact>,
4195    pool: &ConnectionPool,
4196) -> proto::UpdateContacts {
4197    let mut update = proto::UpdateContacts::default();
4198
4199    for contact in contacts {
4200        match contact {
4201            db::Contact::Accepted { user_id, busy } => {
4202                update.contacts.push(contact_for_user(user_id, busy, pool));
4203            }
4204            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4205            db::Contact::Incoming { user_id } => {
4206                update
4207                    .incoming_requests
4208                    .push(proto::IncomingContactRequest {
4209                        requester_id: user_id.to_proto(),
4210                    })
4211            }
4212        }
4213    }
4214
4215    update
4216}
4217
4218fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4219    proto::Contact {
4220        user_id: user_id.to_proto(),
4221        online: pool.is_user_online(user_id),
4222        busy,
4223    }
4224}
4225
4226fn room_updated(room: &proto::Room, peer: &Peer) {
4227    broadcast(
4228        None,
4229        room.participants
4230            .iter()
4231            .filter_map(|participant| Some(participant.peer_id?.into())),
4232        |peer_id| {
4233            peer.send(
4234                peer_id,
4235                proto::RoomUpdated {
4236                    room: Some(room.clone()),
4237                },
4238            )
4239        },
4240    );
4241}
4242
4243fn channel_updated(
4244    channel: &db::channel::Model,
4245    room: &proto::Room,
4246    peer: &Peer,
4247    pool: &ConnectionPool,
4248) {
4249    let participants = room
4250        .participants
4251        .iter()
4252        .map(|p| p.user_id)
4253        .collect::<Vec<_>>();
4254
4255    broadcast(
4256        None,
4257        pool.channel_connection_ids(channel.root_id())
4258            .filter_map(|(channel_id, role)| {
4259                role.can_see_channel(channel.visibility)
4260                    .then_some(channel_id)
4261            }),
4262        |peer_id| {
4263            peer.send(
4264                peer_id,
4265                proto::UpdateChannels {
4266                    channel_participants: vec![proto::ChannelParticipants {
4267                        channel_id: channel.id.to_proto(),
4268                        participant_user_ids: participants.clone(),
4269                    }],
4270                    ..Default::default()
4271                },
4272            )
4273        },
4274    );
4275}
4276
4277async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4278    let db = session.db().await;
4279
4280    let contacts = db.get_contacts(user_id).await?;
4281    let busy = db.is_user_busy(user_id).await?;
4282
4283    let pool = session.connection_pool().await;
4284    let updated_contact = contact_for_user(user_id, busy, &pool);
4285    for contact in contacts {
4286        if let db::Contact::Accepted {
4287            user_id: contact_user_id,
4288            ..
4289        } = contact
4290        {
4291            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4292                session
4293                    .peer
4294                    .send(
4295                        contact_conn_id,
4296                        proto::UpdateContacts {
4297                            contacts: vec![updated_contact.clone()],
4298                            remove_contacts: Default::default(),
4299                            incoming_requests: Default::default(),
4300                            remove_incoming_requests: Default::default(),
4301                            outgoing_requests: Default::default(),
4302                            remove_outgoing_requests: Default::default(),
4303                        },
4304                    )
4305                    .trace_err();
4306            }
4307        }
4308    }
4309    Ok(())
4310}
4311
4312async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4313    let mut contacts_to_update = HashSet::default();
4314
4315    let room_id;
4316    let canceled_calls_to_user_ids;
4317    let livekit_room;
4318    let delete_livekit_room;
4319    let room;
4320    let channel;
4321
4322    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4323        contacts_to_update.insert(session.user_id());
4324
4325        for project in left_room.left_projects.values() {
4326            project_left(project, session);
4327        }
4328
4329        room_id = RoomId::from_proto(left_room.room.id);
4330        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4331        livekit_room = mem::take(&mut left_room.room.livekit_room);
4332        delete_livekit_room = left_room.deleted;
4333        room = mem::take(&mut left_room.room);
4334        channel = mem::take(&mut left_room.channel);
4335
4336        room_updated(&room, &session.peer);
4337    } else {
4338        return Ok(());
4339    }
4340
4341    if let Some(channel) = channel {
4342        channel_updated(
4343            &channel,
4344            &room,
4345            &session.peer,
4346            &*session.connection_pool().await,
4347        );
4348    }
4349
4350    {
4351        let pool = session.connection_pool().await;
4352        for canceled_user_id in canceled_calls_to_user_ids {
4353            for connection_id in pool.user_connection_ids(canceled_user_id) {
4354                session
4355                    .peer
4356                    .send(
4357                        connection_id,
4358                        proto::CallCanceled {
4359                            room_id: room_id.to_proto(),
4360                        },
4361                    )
4362                    .trace_err();
4363            }
4364            contacts_to_update.insert(canceled_user_id);
4365        }
4366    }
4367
4368    for contact_user_id in contacts_to_update {
4369        update_user_contacts(contact_user_id, session).await?;
4370    }
4371
4372    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4373        live_kit
4374            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4375            .await
4376            .trace_err();
4377
4378        if delete_livekit_room {
4379            live_kit.delete_room(livekit_room).await.trace_err();
4380        }
4381    }
4382
4383    Ok(())
4384}
4385
4386async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4387    let left_channel_buffers = session
4388        .db()
4389        .await
4390        .leave_channel_buffers(session.connection_id)
4391        .await?;
4392
4393    for left_buffer in left_channel_buffers {
4394        channel_buffer_updated(
4395            session.connection_id,
4396            left_buffer.connections,
4397            &proto::UpdateChannelBufferCollaborators {
4398                channel_id: left_buffer.channel_id.to_proto(),
4399                collaborators: left_buffer.collaborators,
4400            },
4401            &session.peer,
4402        );
4403    }
4404
4405    Ok(())
4406}
4407
4408fn project_left(project: &db::LeftProject, session: &Session) {
4409    for connection_id in &project.connection_ids {
4410        if project.should_unshare {
4411            session
4412                .peer
4413                .send(
4414                    *connection_id,
4415                    proto::UnshareProject {
4416                        project_id: project.id.to_proto(),
4417                    },
4418                )
4419                .trace_err();
4420        } else {
4421            session
4422                .peer
4423                .send(
4424                    *connection_id,
4425                    proto::RemoveProjectCollaborator {
4426                        project_id: project.id.to_proto(),
4427                        peer_id: Some(session.connection_id.into()),
4428                    },
4429                )
4430                .trace_err();
4431        }
4432    }
4433}
4434
4435pub trait ResultExt {
4436    type Ok;
4437
4438    fn trace_err(self) -> Option<Self::Ok>;
4439}
4440
4441impl<T, E> ResultExt for Result<T, E>
4442where
4443    E: std::fmt::Debug,
4444{
4445    type Ok = T;
4446
4447    #[track_caller]
4448    fn trace_err(self) -> Option<T> {
4449        match self {
4450            Ok(value) => Some(value),
4451            Err(error) => {
4452                tracing::error!("{:?}", error);
4453                None
4454            }
4455        }
4456    }
4457}