rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    _executor: Executor,
 141}
 142
 143impl Session {
 144    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 145        #[cfg(test)]
 146        tokio::task::yield_now().await;
 147        let guard = self.db.lock().await;
 148        #[cfg(test)]
 149        tokio::task::yield_now().await;
 150        guard
 151    }
 152
 153    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.connection_pool.lock();
 157        ConnectionPoolGuard {
 158            guard,
 159            _not_send: PhantomData,
 160        }
 161    }
 162
 163    fn is_staff(&self) -> bool {
 164        match &self.principal {
 165            Principal::User(user) => user.admin,
 166            Principal::Impersonated { .. } => true,
 167        }
 168    }
 169
 170    pub async fn has_llm_subscription(
 171        &self,
 172        db: &MutexGuard<'_, DbHandle>,
 173    ) -> anyhow::Result<bool> {
 174        if self.is_staff() {
 175            return Ok(true);
 176        }
 177
 178        let user_id = self.user_id();
 179
 180        Ok(db.has_active_billing_subscription(user_id).await?)
 181    }
 182
 183    pub async fn current_plan(
 184        &self,
 185        _db: &MutexGuard<'_, DbHandle>,
 186    ) -> anyhow::Result<proto::Plan> {
 187        if self.is_staff() {
 188            Ok(proto::Plan::ZedPro)
 189        } else {
 190            Ok(proto::Plan::Free)
 191        }
 192    }
 193
 194    fn user_id(&self) -> UserId {
 195        match &self.principal {
 196            Principal::User(user) => user.id,
 197            Principal::Impersonated { user, .. } => user.id,
 198        }
 199    }
 200
 201    pub fn email(&self) -> Option<String> {
 202        match &self.principal {
 203            Principal::User(user) => user.email_address.clone(),
 204            Principal::Impersonated { user, .. } => user.email_address.clone(),
 205        }
 206    }
 207}
 208
 209impl Debug for Session {
 210    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 211        let mut result = f.debug_struct("Session");
 212        match &self.principal {
 213            Principal::User(user) => {
 214                result.field("user", &user.github_login);
 215            }
 216            Principal::Impersonated { user, admin } => {
 217                result.field("user", &user.github_login);
 218                result.field("impersonator", &admin.github_login);
 219            }
 220        }
 221        result.field("connection_id", &self.connection_id).finish()
 222    }
 223}
 224
 225struct DbHandle(Arc<Database>);
 226
 227impl Deref for DbHandle {
 228    type Target = Database;
 229
 230    fn deref(&self) -> &Self::Target {
 231        self.0.as_ref()
 232    }
 233}
 234
 235pub struct Server {
 236    id: parking_lot::Mutex<ServerId>,
 237    peer: Arc<Peer>,
 238    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 239    app_state: Arc<AppState>,
 240    handlers: HashMap<TypeId, MessageHandler>,
 241    teardown: watch::Sender<bool>,
 242}
 243
 244pub(crate) struct ConnectionPoolGuard<'a> {
 245    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 246    _not_send: PhantomData<Rc<()>>,
 247}
 248
 249#[derive(Serialize)]
 250pub struct ServerSnapshot<'a> {
 251    peer: &'a Peer,
 252    #[serde(serialize_with = "serialize_deref")]
 253    connection_pool: ConnectionPoolGuard<'a>,
 254}
 255
 256pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 257where
 258    S: Serializer,
 259    T: Deref<Target = U>,
 260    U: Serialize,
 261{
 262    Serialize::serialize(value.deref(), serializer)
 263}
 264
 265impl Server {
 266    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 267        let mut server = Self {
 268            id: parking_lot::Mutex::new(id),
 269            peer: Peer::new(id.0 as u32),
 270            app_state: app_state.clone(),
 271            connection_pool: Default::default(),
 272            handlers: Default::default(),
 273            teardown: watch::channel(false).0,
 274        };
 275
 276        server
 277            .add_request_handler(ping)
 278            .add_request_handler(create_room)
 279            .add_request_handler(join_room)
 280            .add_request_handler(rejoin_room)
 281            .add_request_handler(leave_room)
 282            .add_request_handler(set_room_participant_role)
 283            .add_request_handler(call)
 284            .add_request_handler(cancel_call)
 285            .add_message_handler(decline_call)
 286            .add_request_handler(update_participant_location)
 287            .add_request_handler(share_project)
 288            .add_message_handler(unshare_project)
 289            .add_request_handler(join_project)
 290            .add_message_handler(leave_project)
 291            .add_request_handler(update_project)
 292            .add_request_handler(update_worktree)
 293            .add_message_handler(start_language_server)
 294            .add_message_handler(update_language_server)
 295            .add_message_handler(update_diagnostic_summary)
 296            .add_message_handler(update_worktree_settings)
 297            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 301            .add_request_handler(forward_find_search_candidates_request)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 304            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 306            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 307            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 308            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 309            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 310            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 311            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 312            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 313            .add_request_handler(
 314                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 315            )
 316            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 317            .add_request_handler(
 318                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 319            )
 320            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 321            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 322            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 323            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 324            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 325            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 326            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 327            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 328            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 329            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 331            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 332            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 333            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 334            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 335            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 336            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 337            .add_message_handler(create_buffer_for_peer)
 338            .add_request_handler(update_buffer)
 339            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 340            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 341            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 342            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 343            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 344            .add_request_handler(get_users)
 345            .add_request_handler(fuzzy_search_users)
 346            .add_request_handler(request_contact)
 347            .add_request_handler(remove_contact)
 348            .add_request_handler(respond_to_contact_request)
 349            .add_message_handler(subscribe_to_channels)
 350            .add_request_handler(create_channel)
 351            .add_request_handler(delete_channel)
 352            .add_request_handler(invite_channel_member)
 353            .add_request_handler(remove_channel_member)
 354            .add_request_handler(set_channel_member_role)
 355            .add_request_handler(set_channel_visibility)
 356            .add_request_handler(rename_channel)
 357            .add_request_handler(join_channel_buffer)
 358            .add_request_handler(leave_channel_buffer)
 359            .add_message_handler(update_channel_buffer)
 360            .add_request_handler(rejoin_channel_buffers)
 361            .add_request_handler(get_channel_members)
 362            .add_request_handler(respond_to_channel_invite)
 363            .add_request_handler(join_channel)
 364            .add_request_handler(join_channel_chat)
 365            .add_message_handler(leave_channel_chat)
 366            .add_request_handler(send_channel_message)
 367            .add_request_handler(remove_channel_message)
 368            .add_request_handler(update_channel_message)
 369            .add_request_handler(get_channel_messages)
 370            .add_request_handler(get_channel_messages_by_id)
 371            .add_request_handler(get_notifications)
 372            .add_request_handler(mark_notification_as_read)
 373            .add_request_handler(move_channel)
 374            .add_request_handler(follow)
 375            .add_message_handler(unfollow)
 376            .add_message_handler(update_followers)
 377            .add_request_handler(get_private_user_info)
 378            .add_request_handler(get_llm_api_token)
 379            .add_request_handler(accept_terms_of_service)
 380            .add_message_handler(acknowledge_channel_message)
 381            .add_message_handler(acknowledge_buffer_version)
 382            .add_request_handler(get_supermaven_api_key)
 383            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 384            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 385            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 387            .add_message_handler(update_context)
 388            .add_request_handler({
 389                let app_state = app_state.clone();
 390                move |request, response, session| {
 391                    let app_state = app_state.clone();
 392                    async move {
 393                        count_language_model_tokens(request, response, session, &app_state.config)
 394                            .await
 395                    }
 396                }
 397            })
 398            .add_request_handler(get_cached_embeddings)
 399            .add_request_handler({
 400                let app_state = app_state.clone();
 401                move |request, response, session| {
 402                    compute_embeddings(
 403                        request,
 404                        response,
 405                        session,
 406                        app_state.config.openai_api_key.clone(),
 407                    )
 408                }
 409            });
 410
 411        Arc::new(server)
 412    }
 413
 414    pub async fn start(&self) -> Result<()> {
 415        let server_id = *self.id.lock();
 416        let app_state = self.app_state.clone();
 417        let peer = self.peer.clone();
 418        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 419        let pool = self.connection_pool.clone();
 420        let live_kit_client = self.app_state.live_kit_client.clone();
 421
 422        let span = info_span!("start server");
 423        self.app_state.executor.spawn_detached(
 424            async move {
 425                tracing::info!("waiting for cleanup timeout");
 426                timeout.await;
 427                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 428                if let Some((room_ids, channel_ids)) = app_state
 429                    .db
 430                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 431                    .await
 432                    .trace_err()
 433                {
 434                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 435                    tracing::info!(
 436                        stale_channel_buffer_count = channel_ids.len(),
 437                        "retrieved stale channel buffers"
 438                    );
 439
 440                    for channel_id in channel_ids {
 441                        if let Some(refreshed_channel_buffer) = app_state
 442                            .db
 443                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 444                            .await
 445                            .trace_err()
 446                        {
 447                            for connection_id in refreshed_channel_buffer.connection_ids {
 448                                peer.send(
 449                                    connection_id,
 450                                    proto::UpdateChannelBufferCollaborators {
 451                                        channel_id: channel_id.to_proto(),
 452                                        collaborators: refreshed_channel_buffer
 453                                            .collaborators
 454                                            .clone(),
 455                                    },
 456                                )
 457                                .trace_err();
 458                            }
 459                        }
 460                    }
 461
 462                    for room_id in room_ids {
 463                        let mut contacts_to_update = HashSet::default();
 464                        let mut canceled_calls_to_user_ids = Vec::new();
 465                        let mut live_kit_room = String::new();
 466                        let mut delete_live_kit_room = false;
 467
 468                        if let Some(mut refreshed_room) = app_state
 469                            .db
 470                            .clear_stale_room_participants(room_id, server_id)
 471                            .await
 472                            .trace_err()
 473                        {
 474                            tracing::info!(
 475                                room_id = room_id.0,
 476                                new_participant_count = refreshed_room.room.participants.len(),
 477                                "refreshed room"
 478                            );
 479                            room_updated(&refreshed_room.room, &peer);
 480                            if let Some(channel) = refreshed_room.channel.as_ref() {
 481                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 482                            }
 483                            contacts_to_update
 484                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 485                            contacts_to_update
 486                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 487                            canceled_calls_to_user_ids =
 488                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 489                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 490                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 491                        }
 492
 493                        {
 494                            let pool = pool.lock();
 495                            for canceled_user_id in canceled_calls_to_user_ids {
 496                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 497                                    peer.send(
 498                                        connection_id,
 499                                        proto::CallCanceled {
 500                                            room_id: room_id.to_proto(),
 501                                        },
 502                                    )
 503                                    .trace_err();
 504                                }
 505                            }
 506                        }
 507
 508                        for user_id in contacts_to_update {
 509                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 510                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 511                            if let Some((busy, contacts)) = busy.zip(contacts) {
 512                                let pool = pool.lock();
 513                                let updated_contact = contact_for_user(user_id, busy, &pool);
 514                                for contact in contacts {
 515                                    if let db::Contact::Accepted {
 516                                        user_id: contact_user_id,
 517                                        ..
 518                                    } = contact
 519                                    {
 520                                        for contact_conn_id in
 521                                            pool.user_connection_ids(contact_user_id)
 522                                        {
 523                                            peer.send(
 524                                                contact_conn_id,
 525                                                proto::UpdateContacts {
 526                                                    contacts: vec![updated_contact.clone()],
 527                                                    remove_contacts: Default::default(),
 528                                                    incoming_requests: Default::default(),
 529                                                    remove_incoming_requests: Default::default(),
 530                                                    outgoing_requests: Default::default(),
 531                                                    remove_outgoing_requests: Default::default(),
 532                                                },
 533                                            )
 534                                            .trace_err();
 535                                        }
 536                                    }
 537                                }
 538                            }
 539                        }
 540
 541                        if let Some(live_kit) = live_kit_client.as_ref() {
 542                            if delete_live_kit_room {
 543                                live_kit.delete_room(live_kit_room).await.trace_err();
 544                            }
 545                        }
 546                    }
 547                }
 548
 549                app_state
 550                    .db
 551                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 552                    .await
 553                    .trace_err();
 554            }
 555            .instrument(span),
 556        );
 557        Ok(())
 558    }
 559
 560    pub fn teardown(&self) {
 561        self.peer.teardown();
 562        self.connection_pool.lock().reset();
 563        let _ = self.teardown.send(true);
 564    }
 565
 566    #[cfg(test)]
 567    pub fn reset(&self, id: ServerId) {
 568        self.teardown();
 569        *self.id.lock() = id;
 570        self.peer.reset(id.0 as u32);
 571        let _ = self.teardown.send(false);
 572    }
 573
 574    #[cfg(test)]
 575    pub fn id(&self) -> ServerId {
 576        *self.id.lock()
 577    }
 578
 579    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 580    where
 581        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 582        Fut: 'static + Send + Future<Output = Result<()>>,
 583        M: EnvelopedMessage,
 584    {
 585        let prev_handler = self.handlers.insert(
 586            TypeId::of::<M>(),
 587            Box::new(move |envelope, session| {
 588                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 589                let received_at = envelope.received_at;
 590                tracing::info!("message received");
 591                let start_time = Instant::now();
 592                let future = (handler)(*envelope, session);
 593                async move {
 594                    let result = future.await;
 595                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 596                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 597                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 598                    let payload_type = M::NAME;
 599
 600                    match result {
 601                        Err(error) => {
 602                            tracing::error!(
 603                                ?error,
 604                                total_duration_ms,
 605                                processing_duration_ms,
 606                                queue_duration_ms,
 607                                payload_type,
 608                                "error handling message"
 609                            )
 610                        }
 611                        Ok(()) => tracing::info!(
 612                            total_duration_ms,
 613                            processing_duration_ms,
 614                            queue_duration_ms,
 615                            "finished handling message"
 616                        ),
 617                    }
 618                }
 619                .boxed()
 620            }),
 621        );
 622        if prev_handler.is_some() {
 623            panic!("registered a handler for the same message twice");
 624        }
 625        self
 626    }
 627
 628    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 629    where
 630        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 631        Fut: 'static + Send + Future<Output = Result<()>>,
 632        M: EnvelopedMessage,
 633    {
 634        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 635        self
 636    }
 637
 638    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 639    where
 640        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 641        Fut: Send + Future<Output = Result<()>>,
 642        M: RequestMessage,
 643    {
 644        let handler = Arc::new(handler);
 645        self.add_handler(move |envelope, session| {
 646            let receipt = envelope.receipt();
 647            let handler = handler.clone();
 648            async move {
 649                let peer = session.peer.clone();
 650                let responded = Arc::new(AtomicBool::default());
 651                let response = Response {
 652                    peer: peer.clone(),
 653                    responded: responded.clone(),
 654                    receipt,
 655                };
 656                match (handler)(envelope.payload, response, session).await {
 657                    Ok(()) => {
 658                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 659                            Ok(())
 660                        } else {
 661                            Err(anyhow!("handler did not send a response"))?
 662                        }
 663                    }
 664                    Err(error) => {
 665                        let proto_err = match &error {
 666                            Error::Internal(err) => err.to_proto(),
 667                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 668                        };
 669                        peer.respond_with_error(receipt, proto_err)?;
 670                        Err(error)
 671                    }
 672                }
 673            }
 674        })
 675    }
 676
 677    #[allow(clippy::too_many_arguments)]
 678    pub fn handle_connection(
 679        self: &Arc<Self>,
 680        connection: Connection,
 681        address: String,
 682        principal: Principal,
 683        zed_version: ZedVersion,
 684        geoip_country_code: Option<String>,
 685        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 686        executor: Executor,
 687    ) -> impl Future<Output = ()> {
 688        let this = self.clone();
 689        let span = info_span!("handle connection", %address,
 690            connection_id=field::Empty,
 691            user_id=field::Empty,
 692            login=field::Empty,
 693            impersonator=field::Empty,
 694            geoip_country_code=field::Empty
 695        );
 696        principal.update_span(&span);
 697        if let Some(country_code) = geoip_country_code.as_ref() {
 698            span.record("geoip_country_code", country_code);
 699        }
 700
 701        let mut teardown = self.teardown.subscribe();
 702        async move {
 703            if *teardown.borrow() {
 704                tracing::error!("server is tearing down");
 705                return
 706            }
 707            let (connection_id, handle_io, mut incoming_rx) = this
 708                .peer
 709                .add_connection(connection, {
 710                    let executor = executor.clone();
 711                    move |duration| executor.sleep(duration)
 712                });
 713            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 714
 715            tracing::info!("connection opened");
 716
 717            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 718            let http_client = match ReqwestClient::user_agent(&user_agent) {
 719                Ok(http_client) => Arc::new(http_client),
 720                Err(error) => {
 721                    tracing::error!(?error, "failed to create HTTP client");
 722                    return;
 723                }
 724            };
 725
 726            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 727                    supermaven_admin_api_key.to_string(),
 728                    http_client.clone(),
 729                )));
 730
 731            let session = Session {
 732                principal: principal.clone(),
 733                connection_id,
 734                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 735                peer: this.peer.clone(),
 736                connection_pool: this.connection_pool.clone(),
 737                app_state: this.app_state.clone(),
 738                http_client,
 739                geoip_country_code,
 740                _executor: executor.clone(),
 741                supermaven_client,
 742            };
 743
 744            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 745                tracing::error!(?error, "failed to send initial client update");
 746                return;
 747            }
 748
 749            let handle_io = handle_io.fuse();
 750            futures::pin_mut!(handle_io);
 751
 752            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 753            // This prevents deadlocks when e.g., client A performs a request to client B and
 754            // client B performs a request to client A. If both clients stop processing further
 755            // messages until their respective request completes, they won't have a chance to
 756            // respond to the other client's request and cause a deadlock.
 757            //
 758            // This arrangement ensures we will attempt to process earlier messages first, but fall
 759            // back to processing messages arrived later in the spirit of making progress.
 760            let mut foreground_message_handlers = FuturesUnordered::new();
 761            let concurrent_handlers = Arc::new(Semaphore::new(256));
 762            loop {
 763                let next_message = async {
 764                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 765                    let message = incoming_rx.next().await;
 766                    (permit, message)
 767                }.fuse();
 768                futures::pin_mut!(next_message);
 769                futures::select_biased! {
 770                    _ = teardown.changed().fuse() => return,
 771                    result = handle_io => {
 772                        if let Err(error) = result {
 773                            tracing::error!(?error, "error handling I/O");
 774                        }
 775                        break;
 776                    }
 777                    _ = foreground_message_handlers.next() => {}
 778                    next_message = next_message => {
 779                        let (permit, message) = next_message;
 780                        if let Some(message) = message {
 781                            let type_name = message.payload_type_name();
 782                            // note: we copy all the fields from the parent span so we can query them in the logs.
 783                            // (https://github.com/tokio-rs/tracing/issues/2670).
 784                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 785                                user_id=field::Empty,
 786                                login=field::Empty,
 787                                impersonator=field::Empty,
 788                            );
 789                            principal.update_span(&span);
 790                            let span_enter = span.enter();
 791                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 792                                let is_background = message.is_background();
 793                                let handle_message = (handler)(message, session.clone());
 794                                drop(span_enter);
 795
 796                                let handle_message = async move {
 797                                    handle_message.await;
 798                                    drop(permit);
 799                                }.instrument(span);
 800                                if is_background {
 801                                    executor.spawn_detached(handle_message);
 802                                } else {
 803                                    foreground_message_handlers.push(handle_message);
 804                                }
 805                            } else {
 806                                tracing::error!("no message handler");
 807                            }
 808                        } else {
 809                            tracing::info!("connection closed");
 810                            break;
 811                        }
 812                    }
 813                }
 814            }
 815
 816            drop(foreground_message_handlers);
 817            tracing::info!("signing out");
 818            if let Err(error) = connection_lost(session, teardown, executor).await {
 819                tracing::error!(?error, "error signing out");
 820            }
 821
 822        }.instrument(span)
 823    }
 824
 825    async fn send_initial_client_update(
 826        &self,
 827        connection_id: ConnectionId,
 828        principal: &Principal,
 829        zed_version: ZedVersion,
 830        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 831        session: &Session,
 832    ) -> Result<()> {
 833        self.peer.send(
 834            connection_id,
 835            proto::Hello {
 836                peer_id: Some(connection_id.into()),
 837            },
 838        )?;
 839        tracing::info!("sent hello message");
 840        if let Some(send_connection_id) = send_connection_id.take() {
 841            let _ = send_connection_id.send(connection_id);
 842        }
 843
 844        match principal {
 845            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 846                if !user.connected_once {
 847                    self.peer.send(connection_id, proto::ShowContacts {})?;
 848                    self.app_state
 849                        .db
 850                        .set_user_connected_once(user.id, true)
 851                        .await?;
 852                }
 853
 854                update_user_plan(user.id, session).await?;
 855
 856                let contacts = self.app_state.db.get_contacts(user.id).await?;
 857
 858                {
 859                    let mut pool = self.connection_pool.lock();
 860                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 861                    self.peer.send(
 862                        connection_id,
 863                        build_initial_contacts_update(contacts, &pool),
 864                    )?;
 865                }
 866
 867                if should_auto_subscribe_to_channels(zed_version) {
 868                    subscribe_user_to_channels(user.id, session).await?;
 869                }
 870
 871                if let Some(incoming_call) =
 872                    self.app_state.db.incoming_call_for_user(user.id).await?
 873                {
 874                    self.peer.send(connection_id, incoming_call)?;
 875                }
 876
 877                update_user_contacts(user.id, session).await?;
 878            }
 879        }
 880
 881        Ok(())
 882    }
 883
 884    pub async fn invite_code_redeemed(
 885        self: &Arc<Self>,
 886        inviter_id: UserId,
 887        invitee_id: UserId,
 888    ) -> Result<()> {
 889        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 890            if let Some(code) = &user.invite_code {
 891                let pool = self.connection_pool.lock();
 892                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 893                for connection_id in pool.user_connection_ids(inviter_id) {
 894                    self.peer.send(
 895                        connection_id,
 896                        proto::UpdateContacts {
 897                            contacts: vec![invitee_contact.clone()],
 898                            ..Default::default()
 899                        },
 900                    )?;
 901                    self.peer.send(
 902                        connection_id,
 903                        proto::UpdateInviteInfo {
 904                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 905                            count: user.invite_count as u32,
 906                        },
 907                    )?;
 908                }
 909            }
 910        }
 911        Ok(())
 912    }
 913
 914    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 915        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 916            if let Some(invite_code) = &user.invite_code {
 917                let pool = self.connection_pool.lock();
 918                for connection_id in pool.user_connection_ids(user_id) {
 919                    self.peer.send(
 920                        connection_id,
 921                        proto::UpdateInviteInfo {
 922                            url: format!(
 923                                "{}{}",
 924                                self.app_state.config.invite_link_prefix, invite_code
 925                            ),
 926                            count: user.invite_count as u32,
 927                        },
 928                    )?;
 929                }
 930            }
 931        }
 932        Ok(())
 933    }
 934
 935    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 936        let pool = self.connection_pool.lock();
 937        for connection_id in pool.user_connection_ids(user_id) {
 938            self.peer
 939                .send(connection_id, proto::RefreshLlmToken {})
 940                .trace_err();
 941        }
 942    }
 943
 944    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 945        ServerSnapshot {
 946            connection_pool: ConnectionPoolGuard {
 947                guard: self.connection_pool.lock(),
 948                _not_send: PhantomData,
 949            },
 950            peer: &self.peer,
 951        }
 952    }
 953}
 954
 955impl<'a> Deref for ConnectionPoolGuard<'a> {
 956    type Target = ConnectionPool;
 957
 958    fn deref(&self) -> &Self::Target {
 959        &self.guard
 960    }
 961}
 962
 963impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 964    fn deref_mut(&mut self) -> &mut Self::Target {
 965        &mut self.guard
 966    }
 967}
 968
 969impl<'a> Drop for ConnectionPoolGuard<'a> {
 970    fn drop(&mut self) {
 971        #[cfg(test)]
 972        self.check_invariants();
 973    }
 974}
 975
 976fn broadcast<F>(
 977    sender_id: Option<ConnectionId>,
 978    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 979    mut f: F,
 980) where
 981    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 982{
 983    for receiver_id in receiver_ids {
 984        if Some(receiver_id) != sender_id {
 985            if let Err(error) = f(receiver_id) {
 986                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 987            }
 988        }
 989    }
 990}
 991
 992pub struct ProtocolVersion(u32);
 993
 994impl Header for ProtocolVersion {
 995    fn name() -> &'static HeaderName {
 996        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
 997        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
 998    }
 999
1000    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1001    where
1002        Self: Sized,
1003        I: Iterator<Item = &'i axum::http::HeaderValue>,
1004    {
1005        let version = values
1006            .next()
1007            .ok_or_else(axum::headers::Error::invalid)?
1008            .to_str()
1009            .map_err(|_| axum::headers::Error::invalid())?
1010            .parse()
1011            .map_err(|_| axum::headers::Error::invalid())?;
1012        Ok(Self(version))
1013    }
1014
1015    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1016        values.extend([self.0.to_string().parse().unwrap()]);
1017    }
1018}
1019
1020pub struct AppVersionHeader(SemanticVersion);
1021impl Header for AppVersionHeader {
1022    fn name() -> &'static HeaderName {
1023        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1024        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1025    }
1026
1027    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1028    where
1029        Self: Sized,
1030        I: Iterator<Item = &'i axum::http::HeaderValue>,
1031    {
1032        let version = values
1033            .next()
1034            .ok_or_else(axum::headers::Error::invalid)?
1035            .to_str()
1036            .map_err(|_| axum::headers::Error::invalid())?
1037            .parse()
1038            .map_err(|_| axum::headers::Error::invalid())?;
1039        Ok(Self(version))
1040    }
1041
1042    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1043        values.extend([self.0.to_string().parse().unwrap()]);
1044    }
1045}
1046
1047pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1048    Router::new()
1049        .route("/rpc", get(handle_websocket_request))
1050        .layer(
1051            ServiceBuilder::new()
1052                .layer(Extension(server.app_state.clone()))
1053                .layer(middleware::from_fn(auth::validate_header)),
1054        )
1055        .route("/metrics", get(handle_metrics))
1056        .layer(Extension(server))
1057}
1058
1059pub async fn handle_websocket_request(
1060    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1061    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1062    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1063    Extension(server): Extension<Arc<Server>>,
1064    Extension(principal): Extension<Principal>,
1065    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1066    ws: WebSocketUpgrade,
1067) -> axum::response::Response {
1068    if protocol_version != rpc::PROTOCOL_VERSION {
1069        return (
1070            StatusCode::UPGRADE_REQUIRED,
1071            "client must be upgraded".to_string(),
1072        )
1073            .into_response();
1074    }
1075
1076    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1077        return (
1078            StatusCode::UPGRADE_REQUIRED,
1079            "no version header found".to_string(),
1080        )
1081            .into_response();
1082    };
1083
1084    if !version.can_collaborate() {
1085        return (
1086            StatusCode::UPGRADE_REQUIRED,
1087            "client must be upgraded".to_string(),
1088        )
1089            .into_response();
1090    }
1091
1092    let socket_address = socket_address.to_string();
1093    ws.on_upgrade(move |socket| {
1094        let socket = socket
1095            .map_ok(to_tungstenite_message)
1096            .err_into()
1097            .with(|message| async move { to_axum_message(message) });
1098        let connection = Connection::new(Box::pin(socket));
1099        async move {
1100            server
1101                .handle_connection(
1102                    connection,
1103                    socket_address,
1104                    principal,
1105                    version,
1106                    country_code_header.map(|header| header.to_string()),
1107                    None,
1108                    Executor::Production,
1109                )
1110                .await;
1111        }
1112    })
1113}
1114
1115pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1116    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1117    let connections_metric = CONNECTIONS_METRIC
1118        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1119
1120    let connections = server
1121        .connection_pool
1122        .lock()
1123        .connections()
1124        .filter(|connection| !connection.admin)
1125        .count();
1126    connections_metric.set(connections as _);
1127
1128    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1129    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1130        register_int_gauge!(
1131            "shared_projects",
1132            "number of open projects with one or more guests"
1133        )
1134        .unwrap()
1135    });
1136
1137    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1138    shared_projects_metric.set(shared_projects as _);
1139
1140    let encoder = prometheus::TextEncoder::new();
1141    let metric_families = prometheus::gather();
1142    let encoded_metrics = encoder
1143        .encode_to_string(&metric_families)
1144        .map_err(|err| anyhow!("{}", err))?;
1145    Ok(encoded_metrics)
1146}
1147
1148#[instrument(err, skip(executor))]
1149async fn connection_lost(
1150    session: Session,
1151    mut teardown: watch::Receiver<bool>,
1152    executor: Executor,
1153) -> Result<()> {
1154    session.peer.disconnect(session.connection_id);
1155    session
1156        .connection_pool()
1157        .await
1158        .remove_connection(session.connection_id)?;
1159
1160    session
1161        .db()
1162        .await
1163        .connection_lost(session.connection_id)
1164        .await
1165        .trace_err();
1166
1167    futures::select_biased! {
1168        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1169
1170            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1171            leave_room_for_session(&session, session.connection_id).await.trace_err();
1172            leave_channel_buffers_for_session(&session)
1173                .await
1174                .trace_err();
1175
1176            if !session
1177                .connection_pool()
1178                .await
1179                .is_user_online(session.user_id())
1180            {
1181                let db = session.db().await;
1182                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1183                    room_updated(&room, &session.peer);
1184                }
1185            }
1186
1187            update_user_contacts(session.user_id(), &session).await?;
1188        },
1189        _ = teardown.changed().fuse() => {}
1190    }
1191
1192    Ok(())
1193}
1194
1195/// Acknowledges a ping from a client, used to keep the connection alive.
1196async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1197    response.send(proto::Ack {})?;
1198    Ok(())
1199}
1200
1201/// Creates a new room for calling (outside of channels)
1202async fn create_room(
1203    _request: proto::CreateRoom,
1204    response: Response<proto::CreateRoom>,
1205    session: Session,
1206) -> Result<()> {
1207    let live_kit_room = nanoid::nanoid!(30);
1208
1209    let live_kit_connection_info = util::maybe!(async {
1210        let live_kit = session.app_state.live_kit_client.as_ref();
1211        let live_kit = live_kit?;
1212        let user_id = session.user_id().to_string();
1213
1214        let token = live_kit
1215            .room_token(&live_kit_room, &user_id.to_string())
1216            .trace_err()?;
1217
1218        Some(proto::LiveKitConnectionInfo {
1219            server_url: live_kit.url().into(),
1220            token,
1221            can_publish: true,
1222        })
1223    })
1224    .await;
1225
1226    let room = session
1227        .db()
1228        .await
1229        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1230        .await?;
1231
1232    response.send(proto::CreateRoomResponse {
1233        room: Some(room.clone()),
1234        live_kit_connection_info,
1235    })?;
1236
1237    update_user_contacts(session.user_id(), &session).await?;
1238    Ok(())
1239}
1240
1241/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1242async fn join_room(
1243    request: proto::JoinRoom,
1244    response: Response<proto::JoinRoom>,
1245    session: Session,
1246) -> Result<()> {
1247    let room_id = RoomId::from_proto(request.id);
1248
1249    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1250
1251    if let Some(channel_id) = channel_id {
1252        return join_channel_internal(channel_id, Box::new(response), session).await;
1253    }
1254
1255    let joined_room = {
1256        let room = session
1257            .db()
1258            .await
1259            .join_room(room_id, session.user_id(), session.connection_id)
1260            .await?;
1261        room_updated(&room.room, &session.peer);
1262        room.into_inner()
1263    };
1264
1265    for connection_id in session
1266        .connection_pool()
1267        .await
1268        .user_connection_ids(session.user_id())
1269    {
1270        session
1271            .peer
1272            .send(
1273                connection_id,
1274                proto::CallCanceled {
1275                    room_id: room_id.to_proto(),
1276                },
1277            )
1278            .trace_err();
1279    }
1280
1281    let live_kit_connection_info =
1282        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1283            live_kit
1284                .room_token(
1285                    &joined_room.room.live_kit_room,
1286                    &session.user_id().to_string(),
1287                )
1288                .trace_err()
1289                .map(|token| proto::LiveKitConnectionInfo {
1290                    server_url: live_kit.url().into(),
1291                    token,
1292                    can_publish: true,
1293                })
1294        } else {
1295            None
1296        };
1297
1298    response.send(proto::JoinRoomResponse {
1299        room: Some(joined_room.room),
1300        channel_id: None,
1301        live_kit_connection_info,
1302    })?;
1303
1304    update_user_contacts(session.user_id(), &session).await?;
1305    Ok(())
1306}
1307
1308/// Rejoin room is used to reconnect to a room after connection errors.
1309async fn rejoin_room(
1310    request: proto::RejoinRoom,
1311    response: Response<proto::RejoinRoom>,
1312    session: Session,
1313) -> Result<()> {
1314    let room;
1315    let channel;
1316    {
1317        let mut rejoined_room = session
1318            .db()
1319            .await
1320            .rejoin_room(request, session.user_id(), session.connection_id)
1321            .await?;
1322
1323        response.send(proto::RejoinRoomResponse {
1324            room: Some(rejoined_room.room.clone()),
1325            reshared_projects: rejoined_room
1326                .reshared_projects
1327                .iter()
1328                .map(|project| proto::ResharedProject {
1329                    id: project.id.to_proto(),
1330                    collaborators: project
1331                        .collaborators
1332                        .iter()
1333                        .map(|collaborator| collaborator.to_proto())
1334                        .collect(),
1335                })
1336                .collect(),
1337            rejoined_projects: rejoined_room
1338                .rejoined_projects
1339                .iter()
1340                .map(|rejoined_project| rejoined_project.to_proto())
1341                .collect(),
1342        })?;
1343        room_updated(&rejoined_room.room, &session.peer);
1344
1345        for project in &rejoined_room.reshared_projects {
1346            for collaborator in &project.collaborators {
1347                session
1348                    .peer
1349                    .send(
1350                        collaborator.connection_id,
1351                        proto::UpdateProjectCollaborator {
1352                            project_id: project.id.to_proto(),
1353                            old_peer_id: Some(project.old_connection_id.into()),
1354                            new_peer_id: Some(session.connection_id.into()),
1355                        },
1356                    )
1357                    .trace_err();
1358            }
1359
1360            broadcast(
1361                Some(session.connection_id),
1362                project
1363                    .collaborators
1364                    .iter()
1365                    .map(|collaborator| collaborator.connection_id),
1366                |connection_id| {
1367                    session.peer.forward_send(
1368                        session.connection_id,
1369                        connection_id,
1370                        proto::UpdateProject {
1371                            project_id: project.id.to_proto(),
1372                            worktrees: project.worktrees.clone(),
1373                        },
1374                    )
1375                },
1376            );
1377        }
1378
1379        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1380
1381        let rejoined_room = rejoined_room.into_inner();
1382
1383        room = rejoined_room.room;
1384        channel = rejoined_room.channel;
1385    }
1386
1387    if let Some(channel) = channel {
1388        channel_updated(
1389            &channel,
1390            &room,
1391            &session.peer,
1392            &*session.connection_pool().await,
1393        );
1394    }
1395
1396    update_user_contacts(session.user_id(), &session).await?;
1397    Ok(())
1398}
1399
1400fn notify_rejoined_projects(
1401    rejoined_projects: &mut Vec<RejoinedProject>,
1402    session: &Session,
1403) -> Result<()> {
1404    for project in rejoined_projects.iter() {
1405        for collaborator in &project.collaborators {
1406            session
1407                .peer
1408                .send(
1409                    collaborator.connection_id,
1410                    proto::UpdateProjectCollaborator {
1411                        project_id: project.id.to_proto(),
1412                        old_peer_id: Some(project.old_connection_id.into()),
1413                        new_peer_id: Some(session.connection_id.into()),
1414                    },
1415                )
1416                .trace_err();
1417        }
1418    }
1419
1420    for project in rejoined_projects {
1421        for worktree in mem::take(&mut project.worktrees) {
1422            // Stream this worktree's entries.
1423            let message = proto::UpdateWorktree {
1424                project_id: project.id.to_proto(),
1425                worktree_id: worktree.id,
1426                abs_path: worktree.abs_path.clone(),
1427                root_name: worktree.root_name,
1428                updated_entries: worktree.updated_entries,
1429                removed_entries: worktree.removed_entries,
1430                scan_id: worktree.scan_id,
1431                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1432                updated_repositories: worktree.updated_repositories,
1433                removed_repositories: worktree.removed_repositories,
1434            };
1435            for update in proto::split_worktree_update(message) {
1436                session.peer.send(session.connection_id, update.clone())?;
1437            }
1438
1439            // Stream this worktree's diagnostics.
1440            for summary in worktree.diagnostic_summaries {
1441                session.peer.send(
1442                    session.connection_id,
1443                    proto::UpdateDiagnosticSummary {
1444                        project_id: project.id.to_proto(),
1445                        worktree_id: worktree.id,
1446                        summary: Some(summary),
1447                    },
1448                )?;
1449            }
1450
1451            for settings_file in worktree.settings_files {
1452                session.peer.send(
1453                    session.connection_id,
1454                    proto::UpdateWorktreeSettings {
1455                        project_id: project.id.to_proto(),
1456                        worktree_id: worktree.id,
1457                        path: settings_file.path,
1458                        content: Some(settings_file.content),
1459                        kind: Some(settings_file.kind.to_proto().into()),
1460                    },
1461                )?;
1462            }
1463        }
1464
1465        for language_server in &project.language_servers {
1466            session.peer.send(
1467                session.connection_id,
1468                proto::UpdateLanguageServer {
1469                    project_id: project.id.to_proto(),
1470                    language_server_id: language_server.id,
1471                    variant: Some(
1472                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1473                            proto::LspDiskBasedDiagnosticsUpdated {},
1474                        ),
1475                    ),
1476                },
1477            )?;
1478        }
1479    }
1480    Ok(())
1481}
1482
1483/// leave room disconnects from the room.
1484async fn leave_room(
1485    _: proto::LeaveRoom,
1486    response: Response<proto::LeaveRoom>,
1487    session: Session,
1488) -> Result<()> {
1489    leave_room_for_session(&session, session.connection_id).await?;
1490    response.send(proto::Ack {})?;
1491    Ok(())
1492}
1493
1494/// Updates the permissions of someone else in the room.
1495async fn set_room_participant_role(
1496    request: proto::SetRoomParticipantRole,
1497    response: Response<proto::SetRoomParticipantRole>,
1498    session: Session,
1499) -> Result<()> {
1500    let user_id = UserId::from_proto(request.user_id);
1501    let role = ChannelRole::from(request.role());
1502
1503    let (live_kit_room, can_publish) = {
1504        let room = session
1505            .db()
1506            .await
1507            .set_room_participant_role(
1508                session.user_id(),
1509                RoomId::from_proto(request.room_id),
1510                user_id,
1511                role,
1512            )
1513            .await?;
1514
1515        let live_kit_room = room.live_kit_room.clone();
1516        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1517        room_updated(&room, &session.peer);
1518        (live_kit_room, can_publish)
1519    };
1520
1521    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1522        live_kit
1523            .update_participant(
1524                live_kit_room.clone(),
1525                request.user_id.to_string(),
1526                live_kit_server::proto::ParticipantPermission {
1527                    can_subscribe: true,
1528                    can_publish,
1529                    can_publish_data: can_publish,
1530                    hidden: false,
1531                    recorder: false,
1532                },
1533            )
1534            .await
1535            .trace_err();
1536    }
1537
1538    response.send(proto::Ack {})?;
1539    Ok(())
1540}
1541
1542/// Call someone else into the current room
1543async fn call(
1544    request: proto::Call,
1545    response: Response<proto::Call>,
1546    session: Session,
1547) -> Result<()> {
1548    let room_id = RoomId::from_proto(request.room_id);
1549    let calling_user_id = session.user_id();
1550    let calling_connection_id = session.connection_id;
1551    let called_user_id = UserId::from_proto(request.called_user_id);
1552    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1553    if !session
1554        .db()
1555        .await
1556        .has_contact(calling_user_id, called_user_id)
1557        .await?
1558    {
1559        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1560    }
1561
1562    let incoming_call = {
1563        let (room, incoming_call) = &mut *session
1564            .db()
1565            .await
1566            .call(
1567                room_id,
1568                calling_user_id,
1569                calling_connection_id,
1570                called_user_id,
1571                initial_project_id,
1572            )
1573            .await?;
1574        room_updated(room, &session.peer);
1575        mem::take(incoming_call)
1576    };
1577    update_user_contacts(called_user_id, &session).await?;
1578
1579    let mut calls = session
1580        .connection_pool()
1581        .await
1582        .user_connection_ids(called_user_id)
1583        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1584        .collect::<FuturesUnordered<_>>();
1585
1586    while let Some(call_response) = calls.next().await {
1587        match call_response.as_ref() {
1588            Ok(_) => {
1589                response.send(proto::Ack {})?;
1590                return Ok(());
1591            }
1592            Err(_) => {
1593                call_response.trace_err();
1594            }
1595        }
1596    }
1597
1598    {
1599        let room = session
1600            .db()
1601            .await
1602            .call_failed(room_id, called_user_id)
1603            .await?;
1604        room_updated(&room, &session.peer);
1605    }
1606    update_user_contacts(called_user_id, &session).await?;
1607
1608    Err(anyhow!("failed to ring user"))?
1609}
1610
1611/// Cancel an outgoing call.
1612async fn cancel_call(
1613    request: proto::CancelCall,
1614    response: Response<proto::CancelCall>,
1615    session: Session,
1616) -> Result<()> {
1617    let called_user_id = UserId::from_proto(request.called_user_id);
1618    let room_id = RoomId::from_proto(request.room_id);
1619    {
1620        let room = session
1621            .db()
1622            .await
1623            .cancel_call(room_id, session.connection_id, called_user_id)
1624            .await?;
1625        room_updated(&room, &session.peer);
1626    }
1627
1628    for connection_id in session
1629        .connection_pool()
1630        .await
1631        .user_connection_ids(called_user_id)
1632    {
1633        session
1634            .peer
1635            .send(
1636                connection_id,
1637                proto::CallCanceled {
1638                    room_id: room_id.to_proto(),
1639                },
1640            )
1641            .trace_err();
1642    }
1643    response.send(proto::Ack {})?;
1644
1645    update_user_contacts(called_user_id, &session).await?;
1646    Ok(())
1647}
1648
1649/// Decline an incoming call.
1650async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1651    let room_id = RoomId::from_proto(message.room_id);
1652    {
1653        let room = session
1654            .db()
1655            .await
1656            .decline_call(Some(room_id), session.user_id())
1657            .await?
1658            .ok_or_else(|| anyhow!("failed to decline call"))?;
1659        room_updated(&room, &session.peer);
1660    }
1661
1662    for connection_id in session
1663        .connection_pool()
1664        .await
1665        .user_connection_ids(session.user_id())
1666    {
1667        session
1668            .peer
1669            .send(
1670                connection_id,
1671                proto::CallCanceled {
1672                    room_id: room_id.to_proto(),
1673                },
1674            )
1675            .trace_err();
1676    }
1677    update_user_contacts(session.user_id(), &session).await?;
1678    Ok(())
1679}
1680
1681/// Updates other participants in the room with your current location.
1682async fn update_participant_location(
1683    request: proto::UpdateParticipantLocation,
1684    response: Response<proto::UpdateParticipantLocation>,
1685    session: Session,
1686) -> Result<()> {
1687    let room_id = RoomId::from_proto(request.room_id);
1688    let location = request
1689        .location
1690        .ok_or_else(|| anyhow!("invalid location"))?;
1691
1692    let db = session.db().await;
1693    let room = db
1694        .update_room_participant_location(room_id, session.connection_id, location)
1695        .await?;
1696
1697    room_updated(&room, &session.peer);
1698    response.send(proto::Ack {})?;
1699    Ok(())
1700}
1701
1702/// Share a project into the room.
1703async fn share_project(
1704    request: proto::ShareProject,
1705    response: Response<proto::ShareProject>,
1706    session: Session,
1707) -> Result<()> {
1708    let (project_id, room) = &*session
1709        .db()
1710        .await
1711        .share_project(
1712            RoomId::from_proto(request.room_id),
1713            session.connection_id,
1714            &request.worktrees,
1715            request.is_ssh_project,
1716        )
1717        .await?;
1718    response.send(proto::ShareProjectResponse {
1719        project_id: project_id.to_proto(),
1720    })?;
1721    room_updated(room, &session.peer);
1722
1723    Ok(())
1724}
1725
1726/// Unshare a project from the room.
1727async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1728    let project_id = ProjectId::from_proto(message.project_id);
1729    unshare_project_internal(project_id, session.connection_id, &session).await
1730}
1731
1732async fn unshare_project_internal(
1733    project_id: ProjectId,
1734    connection_id: ConnectionId,
1735    session: &Session,
1736) -> Result<()> {
1737    let delete = {
1738        let room_guard = session
1739            .db()
1740            .await
1741            .unshare_project(project_id, connection_id)
1742            .await?;
1743
1744        let (delete, room, guest_connection_ids) = &*room_guard;
1745
1746        let message = proto::UnshareProject {
1747            project_id: project_id.to_proto(),
1748        };
1749
1750        broadcast(
1751            Some(connection_id),
1752            guest_connection_ids.iter().copied(),
1753            |conn_id| session.peer.send(conn_id, message.clone()),
1754        );
1755        if let Some(room) = room {
1756            room_updated(room, &session.peer);
1757        }
1758
1759        *delete
1760    };
1761
1762    if delete {
1763        let db = session.db().await;
1764        db.delete_project(project_id).await?;
1765    }
1766
1767    Ok(())
1768}
1769
1770/// Join someone elses shared project.
1771async fn join_project(
1772    request: proto::JoinProject,
1773    response: Response<proto::JoinProject>,
1774    session: Session,
1775) -> Result<()> {
1776    let project_id = ProjectId::from_proto(request.project_id);
1777
1778    tracing::info!(%project_id, "join project");
1779
1780    let db = session.db().await;
1781    let (project, replica_id) = &mut *db
1782        .join_project(project_id, session.connection_id, session.user_id())
1783        .await?;
1784    drop(db);
1785    tracing::info!(%project_id, "join remote project");
1786    join_project_internal(response, session, project, replica_id)
1787}
1788
1789trait JoinProjectInternalResponse {
1790    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1791}
1792impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1793    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1794        Response::<proto::JoinProject>::send(self, result)
1795    }
1796}
1797
1798fn join_project_internal(
1799    response: impl JoinProjectInternalResponse,
1800    session: Session,
1801    project: &mut Project,
1802    replica_id: &ReplicaId,
1803) -> Result<()> {
1804    let collaborators = project
1805        .collaborators
1806        .iter()
1807        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1808        .map(|collaborator| collaborator.to_proto())
1809        .collect::<Vec<_>>();
1810    let project_id = project.id;
1811    let guest_user_id = session.user_id();
1812
1813    let worktrees = project
1814        .worktrees
1815        .iter()
1816        .map(|(id, worktree)| proto::WorktreeMetadata {
1817            id: *id,
1818            root_name: worktree.root_name.clone(),
1819            visible: worktree.visible,
1820            abs_path: worktree.abs_path.clone(),
1821        })
1822        .collect::<Vec<_>>();
1823
1824    let add_project_collaborator = proto::AddProjectCollaborator {
1825        project_id: project_id.to_proto(),
1826        collaborator: Some(proto::Collaborator {
1827            peer_id: Some(session.connection_id.into()),
1828            replica_id: replica_id.0 as u32,
1829            user_id: guest_user_id.to_proto(),
1830            is_host: false,
1831        }),
1832    };
1833
1834    for collaborator in &collaborators {
1835        session
1836            .peer
1837            .send(
1838                collaborator.peer_id.unwrap().into(),
1839                add_project_collaborator.clone(),
1840            )
1841            .trace_err();
1842    }
1843
1844    // First, we send the metadata associated with each worktree.
1845    response.send(proto::JoinProjectResponse {
1846        project_id: project.id.0 as u64,
1847        worktrees: worktrees.clone(),
1848        replica_id: replica_id.0 as u32,
1849        collaborators: collaborators.clone(),
1850        language_servers: project.language_servers.clone(),
1851        role: project.role.into(),
1852    })?;
1853
1854    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1855        // Stream this worktree's entries.
1856        let message = proto::UpdateWorktree {
1857            project_id: project_id.to_proto(),
1858            worktree_id,
1859            abs_path: worktree.abs_path.clone(),
1860            root_name: worktree.root_name,
1861            updated_entries: worktree.entries,
1862            removed_entries: Default::default(),
1863            scan_id: worktree.scan_id,
1864            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1865            updated_repositories: worktree.repository_entries.into_values().collect(),
1866            removed_repositories: Default::default(),
1867        };
1868        for update in proto::split_worktree_update(message) {
1869            session.peer.send(session.connection_id, update.clone())?;
1870        }
1871
1872        // Stream this worktree's diagnostics.
1873        for summary in worktree.diagnostic_summaries {
1874            session.peer.send(
1875                session.connection_id,
1876                proto::UpdateDiagnosticSummary {
1877                    project_id: project_id.to_proto(),
1878                    worktree_id: worktree.id,
1879                    summary: Some(summary),
1880                },
1881            )?;
1882        }
1883
1884        for settings_file in worktree.settings_files {
1885            session.peer.send(
1886                session.connection_id,
1887                proto::UpdateWorktreeSettings {
1888                    project_id: project_id.to_proto(),
1889                    worktree_id: worktree.id,
1890                    path: settings_file.path,
1891                    content: Some(settings_file.content),
1892                    kind: Some(settings_file.kind.to_proto() as i32),
1893                },
1894            )?;
1895        }
1896    }
1897
1898    for language_server in &project.language_servers {
1899        session.peer.send(
1900            session.connection_id,
1901            proto::UpdateLanguageServer {
1902                project_id: project_id.to_proto(),
1903                language_server_id: language_server.id,
1904                variant: Some(
1905                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1906                        proto::LspDiskBasedDiagnosticsUpdated {},
1907                    ),
1908                ),
1909            },
1910        )?;
1911    }
1912
1913    Ok(())
1914}
1915
1916/// Leave someone elses shared project.
1917async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1918    let sender_id = session.connection_id;
1919    let project_id = ProjectId::from_proto(request.project_id);
1920    let db = session.db().await;
1921
1922    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1923    tracing::info!(
1924        %project_id,
1925        "leave project"
1926    );
1927
1928    project_left(project, &session);
1929    if let Some(room) = room {
1930        room_updated(room, &session.peer);
1931    }
1932
1933    Ok(())
1934}
1935
1936/// Updates other participants with changes to the project
1937async fn update_project(
1938    request: proto::UpdateProject,
1939    response: Response<proto::UpdateProject>,
1940    session: Session,
1941) -> Result<()> {
1942    let project_id = ProjectId::from_proto(request.project_id);
1943    let (room, guest_connection_ids) = &*session
1944        .db()
1945        .await
1946        .update_project(project_id, session.connection_id, &request.worktrees)
1947        .await?;
1948    broadcast(
1949        Some(session.connection_id),
1950        guest_connection_ids.iter().copied(),
1951        |connection_id| {
1952            session
1953                .peer
1954                .forward_send(session.connection_id, connection_id, request.clone())
1955        },
1956    );
1957    if let Some(room) = room {
1958        room_updated(room, &session.peer);
1959    }
1960    response.send(proto::Ack {})?;
1961
1962    Ok(())
1963}
1964
1965/// Updates other participants with changes to the worktree
1966async fn update_worktree(
1967    request: proto::UpdateWorktree,
1968    response: Response<proto::UpdateWorktree>,
1969    session: Session,
1970) -> Result<()> {
1971    let guest_connection_ids = session
1972        .db()
1973        .await
1974        .update_worktree(&request, session.connection_id)
1975        .await?;
1976
1977    broadcast(
1978        Some(session.connection_id),
1979        guest_connection_ids.iter().copied(),
1980        |connection_id| {
1981            session
1982                .peer
1983                .forward_send(session.connection_id, connection_id, request.clone())
1984        },
1985    );
1986    response.send(proto::Ack {})?;
1987    Ok(())
1988}
1989
1990/// Updates other participants with changes to the diagnostics
1991async fn update_diagnostic_summary(
1992    message: proto::UpdateDiagnosticSummary,
1993    session: Session,
1994) -> Result<()> {
1995    let guest_connection_ids = session
1996        .db()
1997        .await
1998        .update_diagnostic_summary(&message, session.connection_id)
1999        .await?;
2000
2001    broadcast(
2002        Some(session.connection_id),
2003        guest_connection_ids.iter().copied(),
2004        |connection_id| {
2005            session
2006                .peer
2007                .forward_send(session.connection_id, connection_id, message.clone())
2008        },
2009    );
2010
2011    Ok(())
2012}
2013
2014/// Updates other participants with changes to the worktree settings
2015async fn update_worktree_settings(
2016    message: proto::UpdateWorktreeSettings,
2017    session: Session,
2018) -> Result<()> {
2019    let guest_connection_ids = session
2020        .db()
2021        .await
2022        .update_worktree_settings(&message, session.connection_id)
2023        .await?;
2024
2025    broadcast(
2026        Some(session.connection_id),
2027        guest_connection_ids.iter().copied(),
2028        |connection_id| {
2029            session
2030                .peer
2031                .forward_send(session.connection_id, connection_id, message.clone())
2032        },
2033    );
2034
2035    Ok(())
2036}
2037
2038/// Notify other participants that a  language server has started.
2039async fn start_language_server(
2040    request: proto::StartLanguageServer,
2041    session: Session,
2042) -> Result<()> {
2043    let guest_connection_ids = session
2044        .db()
2045        .await
2046        .start_language_server(&request, session.connection_id)
2047        .await?;
2048
2049    broadcast(
2050        Some(session.connection_id),
2051        guest_connection_ids.iter().copied(),
2052        |connection_id| {
2053            session
2054                .peer
2055                .forward_send(session.connection_id, connection_id, request.clone())
2056        },
2057    );
2058    Ok(())
2059}
2060
2061/// Notify other participants that a language server has changed.
2062async fn update_language_server(
2063    request: proto::UpdateLanguageServer,
2064    session: Session,
2065) -> Result<()> {
2066    let project_id = ProjectId::from_proto(request.project_id);
2067    let project_connection_ids = session
2068        .db()
2069        .await
2070        .project_connection_ids(project_id, session.connection_id, true)
2071        .await?;
2072    broadcast(
2073        Some(session.connection_id),
2074        project_connection_ids.iter().copied(),
2075        |connection_id| {
2076            session
2077                .peer
2078                .forward_send(session.connection_id, connection_id, request.clone())
2079        },
2080    );
2081    Ok(())
2082}
2083
2084/// forward a project request to the host. These requests should be read only
2085/// as guests are allowed to send them.
2086async fn forward_read_only_project_request<T>(
2087    request: T,
2088    response: Response<T>,
2089    session: Session,
2090) -> Result<()>
2091where
2092    T: EntityMessage + RequestMessage,
2093{
2094    let project_id = ProjectId::from_proto(request.remote_entity_id());
2095    let host_connection_id = session
2096        .db()
2097        .await
2098        .host_for_read_only_project_request(project_id, session.connection_id)
2099        .await?;
2100    let payload = session
2101        .peer
2102        .forward_request(session.connection_id, host_connection_id, request)
2103        .await?;
2104    response.send(payload)?;
2105    Ok(())
2106}
2107
2108async fn forward_find_search_candidates_request(
2109    request: proto::FindSearchCandidates,
2110    response: Response<proto::FindSearchCandidates>,
2111    session: Session,
2112) -> Result<()> {
2113    let project_id = ProjectId::from_proto(request.remote_entity_id());
2114    let host_connection_id = session
2115        .db()
2116        .await
2117        .host_for_read_only_project_request(project_id, session.connection_id)
2118        .await?;
2119    let payload = session
2120        .peer
2121        .forward_request(session.connection_id, host_connection_id, request)
2122        .await?;
2123    response.send(payload)?;
2124    Ok(())
2125}
2126
2127/// forward a project request to the host. These requests are disallowed
2128/// for guests.
2129async fn forward_mutating_project_request<T>(
2130    request: T,
2131    response: Response<T>,
2132    session: Session,
2133) -> Result<()>
2134where
2135    T: EntityMessage + RequestMessage,
2136{
2137    let project_id = ProjectId::from_proto(request.remote_entity_id());
2138
2139    let host_connection_id = session
2140        .db()
2141        .await
2142        .host_for_mutating_project_request(project_id, session.connection_id)
2143        .await?;
2144    let payload = session
2145        .peer
2146        .forward_request(session.connection_id, host_connection_id, request)
2147        .await?;
2148    response.send(payload)?;
2149    Ok(())
2150}
2151
2152/// Notify other participants that a new buffer has been created
2153async fn create_buffer_for_peer(
2154    request: proto::CreateBufferForPeer,
2155    session: Session,
2156) -> Result<()> {
2157    session
2158        .db()
2159        .await
2160        .check_user_is_project_host(
2161            ProjectId::from_proto(request.project_id),
2162            session.connection_id,
2163        )
2164        .await?;
2165    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2166    session
2167        .peer
2168        .forward_send(session.connection_id, peer_id.into(), request)?;
2169    Ok(())
2170}
2171
2172/// Notify other participants that a buffer has been updated. This is
2173/// allowed for guests as long as the update is limited to selections.
2174async fn update_buffer(
2175    request: proto::UpdateBuffer,
2176    response: Response<proto::UpdateBuffer>,
2177    session: Session,
2178) -> Result<()> {
2179    let project_id = ProjectId::from_proto(request.project_id);
2180    let mut capability = Capability::ReadOnly;
2181
2182    for op in request.operations.iter() {
2183        match op.variant {
2184            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2185            Some(_) => capability = Capability::ReadWrite,
2186        }
2187    }
2188
2189    let host = {
2190        let guard = session
2191            .db()
2192            .await
2193            .connections_for_buffer_update(project_id, session.connection_id, capability)
2194            .await?;
2195
2196        let (host, guests) = &*guard;
2197
2198        broadcast(
2199            Some(session.connection_id),
2200            guests.clone(),
2201            |connection_id| {
2202                session
2203                    .peer
2204                    .forward_send(session.connection_id, connection_id, request.clone())
2205            },
2206        );
2207
2208        *host
2209    };
2210
2211    if host != session.connection_id {
2212        session
2213            .peer
2214            .forward_request(session.connection_id, host, request.clone())
2215            .await?;
2216    }
2217
2218    response.send(proto::Ack {})?;
2219    Ok(())
2220}
2221
2222async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2223    let project_id = ProjectId::from_proto(message.project_id);
2224
2225    let operation = message.operation.as_ref().context("invalid operation")?;
2226    let capability = match operation.variant.as_ref() {
2227        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2228            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2229                match buffer_op.variant {
2230                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2231                        Capability::ReadOnly
2232                    }
2233                    _ => Capability::ReadWrite,
2234                }
2235            } else {
2236                Capability::ReadWrite
2237            }
2238        }
2239        Some(_) => Capability::ReadWrite,
2240        None => Capability::ReadOnly,
2241    };
2242
2243    let guard = session
2244        .db()
2245        .await
2246        .connections_for_buffer_update(project_id, session.connection_id, capability)
2247        .await?;
2248
2249    let (host, guests) = &*guard;
2250
2251    broadcast(
2252        Some(session.connection_id),
2253        guests.iter().chain([host]).copied(),
2254        |connection_id| {
2255            session
2256                .peer
2257                .forward_send(session.connection_id, connection_id, message.clone())
2258        },
2259    );
2260
2261    Ok(())
2262}
2263
2264/// Notify other participants that a project has been updated.
2265async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2266    request: T,
2267    session: Session,
2268) -> Result<()> {
2269    let project_id = ProjectId::from_proto(request.remote_entity_id());
2270    let project_connection_ids = session
2271        .db()
2272        .await
2273        .project_connection_ids(project_id, session.connection_id, false)
2274        .await?;
2275
2276    broadcast(
2277        Some(session.connection_id),
2278        project_connection_ids.iter().copied(),
2279        |connection_id| {
2280            session
2281                .peer
2282                .forward_send(session.connection_id, connection_id, request.clone())
2283        },
2284    );
2285    Ok(())
2286}
2287
2288/// Start following another user in a call.
2289async fn follow(
2290    request: proto::Follow,
2291    response: Response<proto::Follow>,
2292    session: Session,
2293) -> Result<()> {
2294    let room_id = RoomId::from_proto(request.room_id);
2295    let project_id = request.project_id.map(ProjectId::from_proto);
2296    let leader_id = request
2297        .leader_id
2298        .ok_or_else(|| anyhow!("invalid leader id"))?
2299        .into();
2300    let follower_id = session.connection_id;
2301
2302    session
2303        .db()
2304        .await
2305        .check_room_participants(room_id, leader_id, session.connection_id)
2306        .await?;
2307
2308    let response_payload = session
2309        .peer
2310        .forward_request(session.connection_id, leader_id, request)
2311        .await?;
2312    response.send(response_payload)?;
2313
2314    if let Some(project_id) = project_id {
2315        let room = session
2316            .db()
2317            .await
2318            .follow(room_id, project_id, leader_id, follower_id)
2319            .await?;
2320        room_updated(&room, &session.peer);
2321    }
2322
2323    Ok(())
2324}
2325
2326/// Stop following another user in a call.
2327async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2328    let room_id = RoomId::from_proto(request.room_id);
2329    let project_id = request.project_id.map(ProjectId::from_proto);
2330    let leader_id = request
2331        .leader_id
2332        .ok_or_else(|| anyhow!("invalid leader id"))?
2333        .into();
2334    let follower_id = session.connection_id;
2335
2336    session
2337        .db()
2338        .await
2339        .check_room_participants(room_id, leader_id, session.connection_id)
2340        .await?;
2341
2342    session
2343        .peer
2344        .forward_send(session.connection_id, leader_id, request)?;
2345
2346    if let Some(project_id) = project_id {
2347        let room = session
2348            .db()
2349            .await
2350            .unfollow(room_id, project_id, leader_id, follower_id)
2351            .await?;
2352        room_updated(&room, &session.peer);
2353    }
2354
2355    Ok(())
2356}
2357
2358/// Notify everyone following you of your current location.
2359async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2360    let room_id = RoomId::from_proto(request.room_id);
2361    let database = session.db.lock().await;
2362
2363    let connection_ids = if let Some(project_id) = request.project_id {
2364        let project_id = ProjectId::from_proto(project_id);
2365        database
2366            .project_connection_ids(project_id, session.connection_id, true)
2367            .await?
2368    } else {
2369        database
2370            .room_connection_ids(room_id, session.connection_id)
2371            .await?
2372    };
2373
2374    // For now, don't send view update messages back to that view's current leader.
2375    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2376        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2377        _ => None,
2378    });
2379
2380    for connection_id in connection_ids.iter().cloned() {
2381        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2382            session
2383                .peer
2384                .forward_send(session.connection_id, connection_id, request.clone())?;
2385        }
2386    }
2387    Ok(())
2388}
2389
2390/// Get public data about users.
2391async fn get_users(
2392    request: proto::GetUsers,
2393    response: Response<proto::GetUsers>,
2394    session: Session,
2395) -> Result<()> {
2396    let user_ids = request
2397        .user_ids
2398        .into_iter()
2399        .map(UserId::from_proto)
2400        .collect();
2401    let users = session
2402        .db()
2403        .await
2404        .get_users_by_ids(user_ids)
2405        .await?
2406        .into_iter()
2407        .map(|user| proto::User {
2408            id: user.id.to_proto(),
2409            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2410            github_login: user.github_login,
2411        })
2412        .collect();
2413    response.send(proto::UsersResponse { users })?;
2414    Ok(())
2415}
2416
2417/// Search for users (to invite) buy Github login
2418async fn fuzzy_search_users(
2419    request: proto::FuzzySearchUsers,
2420    response: Response<proto::FuzzySearchUsers>,
2421    session: Session,
2422) -> Result<()> {
2423    let query = request.query;
2424    let users = match query.len() {
2425        0 => vec![],
2426        1 | 2 => session
2427            .db()
2428            .await
2429            .get_user_by_github_login(&query)
2430            .await?
2431            .into_iter()
2432            .collect(),
2433        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2434    };
2435    let users = users
2436        .into_iter()
2437        .filter(|user| user.id != session.user_id())
2438        .map(|user| proto::User {
2439            id: user.id.to_proto(),
2440            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2441            github_login: user.github_login,
2442        })
2443        .collect();
2444    response.send(proto::UsersResponse { users })?;
2445    Ok(())
2446}
2447
2448/// Send a contact request to another user.
2449async fn request_contact(
2450    request: proto::RequestContact,
2451    response: Response<proto::RequestContact>,
2452    session: Session,
2453) -> Result<()> {
2454    let requester_id = session.user_id();
2455    let responder_id = UserId::from_proto(request.responder_id);
2456    if requester_id == responder_id {
2457        return Err(anyhow!("cannot add yourself as a contact"))?;
2458    }
2459
2460    let notifications = session
2461        .db()
2462        .await
2463        .send_contact_request(requester_id, responder_id)
2464        .await?;
2465
2466    // Update outgoing contact requests of requester
2467    let mut update = proto::UpdateContacts::default();
2468    update.outgoing_requests.push(responder_id.to_proto());
2469    for connection_id in session
2470        .connection_pool()
2471        .await
2472        .user_connection_ids(requester_id)
2473    {
2474        session.peer.send(connection_id, update.clone())?;
2475    }
2476
2477    // Update incoming contact requests of responder
2478    let mut update = proto::UpdateContacts::default();
2479    update
2480        .incoming_requests
2481        .push(proto::IncomingContactRequest {
2482            requester_id: requester_id.to_proto(),
2483        });
2484    let connection_pool = session.connection_pool().await;
2485    for connection_id in connection_pool.user_connection_ids(responder_id) {
2486        session.peer.send(connection_id, update.clone())?;
2487    }
2488
2489    send_notifications(&connection_pool, &session.peer, notifications);
2490
2491    response.send(proto::Ack {})?;
2492    Ok(())
2493}
2494
2495/// Accept or decline a contact request
2496async fn respond_to_contact_request(
2497    request: proto::RespondToContactRequest,
2498    response: Response<proto::RespondToContactRequest>,
2499    session: Session,
2500) -> Result<()> {
2501    let responder_id = session.user_id();
2502    let requester_id = UserId::from_proto(request.requester_id);
2503    let db = session.db().await;
2504    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2505        db.dismiss_contact_notification(responder_id, requester_id)
2506            .await?;
2507    } else {
2508        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2509
2510        let notifications = db
2511            .respond_to_contact_request(responder_id, requester_id, accept)
2512            .await?;
2513        let requester_busy = db.is_user_busy(requester_id).await?;
2514        let responder_busy = db.is_user_busy(responder_id).await?;
2515
2516        let pool = session.connection_pool().await;
2517        // Update responder with new contact
2518        let mut update = proto::UpdateContacts::default();
2519        if accept {
2520            update
2521                .contacts
2522                .push(contact_for_user(requester_id, requester_busy, &pool));
2523        }
2524        update
2525            .remove_incoming_requests
2526            .push(requester_id.to_proto());
2527        for connection_id in pool.user_connection_ids(responder_id) {
2528            session.peer.send(connection_id, update.clone())?;
2529        }
2530
2531        // Update requester with new contact
2532        let mut update = proto::UpdateContacts::default();
2533        if accept {
2534            update
2535                .contacts
2536                .push(contact_for_user(responder_id, responder_busy, &pool));
2537        }
2538        update
2539            .remove_outgoing_requests
2540            .push(responder_id.to_proto());
2541
2542        for connection_id in pool.user_connection_ids(requester_id) {
2543            session.peer.send(connection_id, update.clone())?;
2544        }
2545
2546        send_notifications(&pool, &session.peer, notifications);
2547    }
2548
2549    response.send(proto::Ack {})?;
2550    Ok(())
2551}
2552
2553/// Remove a contact.
2554async fn remove_contact(
2555    request: proto::RemoveContact,
2556    response: Response<proto::RemoveContact>,
2557    session: Session,
2558) -> Result<()> {
2559    let requester_id = session.user_id();
2560    let responder_id = UserId::from_proto(request.user_id);
2561    let db = session.db().await;
2562    let (contact_accepted, deleted_notification_id) =
2563        db.remove_contact(requester_id, responder_id).await?;
2564
2565    let pool = session.connection_pool().await;
2566    // Update outgoing contact requests of requester
2567    let mut update = proto::UpdateContacts::default();
2568    if contact_accepted {
2569        update.remove_contacts.push(responder_id.to_proto());
2570    } else {
2571        update
2572            .remove_outgoing_requests
2573            .push(responder_id.to_proto());
2574    }
2575    for connection_id in pool.user_connection_ids(requester_id) {
2576        session.peer.send(connection_id, update.clone())?;
2577    }
2578
2579    // Update incoming contact requests of responder
2580    let mut update = proto::UpdateContacts::default();
2581    if contact_accepted {
2582        update.remove_contacts.push(requester_id.to_proto());
2583    } else {
2584        update
2585            .remove_incoming_requests
2586            .push(requester_id.to_proto());
2587    }
2588    for connection_id in pool.user_connection_ids(responder_id) {
2589        session.peer.send(connection_id, update.clone())?;
2590        if let Some(notification_id) = deleted_notification_id {
2591            session.peer.send(
2592                connection_id,
2593                proto::DeleteNotification {
2594                    notification_id: notification_id.to_proto(),
2595                },
2596            )?;
2597        }
2598    }
2599
2600    response.send(proto::Ack {})?;
2601    Ok(())
2602}
2603
2604fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2605    version.0.minor() < 139
2606}
2607
2608async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2609    let plan = session.current_plan(&session.db().await).await?;
2610
2611    session
2612        .peer
2613        .send(
2614            session.connection_id,
2615            proto::UpdateUserPlan { plan: plan.into() },
2616        )
2617        .trace_err();
2618
2619    Ok(())
2620}
2621
2622async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2623    subscribe_user_to_channels(session.user_id(), &session).await?;
2624    Ok(())
2625}
2626
2627async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2628    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2629    let mut pool = session.connection_pool().await;
2630    for membership in &channels_for_user.channel_memberships {
2631        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2632    }
2633    session.peer.send(
2634        session.connection_id,
2635        build_update_user_channels(&channels_for_user),
2636    )?;
2637    session.peer.send(
2638        session.connection_id,
2639        build_channels_update(channels_for_user),
2640    )?;
2641    Ok(())
2642}
2643
2644/// Creates a new channel.
2645async fn create_channel(
2646    request: proto::CreateChannel,
2647    response: Response<proto::CreateChannel>,
2648    session: Session,
2649) -> Result<()> {
2650    let db = session.db().await;
2651
2652    let parent_id = request.parent_id.map(ChannelId::from_proto);
2653    let (channel, membership) = db
2654        .create_channel(&request.name, parent_id, session.user_id())
2655        .await?;
2656
2657    let root_id = channel.root_id();
2658    let channel = Channel::from_model(channel);
2659
2660    response.send(proto::CreateChannelResponse {
2661        channel: Some(channel.to_proto()),
2662        parent_id: request.parent_id,
2663    })?;
2664
2665    let mut connection_pool = session.connection_pool().await;
2666    if let Some(membership) = membership {
2667        connection_pool.subscribe_to_channel(
2668            membership.user_id,
2669            membership.channel_id,
2670            membership.role,
2671        );
2672        let update = proto::UpdateUserChannels {
2673            channel_memberships: vec![proto::ChannelMembership {
2674                channel_id: membership.channel_id.to_proto(),
2675                role: membership.role.into(),
2676            }],
2677            ..Default::default()
2678        };
2679        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2680            session.peer.send(connection_id, update.clone())?;
2681        }
2682    }
2683
2684    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2685        if !role.can_see_channel(channel.visibility) {
2686            continue;
2687        }
2688
2689        let update = proto::UpdateChannels {
2690            channels: vec![channel.to_proto()],
2691            ..Default::default()
2692        };
2693        session.peer.send(connection_id, update.clone())?;
2694    }
2695
2696    Ok(())
2697}
2698
2699/// Delete a channel
2700async fn delete_channel(
2701    request: proto::DeleteChannel,
2702    response: Response<proto::DeleteChannel>,
2703    session: Session,
2704) -> Result<()> {
2705    let db = session.db().await;
2706
2707    let channel_id = request.channel_id;
2708    let (root_channel, removed_channels) = db
2709        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2710        .await?;
2711    response.send(proto::Ack {})?;
2712
2713    // Notify members of removed channels
2714    let mut update = proto::UpdateChannels::default();
2715    update
2716        .delete_channels
2717        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2718
2719    let connection_pool = session.connection_pool().await;
2720    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2721        session.peer.send(connection_id, update.clone())?;
2722    }
2723
2724    Ok(())
2725}
2726
2727/// Invite someone to join a channel.
2728async fn invite_channel_member(
2729    request: proto::InviteChannelMember,
2730    response: Response<proto::InviteChannelMember>,
2731    session: Session,
2732) -> Result<()> {
2733    let db = session.db().await;
2734    let channel_id = ChannelId::from_proto(request.channel_id);
2735    let invitee_id = UserId::from_proto(request.user_id);
2736    let InviteMemberResult {
2737        channel,
2738        notifications,
2739    } = db
2740        .invite_channel_member(
2741            channel_id,
2742            invitee_id,
2743            session.user_id(),
2744            request.role().into(),
2745        )
2746        .await?;
2747
2748    let update = proto::UpdateChannels {
2749        channel_invitations: vec![channel.to_proto()],
2750        ..Default::default()
2751    };
2752
2753    let connection_pool = session.connection_pool().await;
2754    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2755        session.peer.send(connection_id, update.clone())?;
2756    }
2757
2758    send_notifications(&connection_pool, &session.peer, notifications);
2759
2760    response.send(proto::Ack {})?;
2761    Ok(())
2762}
2763
2764/// remove someone from a channel
2765async fn remove_channel_member(
2766    request: proto::RemoveChannelMember,
2767    response: Response<proto::RemoveChannelMember>,
2768    session: Session,
2769) -> Result<()> {
2770    let db = session.db().await;
2771    let channel_id = ChannelId::from_proto(request.channel_id);
2772    let member_id = UserId::from_proto(request.user_id);
2773
2774    let RemoveChannelMemberResult {
2775        membership_update,
2776        notification_id,
2777    } = db
2778        .remove_channel_member(channel_id, member_id, session.user_id())
2779        .await?;
2780
2781    let mut connection_pool = session.connection_pool().await;
2782    notify_membership_updated(
2783        &mut connection_pool,
2784        membership_update,
2785        member_id,
2786        &session.peer,
2787    );
2788    for connection_id in connection_pool.user_connection_ids(member_id) {
2789        if let Some(notification_id) = notification_id {
2790            session
2791                .peer
2792                .send(
2793                    connection_id,
2794                    proto::DeleteNotification {
2795                        notification_id: notification_id.to_proto(),
2796                    },
2797                )
2798                .trace_err();
2799        }
2800    }
2801
2802    response.send(proto::Ack {})?;
2803    Ok(())
2804}
2805
2806/// Toggle the channel between public and private.
2807/// Care is taken to maintain the invariant that public channels only descend from public channels,
2808/// (though members-only channels can appear at any point in the hierarchy).
2809async fn set_channel_visibility(
2810    request: proto::SetChannelVisibility,
2811    response: Response<proto::SetChannelVisibility>,
2812    session: Session,
2813) -> Result<()> {
2814    let db = session.db().await;
2815    let channel_id = ChannelId::from_proto(request.channel_id);
2816    let visibility = request.visibility().into();
2817
2818    let channel_model = db
2819        .set_channel_visibility(channel_id, visibility, session.user_id())
2820        .await?;
2821    let root_id = channel_model.root_id();
2822    let channel = Channel::from_model(channel_model);
2823
2824    let mut connection_pool = session.connection_pool().await;
2825    for (user_id, role) in connection_pool
2826        .channel_user_ids(root_id)
2827        .collect::<Vec<_>>()
2828        .into_iter()
2829    {
2830        let update = if role.can_see_channel(channel.visibility) {
2831            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2832            proto::UpdateChannels {
2833                channels: vec![channel.to_proto()],
2834                ..Default::default()
2835            }
2836        } else {
2837            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2838            proto::UpdateChannels {
2839                delete_channels: vec![channel.id.to_proto()],
2840                ..Default::default()
2841            }
2842        };
2843
2844        for connection_id in connection_pool.user_connection_ids(user_id) {
2845            session.peer.send(connection_id, update.clone())?;
2846        }
2847    }
2848
2849    response.send(proto::Ack {})?;
2850    Ok(())
2851}
2852
2853/// Alter the role for a user in the channel.
2854async fn set_channel_member_role(
2855    request: proto::SetChannelMemberRole,
2856    response: Response<proto::SetChannelMemberRole>,
2857    session: Session,
2858) -> Result<()> {
2859    let db = session.db().await;
2860    let channel_id = ChannelId::from_proto(request.channel_id);
2861    let member_id = UserId::from_proto(request.user_id);
2862    let result = db
2863        .set_channel_member_role(
2864            channel_id,
2865            session.user_id(),
2866            member_id,
2867            request.role().into(),
2868        )
2869        .await?;
2870
2871    match result {
2872        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2873            let mut connection_pool = session.connection_pool().await;
2874            notify_membership_updated(
2875                &mut connection_pool,
2876                membership_update,
2877                member_id,
2878                &session.peer,
2879            )
2880        }
2881        db::SetMemberRoleResult::InviteUpdated(channel) => {
2882            let update = proto::UpdateChannels {
2883                channel_invitations: vec![channel.to_proto()],
2884                ..Default::default()
2885            };
2886
2887            for connection_id in session
2888                .connection_pool()
2889                .await
2890                .user_connection_ids(member_id)
2891            {
2892                session.peer.send(connection_id, update.clone())?;
2893            }
2894        }
2895    }
2896
2897    response.send(proto::Ack {})?;
2898    Ok(())
2899}
2900
2901/// Change the name of a channel
2902async fn rename_channel(
2903    request: proto::RenameChannel,
2904    response: Response<proto::RenameChannel>,
2905    session: Session,
2906) -> Result<()> {
2907    let db = session.db().await;
2908    let channel_id = ChannelId::from_proto(request.channel_id);
2909    let channel_model = db
2910        .rename_channel(channel_id, session.user_id(), &request.name)
2911        .await?;
2912    let root_id = channel_model.root_id();
2913    let channel = Channel::from_model(channel_model);
2914
2915    response.send(proto::RenameChannelResponse {
2916        channel: Some(channel.to_proto()),
2917    })?;
2918
2919    let connection_pool = session.connection_pool().await;
2920    let update = proto::UpdateChannels {
2921        channels: vec![channel.to_proto()],
2922        ..Default::default()
2923    };
2924    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2925        if role.can_see_channel(channel.visibility) {
2926            session.peer.send(connection_id, update.clone())?;
2927        }
2928    }
2929
2930    Ok(())
2931}
2932
2933/// Move a channel to a new parent.
2934async fn move_channel(
2935    request: proto::MoveChannel,
2936    response: Response<proto::MoveChannel>,
2937    session: Session,
2938) -> Result<()> {
2939    let channel_id = ChannelId::from_proto(request.channel_id);
2940    let to = ChannelId::from_proto(request.to);
2941
2942    let (root_id, channels) = session
2943        .db()
2944        .await
2945        .move_channel(channel_id, to, session.user_id())
2946        .await?;
2947
2948    let connection_pool = session.connection_pool().await;
2949    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2950        let channels = channels
2951            .iter()
2952            .filter_map(|channel| {
2953                if role.can_see_channel(channel.visibility) {
2954                    Some(channel.to_proto())
2955                } else {
2956                    None
2957                }
2958            })
2959            .collect::<Vec<_>>();
2960        if channels.is_empty() {
2961            continue;
2962        }
2963
2964        let update = proto::UpdateChannels {
2965            channels,
2966            ..Default::default()
2967        };
2968
2969        session.peer.send(connection_id, update.clone())?;
2970    }
2971
2972    response.send(Ack {})?;
2973    Ok(())
2974}
2975
2976/// Get the list of channel members
2977async fn get_channel_members(
2978    request: proto::GetChannelMembers,
2979    response: Response<proto::GetChannelMembers>,
2980    session: Session,
2981) -> Result<()> {
2982    let db = session.db().await;
2983    let channel_id = ChannelId::from_proto(request.channel_id);
2984    let limit = if request.limit == 0 {
2985        u16::MAX as u64
2986    } else {
2987        request.limit
2988    };
2989    let (members, users) = db
2990        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
2991        .await?;
2992    response.send(proto::GetChannelMembersResponse { members, users })?;
2993    Ok(())
2994}
2995
2996/// Accept or decline a channel invitation.
2997async fn respond_to_channel_invite(
2998    request: proto::RespondToChannelInvite,
2999    response: Response<proto::RespondToChannelInvite>,
3000    session: Session,
3001) -> Result<()> {
3002    let db = session.db().await;
3003    let channel_id = ChannelId::from_proto(request.channel_id);
3004    let RespondToChannelInvite {
3005        membership_update,
3006        notifications,
3007    } = db
3008        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3009        .await?;
3010
3011    let mut connection_pool = session.connection_pool().await;
3012    if let Some(membership_update) = membership_update {
3013        notify_membership_updated(
3014            &mut connection_pool,
3015            membership_update,
3016            session.user_id(),
3017            &session.peer,
3018        );
3019    } else {
3020        let update = proto::UpdateChannels {
3021            remove_channel_invitations: vec![channel_id.to_proto()],
3022            ..Default::default()
3023        };
3024
3025        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3026            session.peer.send(connection_id, update.clone())?;
3027        }
3028    };
3029
3030    send_notifications(&connection_pool, &session.peer, notifications);
3031
3032    response.send(proto::Ack {})?;
3033
3034    Ok(())
3035}
3036
3037/// Join the channels' room
3038async fn join_channel(
3039    request: proto::JoinChannel,
3040    response: Response<proto::JoinChannel>,
3041    session: Session,
3042) -> Result<()> {
3043    let channel_id = ChannelId::from_proto(request.channel_id);
3044    join_channel_internal(channel_id, Box::new(response), session).await
3045}
3046
3047trait JoinChannelInternalResponse {
3048    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3049}
3050impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3051    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3052        Response::<proto::JoinChannel>::send(self, result)
3053    }
3054}
3055impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3056    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3057        Response::<proto::JoinRoom>::send(self, result)
3058    }
3059}
3060
3061async fn join_channel_internal(
3062    channel_id: ChannelId,
3063    response: Box<impl JoinChannelInternalResponse>,
3064    session: Session,
3065) -> Result<()> {
3066    let joined_room = {
3067        let mut db = session.db().await;
3068        // If zed quits without leaving the room, and the user re-opens zed before the
3069        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3070        // room they were in.
3071        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3072            tracing::info!(
3073                stale_connection_id = %connection,
3074                "cleaning up stale connection",
3075            );
3076            drop(db);
3077            leave_room_for_session(&session, connection).await?;
3078            db = session.db().await;
3079        }
3080
3081        let (joined_room, membership_updated, role) = db
3082            .join_channel(channel_id, session.user_id(), session.connection_id)
3083            .await?;
3084
3085        let live_kit_connection_info =
3086            session
3087                .app_state
3088                .live_kit_client
3089                .as_ref()
3090                .and_then(|live_kit| {
3091                    let (can_publish, token) = if role == ChannelRole::Guest {
3092                        (
3093                            false,
3094                            live_kit
3095                                .guest_token(
3096                                    &joined_room.room.live_kit_room,
3097                                    &session.user_id().to_string(),
3098                                )
3099                                .trace_err()?,
3100                        )
3101                    } else {
3102                        (
3103                            true,
3104                            live_kit
3105                                .room_token(
3106                                    &joined_room.room.live_kit_room,
3107                                    &session.user_id().to_string(),
3108                                )
3109                                .trace_err()?,
3110                        )
3111                    };
3112
3113                    Some(LiveKitConnectionInfo {
3114                        server_url: live_kit.url().into(),
3115                        token,
3116                        can_publish,
3117                    })
3118                });
3119
3120        response.send(proto::JoinRoomResponse {
3121            room: Some(joined_room.room.clone()),
3122            channel_id: joined_room
3123                .channel
3124                .as_ref()
3125                .map(|channel| channel.id.to_proto()),
3126            live_kit_connection_info,
3127        })?;
3128
3129        let mut connection_pool = session.connection_pool().await;
3130        if let Some(membership_updated) = membership_updated {
3131            notify_membership_updated(
3132                &mut connection_pool,
3133                membership_updated,
3134                session.user_id(),
3135                &session.peer,
3136            );
3137        }
3138
3139        room_updated(&joined_room.room, &session.peer);
3140
3141        joined_room
3142    };
3143
3144    channel_updated(
3145        &joined_room
3146            .channel
3147            .ok_or_else(|| anyhow!("channel not returned"))?,
3148        &joined_room.room,
3149        &session.peer,
3150        &*session.connection_pool().await,
3151    );
3152
3153    update_user_contacts(session.user_id(), &session).await?;
3154    Ok(())
3155}
3156
3157/// Start editing the channel notes
3158async fn join_channel_buffer(
3159    request: proto::JoinChannelBuffer,
3160    response: Response<proto::JoinChannelBuffer>,
3161    session: Session,
3162) -> Result<()> {
3163    let db = session.db().await;
3164    let channel_id = ChannelId::from_proto(request.channel_id);
3165
3166    let open_response = db
3167        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3168        .await?;
3169
3170    let collaborators = open_response.collaborators.clone();
3171    response.send(open_response)?;
3172
3173    let update = UpdateChannelBufferCollaborators {
3174        channel_id: channel_id.to_proto(),
3175        collaborators: collaborators.clone(),
3176    };
3177    channel_buffer_updated(
3178        session.connection_id,
3179        collaborators
3180            .iter()
3181            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3182        &update,
3183        &session.peer,
3184    );
3185
3186    Ok(())
3187}
3188
3189/// Edit the channel notes
3190async fn update_channel_buffer(
3191    request: proto::UpdateChannelBuffer,
3192    session: Session,
3193) -> Result<()> {
3194    let db = session.db().await;
3195    let channel_id = ChannelId::from_proto(request.channel_id);
3196
3197    let (collaborators, epoch, version) = db
3198        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3199        .await?;
3200
3201    channel_buffer_updated(
3202        session.connection_id,
3203        collaborators.clone(),
3204        &proto::UpdateChannelBuffer {
3205            channel_id: channel_id.to_proto(),
3206            operations: request.operations,
3207        },
3208        &session.peer,
3209    );
3210
3211    let pool = &*session.connection_pool().await;
3212
3213    let non_collaborators =
3214        pool.channel_connection_ids(channel_id)
3215            .filter_map(|(connection_id, _)| {
3216                if collaborators.contains(&connection_id) {
3217                    None
3218                } else {
3219                    Some(connection_id)
3220                }
3221            });
3222
3223    broadcast(None, non_collaborators, |peer_id| {
3224        session.peer.send(
3225            peer_id,
3226            proto::UpdateChannels {
3227                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3228                    channel_id: channel_id.to_proto(),
3229                    epoch: epoch as u64,
3230                    version: version.clone(),
3231                }],
3232                ..Default::default()
3233            },
3234        )
3235    });
3236
3237    Ok(())
3238}
3239
3240/// Rejoin the channel notes after a connection blip
3241async fn rejoin_channel_buffers(
3242    request: proto::RejoinChannelBuffers,
3243    response: Response<proto::RejoinChannelBuffers>,
3244    session: Session,
3245) -> Result<()> {
3246    let db = session.db().await;
3247    let buffers = db
3248        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3249        .await?;
3250
3251    for rejoined_buffer in &buffers {
3252        let collaborators_to_notify = rejoined_buffer
3253            .buffer
3254            .collaborators
3255            .iter()
3256            .filter_map(|c| Some(c.peer_id?.into()));
3257        channel_buffer_updated(
3258            session.connection_id,
3259            collaborators_to_notify,
3260            &proto::UpdateChannelBufferCollaborators {
3261                channel_id: rejoined_buffer.buffer.channel_id,
3262                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3263            },
3264            &session.peer,
3265        );
3266    }
3267
3268    response.send(proto::RejoinChannelBuffersResponse {
3269        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3270    })?;
3271
3272    Ok(())
3273}
3274
3275/// Stop editing the channel notes
3276async fn leave_channel_buffer(
3277    request: proto::LeaveChannelBuffer,
3278    response: Response<proto::LeaveChannelBuffer>,
3279    session: Session,
3280) -> Result<()> {
3281    let db = session.db().await;
3282    let channel_id = ChannelId::from_proto(request.channel_id);
3283
3284    let left_buffer = db
3285        .leave_channel_buffer(channel_id, session.connection_id)
3286        .await?;
3287
3288    response.send(Ack {})?;
3289
3290    channel_buffer_updated(
3291        session.connection_id,
3292        left_buffer.connections,
3293        &proto::UpdateChannelBufferCollaborators {
3294            channel_id: channel_id.to_proto(),
3295            collaborators: left_buffer.collaborators,
3296        },
3297        &session.peer,
3298    );
3299
3300    Ok(())
3301}
3302
3303fn channel_buffer_updated<T: EnvelopedMessage>(
3304    sender_id: ConnectionId,
3305    collaborators: impl IntoIterator<Item = ConnectionId>,
3306    message: &T,
3307    peer: &Peer,
3308) {
3309    broadcast(Some(sender_id), collaborators, |peer_id| {
3310        peer.send(peer_id, message.clone())
3311    });
3312}
3313
3314fn send_notifications(
3315    connection_pool: &ConnectionPool,
3316    peer: &Peer,
3317    notifications: db::NotificationBatch,
3318) {
3319    for (user_id, notification) in notifications {
3320        for connection_id in connection_pool.user_connection_ids(user_id) {
3321            if let Err(error) = peer.send(
3322                connection_id,
3323                proto::AddNotification {
3324                    notification: Some(notification.clone()),
3325                },
3326            ) {
3327                tracing::error!(
3328                    "failed to send notification to {:?} {}",
3329                    connection_id,
3330                    error
3331                );
3332            }
3333        }
3334    }
3335}
3336
3337/// Send a message to the channel
3338async fn send_channel_message(
3339    request: proto::SendChannelMessage,
3340    response: Response<proto::SendChannelMessage>,
3341    session: Session,
3342) -> Result<()> {
3343    // Validate the message body.
3344    let body = request.body.trim().to_string();
3345    if body.len() > MAX_MESSAGE_LEN {
3346        return Err(anyhow!("message is too long"))?;
3347    }
3348    if body.is_empty() {
3349        return Err(anyhow!("message can't be blank"))?;
3350    }
3351
3352    // TODO: adjust mentions if body is trimmed
3353
3354    let timestamp = OffsetDateTime::now_utc();
3355    let nonce = request
3356        .nonce
3357        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3358
3359    let channel_id = ChannelId::from_proto(request.channel_id);
3360    let CreatedChannelMessage {
3361        message_id,
3362        participant_connection_ids,
3363        notifications,
3364    } = session
3365        .db()
3366        .await
3367        .create_channel_message(
3368            channel_id,
3369            session.user_id(),
3370            &body,
3371            &request.mentions,
3372            timestamp,
3373            nonce.clone().into(),
3374            request.reply_to_message_id.map(MessageId::from_proto),
3375        )
3376        .await?;
3377
3378    let message = proto::ChannelMessage {
3379        sender_id: session.user_id().to_proto(),
3380        id: message_id.to_proto(),
3381        body,
3382        mentions: request.mentions,
3383        timestamp: timestamp.unix_timestamp() as u64,
3384        nonce: Some(nonce),
3385        reply_to_message_id: request.reply_to_message_id,
3386        edited_at: None,
3387    };
3388    broadcast(
3389        Some(session.connection_id),
3390        participant_connection_ids.clone(),
3391        |connection| {
3392            session.peer.send(
3393                connection,
3394                proto::ChannelMessageSent {
3395                    channel_id: channel_id.to_proto(),
3396                    message: Some(message.clone()),
3397                },
3398            )
3399        },
3400    );
3401    response.send(proto::SendChannelMessageResponse {
3402        message: Some(message),
3403    })?;
3404
3405    let pool = &*session.connection_pool().await;
3406    let non_participants =
3407        pool.channel_connection_ids(channel_id)
3408            .filter_map(|(connection_id, _)| {
3409                if participant_connection_ids.contains(&connection_id) {
3410                    None
3411                } else {
3412                    Some(connection_id)
3413                }
3414            });
3415    broadcast(None, non_participants, |peer_id| {
3416        session.peer.send(
3417            peer_id,
3418            proto::UpdateChannels {
3419                latest_channel_message_ids: vec![proto::ChannelMessageId {
3420                    channel_id: channel_id.to_proto(),
3421                    message_id: message_id.to_proto(),
3422                }],
3423                ..Default::default()
3424            },
3425        )
3426    });
3427    send_notifications(pool, &session.peer, notifications);
3428
3429    Ok(())
3430}
3431
3432/// Delete a channel message
3433async fn remove_channel_message(
3434    request: proto::RemoveChannelMessage,
3435    response: Response<proto::RemoveChannelMessage>,
3436    session: Session,
3437) -> Result<()> {
3438    let channel_id = ChannelId::from_proto(request.channel_id);
3439    let message_id = MessageId::from_proto(request.message_id);
3440    let (connection_ids, existing_notification_ids) = session
3441        .db()
3442        .await
3443        .remove_channel_message(channel_id, message_id, session.user_id())
3444        .await?;
3445
3446    broadcast(
3447        Some(session.connection_id),
3448        connection_ids,
3449        move |connection| {
3450            session.peer.send(connection, request.clone())?;
3451
3452            for notification_id in &existing_notification_ids {
3453                session.peer.send(
3454                    connection,
3455                    proto::DeleteNotification {
3456                        notification_id: (*notification_id).to_proto(),
3457                    },
3458                )?;
3459            }
3460
3461            Ok(())
3462        },
3463    );
3464    response.send(proto::Ack {})?;
3465    Ok(())
3466}
3467
3468async fn update_channel_message(
3469    request: proto::UpdateChannelMessage,
3470    response: Response<proto::UpdateChannelMessage>,
3471    session: Session,
3472) -> Result<()> {
3473    let channel_id = ChannelId::from_proto(request.channel_id);
3474    let message_id = MessageId::from_proto(request.message_id);
3475    let updated_at = OffsetDateTime::now_utc();
3476    let UpdatedChannelMessage {
3477        message_id,
3478        participant_connection_ids,
3479        notifications,
3480        reply_to_message_id,
3481        timestamp,
3482        deleted_mention_notification_ids,
3483        updated_mention_notifications,
3484    } = session
3485        .db()
3486        .await
3487        .update_channel_message(
3488            channel_id,
3489            message_id,
3490            session.user_id(),
3491            request.body.as_str(),
3492            &request.mentions,
3493            updated_at,
3494        )
3495        .await?;
3496
3497    let nonce = request
3498        .nonce
3499        .clone()
3500        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3501
3502    let message = proto::ChannelMessage {
3503        sender_id: session.user_id().to_proto(),
3504        id: message_id.to_proto(),
3505        body: request.body.clone(),
3506        mentions: request.mentions.clone(),
3507        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3508        nonce: Some(nonce),
3509        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3510        edited_at: Some(updated_at.unix_timestamp() as u64),
3511    };
3512
3513    response.send(proto::Ack {})?;
3514
3515    let pool = &*session.connection_pool().await;
3516    broadcast(
3517        Some(session.connection_id),
3518        participant_connection_ids,
3519        |connection| {
3520            session.peer.send(
3521                connection,
3522                proto::ChannelMessageUpdate {
3523                    channel_id: channel_id.to_proto(),
3524                    message: Some(message.clone()),
3525                },
3526            )?;
3527
3528            for notification_id in &deleted_mention_notification_ids {
3529                session.peer.send(
3530                    connection,
3531                    proto::DeleteNotification {
3532                        notification_id: (*notification_id).to_proto(),
3533                    },
3534                )?;
3535            }
3536
3537            for notification in &updated_mention_notifications {
3538                session.peer.send(
3539                    connection,
3540                    proto::UpdateNotification {
3541                        notification: Some(notification.clone()),
3542                    },
3543                )?;
3544            }
3545
3546            Ok(())
3547        },
3548    );
3549
3550    send_notifications(pool, &session.peer, notifications);
3551
3552    Ok(())
3553}
3554
3555/// Mark a channel message as read
3556async fn acknowledge_channel_message(
3557    request: proto::AckChannelMessage,
3558    session: Session,
3559) -> Result<()> {
3560    let channel_id = ChannelId::from_proto(request.channel_id);
3561    let message_id = MessageId::from_proto(request.message_id);
3562    let notifications = session
3563        .db()
3564        .await
3565        .observe_channel_message(channel_id, session.user_id(), message_id)
3566        .await?;
3567    send_notifications(
3568        &*session.connection_pool().await,
3569        &session.peer,
3570        notifications,
3571    );
3572    Ok(())
3573}
3574
3575/// Mark a buffer version as synced
3576async fn acknowledge_buffer_version(
3577    request: proto::AckBufferOperation,
3578    session: Session,
3579) -> Result<()> {
3580    let buffer_id = BufferId::from_proto(request.buffer_id);
3581    session
3582        .db()
3583        .await
3584        .observe_buffer_version(
3585            buffer_id,
3586            session.user_id(),
3587            request.epoch as i32,
3588            &request.version,
3589        )
3590        .await?;
3591    Ok(())
3592}
3593
3594async fn count_language_model_tokens(
3595    request: proto::CountLanguageModelTokens,
3596    response: Response<proto::CountLanguageModelTokens>,
3597    session: Session,
3598    config: &Config,
3599) -> Result<()> {
3600    authorize_access_to_legacy_llm_endpoints(&session).await?;
3601
3602    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3603        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3604        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3605    };
3606
3607    session
3608        .app_state
3609        .rate_limiter
3610        .check(&*rate_limit, session.user_id())
3611        .await?;
3612
3613    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3614        Some(proto::LanguageModelProvider::Google) => {
3615            let api_key = config
3616                .google_ai_api_key
3617                .as_ref()
3618                .context("no Google AI API key configured on the server")?;
3619            google_ai::count_tokens(
3620                session.http_client.as_ref(),
3621                google_ai::API_URL,
3622                api_key,
3623                serde_json::from_str(&request.request)?,
3624            )
3625            .await?
3626        }
3627        _ => return Err(anyhow!("unsupported provider"))?,
3628    };
3629
3630    response.send(proto::CountLanguageModelTokensResponse {
3631        token_count: result.total_tokens as u32,
3632    })?;
3633
3634    Ok(())
3635}
3636
3637struct ZedProCountLanguageModelTokensRateLimit;
3638
3639impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3640    fn capacity(&self) -> usize {
3641        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3642            .ok()
3643            .and_then(|v| v.parse().ok())
3644            .unwrap_or(600) // Picked arbitrarily
3645    }
3646
3647    fn refill_duration(&self) -> chrono::Duration {
3648        chrono::Duration::hours(1)
3649    }
3650
3651    fn db_name(&self) -> &'static str {
3652        "zed-pro:count-language-model-tokens"
3653    }
3654}
3655
3656struct FreeCountLanguageModelTokensRateLimit;
3657
3658impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3659    fn capacity(&self) -> usize {
3660        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3661            .ok()
3662            .and_then(|v| v.parse().ok())
3663            .unwrap_or(600 / 10) // Picked arbitrarily
3664    }
3665
3666    fn refill_duration(&self) -> chrono::Duration {
3667        chrono::Duration::hours(1)
3668    }
3669
3670    fn db_name(&self) -> &'static str {
3671        "free:count-language-model-tokens"
3672    }
3673}
3674
3675struct ZedProComputeEmbeddingsRateLimit;
3676
3677impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3678    fn capacity(&self) -> usize {
3679        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3680            .ok()
3681            .and_then(|v| v.parse().ok())
3682            .unwrap_or(5000) // Picked arbitrarily
3683    }
3684
3685    fn refill_duration(&self) -> chrono::Duration {
3686        chrono::Duration::hours(1)
3687    }
3688
3689    fn db_name(&self) -> &'static str {
3690        "zed-pro:compute-embeddings"
3691    }
3692}
3693
3694struct FreeComputeEmbeddingsRateLimit;
3695
3696impl RateLimit for FreeComputeEmbeddingsRateLimit {
3697    fn capacity(&self) -> usize {
3698        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3699            .ok()
3700            .and_then(|v| v.parse().ok())
3701            .unwrap_or(5000 / 10) // Picked arbitrarily
3702    }
3703
3704    fn refill_duration(&self) -> chrono::Duration {
3705        chrono::Duration::hours(1)
3706    }
3707
3708    fn db_name(&self) -> &'static str {
3709        "free:compute-embeddings"
3710    }
3711}
3712
3713async fn compute_embeddings(
3714    request: proto::ComputeEmbeddings,
3715    response: Response<proto::ComputeEmbeddings>,
3716    session: Session,
3717    api_key: Option<Arc<str>>,
3718) -> Result<()> {
3719    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3720    authorize_access_to_legacy_llm_endpoints(&session).await?;
3721
3722    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3723        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3724        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3725    };
3726
3727    session
3728        .app_state
3729        .rate_limiter
3730        .check(&*rate_limit, session.user_id())
3731        .await?;
3732
3733    let embeddings = match request.model.as_str() {
3734        "openai/text-embedding-3-small" => {
3735            open_ai::embed(
3736                session.http_client.as_ref(),
3737                OPEN_AI_API_URL,
3738                &api_key,
3739                OpenAiEmbeddingModel::TextEmbedding3Small,
3740                request.texts.iter().map(|text| text.as_str()),
3741            )
3742            .await?
3743        }
3744        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3745    };
3746
3747    let embeddings = request
3748        .texts
3749        .iter()
3750        .map(|text| {
3751            let mut hasher = sha2::Sha256::new();
3752            hasher.update(text.as_bytes());
3753            let result = hasher.finalize();
3754            result.to_vec()
3755        })
3756        .zip(
3757            embeddings
3758                .data
3759                .into_iter()
3760                .map(|embedding| embedding.embedding),
3761        )
3762        .collect::<HashMap<_, _>>();
3763
3764    let db = session.db().await;
3765    db.save_embeddings(&request.model, &embeddings)
3766        .await
3767        .context("failed to save embeddings")
3768        .trace_err();
3769
3770    response.send(proto::ComputeEmbeddingsResponse {
3771        embeddings: embeddings
3772            .into_iter()
3773            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3774            .collect(),
3775    })?;
3776    Ok(())
3777}
3778
3779async fn get_cached_embeddings(
3780    request: proto::GetCachedEmbeddings,
3781    response: Response<proto::GetCachedEmbeddings>,
3782    session: Session,
3783) -> Result<()> {
3784    authorize_access_to_legacy_llm_endpoints(&session).await?;
3785
3786    let db = session.db().await;
3787    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3788
3789    response.send(proto::GetCachedEmbeddingsResponse {
3790        embeddings: embeddings
3791            .into_iter()
3792            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3793            .collect(),
3794    })?;
3795    Ok(())
3796}
3797
3798/// This is leftover from before the LLM service.
3799///
3800/// The endpoints protected by this check will be moved there eventually.
3801async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3802    if session.is_staff() {
3803        Ok(())
3804    } else {
3805        Err(anyhow!("permission denied"))?
3806    }
3807}
3808
3809/// Get a Supermaven API key for the user
3810async fn get_supermaven_api_key(
3811    _request: proto::GetSupermavenApiKey,
3812    response: Response<proto::GetSupermavenApiKey>,
3813    session: Session,
3814) -> Result<()> {
3815    let user_id: String = session.user_id().to_string();
3816    if !session.is_staff() {
3817        return Err(anyhow!("supermaven not enabled for this account"))?;
3818    }
3819
3820    let email = session
3821        .email()
3822        .ok_or_else(|| anyhow!("user must have an email"))?;
3823
3824    let supermaven_admin_api = session
3825        .supermaven_client
3826        .as_ref()
3827        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3828
3829    let result = supermaven_admin_api
3830        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3831        .await?;
3832
3833    response.send(proto::GetSupermavenApiKeyResponse {
3834        api_key: result.api_key,
3835    })?;
3836
3837    Ok(())
3838}
3839
3840/// Start receiving chat updates for a channel
3841async fn join_channel_chat(
3842    request: proto::JoinChannelChat,
3843    response: Response<proto::JoinChannelChat>,
3844    session: Session,
3845) -> Result<()> {
3846    let channel_id = ChannelId::from_proto(request.channel_id);
3847
3848    let db = session.db().await;
3849    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3850        .await?;
3851    let messages = db
3852        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3853        .await?;
3854    response.send(proto::JoinChannelChatResponse {
3855        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3856        messages,
3857    })?;
3858    Ok(())
3859}
3860
3861/// Stop receiving chat updates for a channel
3862async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3863    let channel_id = ChannelId::from_proto(request.channel_id);
3864    session
3865        .db()
3866        .await
3867        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3868        .await?;
3869    Ok(())
3870}
3871
3872/// Retrieve the chat history for a channel
3873async fn get_channel_messages(
3874    request: proto::GetChannelMessages,
3875    response: Response<proto::GetChannelMessages>,
3876    session: Session,
3877) -> Result<()> {
3878    let channel_id = ChannelId::from_proto(request.channel_id);
3879    let messages = session
3880        .db()
3881        .await
3882        .get_channel_messages(
3883            channel_id,
3884            session.user_id(),
3885            MESSAGE_COUNT_PER_PAGE,
3886            Some(MessageId::from_proto(request.before_message_id)),
3887        )
3888        .await?;
3889    response.send(proto::GetChannelMessagesResponse {
3890        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3891        messages,
3892    })?;
3893    Ok(())
3894}
3895
3896/// Retrieve specific chat messages
3897async fn get_channel_messages_by_id(
3898    request: proto::GetChannelMessagesById,
3899    response: Response<proto::GetChannelMessagesById>,
3900    session: Session,
3901) -> Result<()> {
3902    let message_ids = request
3903        .message_ids
3904        .iter()
3905        .map(|id| MessageId::from_proto(*id))
3906        .collect::<Vec<_>>();
3907    let messages = session
3908        .db()
3909        .await
3910        .get_channel_messages_by_id(session.user_id(), &message_ids)
3911        .await?;
3912    response.send(proto::GetChannelMessagesResponse {
3913        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914        messages,
3915    })?;
3916    Ok(())
3917}
3918
3919/// Retrieve the current users notifications
3920async fn get_notifications(
3921    request: proto::GetNotifications,
3922    response: Response<proto::GetNotifications>,
3923    session: Session,
3924) -> Result<()> {
3925    let notifications = session
3926        .db()
3927        .await
3928        .get_notifications(
3929            session.user_id(),
3930            NOTIFICATION_COUNT_PER_PAGE,
3931            request.before_id.map(db::NotificationId::from_proto),
3932        )
3933        .await?;
3934    response.send(proto::GetNotificationsResponse {
3935        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3936        notifications,
3937    })?;
3938    Ok(())
3939}
3940
3941/// Mark notifications as read
3942async fn mark_notification_as_read(
3943    request: proto::MarkNotificationRead,
3944    response: Response<proto::MarkNotificationRead>,
3945    session: Session,
3946) -> Result<()> {
3947    let database = &session.db().await;
3948    let notifications = database
3949        .mark_notification_as_read_by_id(
3950            session.user_id(),
3951            NotificationId::from_proto(request.notification_id),
3952        )
3953        .await?;
3954    send_notifications(
3955        &*session.connection_pool().await,
3956        &session.peer,
3957        notifications,
3958    );
3959    response.send(proto::Ack {})?;
3960    Ok(())
3961}
3962
3963/// Get the current users information
3964async fn get_private_user_info(
3965    _request: proto::GetPrivateUserInfo,
3966    response: Response<proto::GetPrivateUserInfo>,
3967    session: Session,
3968) -> Result<()> {
3969    let db = session.db().await;
3970
3971    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3972    let user = db
3973        .get_user_by_id(session.user_id())
3974        .await?
3975        .ok_or_else(|| anyhow!("user not found"))?;
3976    let flags = db.get_user_flags(session.user_id()).await?;
3977
3978    response.send(proto::GetPrivateUserInfoResponse {
3979        metrics_id,
3980        staff: user.admin,
3981        flags,
3982        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3983    })?;
3984    Ok(())
3985}
3986
3987/// Accept the terms of service (tos) on behalf of the current user
3988async fn accept_terms_of_service(
3989    _request: proto::AcceptTermsOfService,
3990    response: Response<proto::AcceptTermsOfService>,
3991    session: Session,
3992) -> Result<()> {
3993    let db = session.db().await;
3994
3995    let accepted_tos_at = Utc::now();
3996    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3997        .await?;
3998
3999    response.send(proto::AcceptTermsOfServiceResponse {
4000        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4001    })?;
4002    Ok(())
4003}
4004
4005/// The minimum account age an account must have in order to use the LLM service.
4006const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4007
4008async fn get_llm_api_token(
4009    _request: proto::GetLlmToken,
4010    response: Response<proto::GetLlmToken>,
4011    session: Session,
4012) -> Result<()> {
4013    let db = session.db().await;
4014
4015    let flags = db.get_user_flags(session.user_id()).await?;
4016    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4017    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4018
4019    if !session.is_staff() && !has_language_models_feature_flag {
4020        Err(anyhow!("permission denied"))?
4021    }
4022
4023    let user_id = session.user_id();
4024    let user = db
4025        .get_user_by_id(user_id)
4026        .await?
4027        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4028
4029    if user.accepted_tos_at.is_none() {
4030        Err(anyhow!("terms of service not accepted"))?
4031    }
4032
4033    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4034    if !has_llm_subscription {
4035        let mut account_created_at = user.created_at;
4036        if let Some(github_created_at) = user.github_user_created_at {
4037            account_created_at = account_created_at.min(github_created_at);
4038        }
4039        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4040            Err(anyhow!("account too young"))?
4041        }
4042    }
4043
4044    let billing_preferences = db.get_billing_preferences(user.id).await?;
4045
4046    let token = LlmTokenClaims::create(
4047        &user,
4048        session.is_staff(),
4049        billing_preferences,
4050        has_llm_closed_beta_feature_flag,
4051        has_llm_subscription,
4052        session.current_plan(&db).await?,
4053        &session.app_state.config,
4054    )?;
4055    response.send(proto::GetLlmTokenResponse { token })?;
4056    Ok(())
4057}
4058
4059fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4060    let message = match message {
4061        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4062        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4063        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4064        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4065        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4066            code: frame.code.into(),
4067            reason: frame.reason,
4068        })),
4069        // We should never receive a frame while reading the message, according
4070        // to the `tungstenite` maintainers:
4071        //
4072        // > It cannot occur when you read messages from the WebSocket, but it
4073        // > can be used when you want to send the raw frames (e.g. you want to
4074        // > send the frames to the WebSocket without composing the full message first).
4075        // >
4076        // > — https://github.com/snapview/tungstenite-rs/issues/268
4077        TungsteniteMessage::Frame(_) => {
4078            bail!("received an unexpected frame while reading the message")
4079        }
4080    };
4081
4082    Ok(message)
4083}
4084
4085fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4086    match message {
4087        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4088        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4089        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4090        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4091        AxumMessage::Close(frame) => {
4092            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4093                code: frame.code.into(),
4094                reason: frame.reason,
4095            }))
4096        }
4097    }
4098}
4099
4100fn notify_membership_updated(
4101    connection_pool: &mut ConnectionPool,
4102    result: MembershipUpdated,
4103    user_id: UserId,
4104    peer: &Peer,
4105) {
4106    for membership in &result.new_channels.channel_memberships {
4107        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4108    }
4109    for channel_id in &result.removed_channels {
4110        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4111    }
4112
4113    let user_channels_update = proto::UpdateUserChannels {
4114        channel_memberships: result
4115            .new_channels
4116            .channel_memberships
4117            .iter()
4118            .map(|cm| proto::ChannelMembership {
4119                channel_id: cm.channel_id.to_proto(),
4120                role: cm.role.into(),
4121            })
4122            .collect(),
4123        ..Default::default()
4124    };
4125
4126    let mut update = build_channels_update(result.new_channels);
4127    update.delete_channels = result
4128        .removed_channels
4129        .into_iter()
4130        .map(|id| id.to_proto())
4131        .collect();
4132    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4133
4134    for connection_id in connection_pool.user_connection_ids(user_id) {
4135        peer.send(connection_id, user_channels_update.clone())
4136            .trace_err();
4137        peer.send(connection_id, update.clone()).trace_err();
4138    }
4139}
4140
4141fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4142    proto::UpdateUserChannels {
4143        channel_memberships: channels
4144            .channel_memberships
4145            .iter()
4146            .map(|m| proto::ChannelMembership {
4147                channel_id: m.channel_id.to_proto(),
4148                role: m.role.into(),
4149            })
4150            .collect(),
4151        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4152        observed_channel_message_id: channels.observed_channel_messages.clone(),
4153    }
4154}
4155
4156fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4157    let mut update = proto::UpdateChannels::default();
4158
4159    for channel in channels.channels {
4160        update.channels.push(channel.to_proto());
4161    }
4162
4163    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4164    update.latest_channel_message_ids = channels.latest_channel_messages;
4165
4166    for (channel_id, participants) in channels.channel_participants {
4167        update
4168            .channel_participants
4169            .push(proto::ChannelParticipants {
4170                channel_id: channel_id.to_proto(),
4171                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4172            });
4173    }
4174
4175    for channel in channels.invited_channels {
4176        update.channel_invitations.push(channel.to_proto());
4177    }
4178
4179    update
4180}
4181
4182fn build_initial_contacts_update(
4183    contacts: Vec<db::Contact>,
4184    pool: &ConnectionPool,
4185) -> proto::UpdateContacts {
4186    let mut update = proto::UpdateContacts::default();
4187
4188    for contact in contacts {
4189        match contact {
4190            db::Contact::Accepted { user_id, busy } => {
4191                update.contacts.push(contact_for_user(user_id, busy, pool));
4192            }
4193            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4194            db::Contact::Incoming { user_id } => {
4195                update
4196                    .incoming_requests
4197                    .push(proto::IncomingContactRequest {
4198                        requester_id: user_id.to_proto(),
4199                    })
4200            }
4201        }
4202    }
4203
4204    update
4205}
4206
4207fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4208    proto::Contact {
4209        user_id: user_id.to_proto(),
4210        online: pool.is_user_online(user_id),
4211        busy,
4212    }
4213}
4214
4215fn room_updated(room: &proto::Room, peer: &Peer) {
4216    broadcast(
4217        None,
4218        room.participants
4219            .iter()
4220            .filter_map(|participant| Some(participant.peer_id?.into())),
4221        |peer_id| {
4222            peer.send(
4223                peer_id,
4224                proto::RoomUpdated {
4225                    room: Some(room.clone()),
4226                },
4227            )
4228        },
4229    );
4230}
4231
4232fn channel_updated(
4233    channel: &db::channel::Model,
4234    room: &proto::Room,
4235    peer: &Peer,
4236    pool: &ConnectionPool,
4237) {
4238    let participants = room
4239        .participants
4240        .iter()
4241        .map(|p| p.user_id)
4242        .collect::<Vec<_>>();
4243
4244    broadcast(
4245        None,
4246        pool.channel_connection_ids(channel.root_id())
4247            .filter_map(|(channel_id, role)| {
4248                role.can_see_channel(channel.visibility)
4249                    .then_some(channel_id)
4250            }),
4251        |peer_id| {
4252            peer.send(
4253                peer_id,
4254                proto::UpdateChannels {
4255                    channel_participants: vec![proto::ChannelParticipants {
4256                        channel_id: channel.id.to_proto(),
4257                        participant_user_ids: participants.clone(),
4258                    }],
4259                    ..Default::default()
4260                },
4261            )
4262        },
4263    );
4264}
4265
4266async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4267    let db = session.db().await;
4268
4269    let contacts = db.get_contacts(user_id).await?;
4270    let busy = db.is_user_busy(user_id).await?;
4271
4272    let pool = session.connection_pool().await;
4273    let updated_contact = contact_for_user(user_id, busy, &pool);
4274    for contact in contacts {
4275        if let db::Contact::Accepted {
4276            user_id: contact_user_id,
4277            ..
4278        } = contact
4279        {
4280            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4281                session
4282                    .peer
4283                    .send(
4284                        contact_conn_id,
4285                        proto::UpdateContacts {
4286                            contacts: vec![updated_contact.clone()],
4287                            remove_contacts: Default::default(),
4288                            incoming_requests: Default::default(),
4289                            remove_incoming_requests: Default::default(),
4290                            outgoing_requests: Default::default(),
4291                            remove_outgoing_requests: Default::default(),
4292                        },
4293                    )
4294                    .trace_err();
4295            }
4296        }
4297    }
4298    Ok(())
4299}
4300
4301async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4302    let mut contacts_to_update = HashSet::default();
4303
4304    let room_id;
4305    let canceled_calls_to_user_ids;
4306    let live_kit_room;
4307    let delete_live_kit_room;
4308    let room;
4309    let channel;
4310
4311    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4312        contacts_to_update.insert(session.user_id());
4313
4314        for project in left_room.left_projects.values() {
4315            project_left(project, session);
4316        }
4317
4318        room_id = RoomId::from_proto(left_room.room.id);
4319        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4320        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4321        delete_live_kit_room = left_room.deleted;
4322        room = mem::take(&mut left_room.room);
4323        channel = mem::take(&mut left_room.channel);
4324
4325        room_updated(&room, &session.peer);
4326    } else {
4327        return Ok(());
4328    }
4329
4330    if let Some(channel) = channel {
4331        channel_updated(
4332            &channel,
4333            &room,
4334            &session.peer,
4335            &*session.connection_pool().await,
4336        );
4337    }
4338
4339    {
4340        let pool = session.connection_pool().await;
4341        for canceled_user_id in canceled_calls_to_user_ids {
4342            for connection_id in pool.user_connection_ids(canceled_user_id) {
4343                session
4344                    .peer
4345                    .send(
4346                        connection_id,
4347                        proto::CallCanceled {
4348                            room_id: room_id.to_proto(),
4349                        },
4350                    )
4351                    .trace_err();
4352            }
4353            contacts_to_update.insert(canceled_user_id);
4354        }
4355    }
4356
4357    for contact_user_id in contacts_to_update {
4358        update_user_contacts(contact_user_id, session).await?;
4359    }
4360
4361    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
4362        live_kit
4363            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4364            .await
4365            .trace_err();
4366
4367        if delete_live_kit_room {
4368            live_kit.delete_room(live_kit_room).await.trace_err();
4369        }
4370    }
4371
4372    Ok(())
4373}
4374
4375async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4376    let left_channel_buffers = session
4377        .db()
4378        .await
4379        .leave_channel_buffers(session.connection_id)
4380        .await?;
4381
4382    for left_buffer in left_channel_buffers {
4383        channel_buffer_updated(
4384            session.connection_id,
4385            left_buffer.connections,
4386            &proto::UpdateChannelBufferCollaborators {
4387                channel_id: left_buffer.channel_id.to_proto(),
4388                collaborators: left_buffer.collaborators,
4389            },
4390            &session.peer,
4391        );
4392    }
4393
4394    Ok(())
4395}
4396
4397fn project_left(project: &db::LeftProject, session: &Session) {
4398    for connection_id in &project.connection_ids {
4399        if project.should_unshare {
4400            session
4401                .peer
4402                .send(
4403                    *connection_id,
4404                    proto::UnshareProject {
4405                        project_id: project.id.to_proto(),
4406                    },
4407                )
4408                .trace_err();
4409        } else {
4410            session
4411                .peer
4412                .send(
4413                    *connection_id,
4414                    proto::RemoveProjectCollaborator {
4415                        project_id: project.id.to_proto(),
4416                        peer_id: Some(session.connection_id.into()),
4417                    },
4418                )
4419                .trace_err();
4420        }
4421    }
4422}
4423
4424pub trait ResultExt {
4425    type Ok;
4426
4427    fn trace_err(self) -> Option<Self::Ok>;
4428}
4429
4430impl<T, E> ResultExt for Result<T, E>
4431where
4432    E: std::fmt::Debug,
4433{
4434    type Ok = T;
4435
4436    #[track_caller]
4437    fn trace_err(self) -> Option<T> {
4438        match self {
4439            Ok(value) => Some(value),
4440            Err(error) => {
4441                tracing::error!("{:?}", error);
4442                None
4443            }
4444        }
4445    }
4446}