rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 313            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 318            .add_request_handler(
 319                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 320            )
 321            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 322            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 323            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 324            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 325            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 326            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 327            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 328            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 329            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 330            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 332            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 333            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 334            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 335            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 336            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 337            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 338            .add_message_handler(create_buffer_for_peer)
 339            .add_request_handler(update_buffer)
 340            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 341            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 342            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 343            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 344            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 345            .add_request_handler(get_users)
 346            .add_request_handler(fuzzy_search_users)
 347            .add_request_handler(request_contact)
 348            .add_request_handler(remove_contact)
 349            .add_request_handler(respond_to_contact_request)
 350            .add_message_handler(subscribe_to_channels)
 351            .add_request_handler(create_channel)
 352            .add_request_handler(delete_channel)
 353            .add_request_handler(invite_channel_member)
 354            .add_request_handler(remove_channel_member)
 355            .add_request_handler(set_channel_member_role)
 356            .add_request_handler(set_channel_visibility)
 357            .add_request_handler(rename_channel)
 358            .add_request_handler(join_channel_buffer)
 359            .add_request_handler(leave_channel_buffer)
 360            .add_message_handler(update_channel_buffer)
 361            .add_request_handler(rejoin_channel_buffers)
 362            .add_request_handler(get_channel_members)
 363            .add_request_handler(respond_to_channel_invite)
 364            .add_request_handler(join_channel)
 365            .add_request_handler(join_channel_chat)
 366            .add_message_handler(leave_channel_chat)
 367            .add_request_handler(send_channel_message)
 368            .add_request_handler(remove_channel_message)
 369            .add_request_handler(update_channel_message)
 370            .add_request_handler(get_channel_messages)
 371            .add_request_handler(get_channel_messages_by_id)
 372            .add_request_handler(get_notifications)
 373            .add_request_handler(mark_notification_as_read)
 374            .add_request_handler(move_channel)
 375            .add_request_handler(follow)
 376            .add_message_handler(unfollow)
 377            .add_message_handler(update_followers)
 378            .add_request_handler(get_private_user_info)
 379            .add_request_handler(get_llm_api_token)
 380            .add_request_handler(accept_terms_of_service)
 381            .add_message_handler(acknowledge_channel_message)
 382            .add_message_handler(acknowledge_buffer_version)
 383            .add_request_handler(get_supermaven_api_key)
 384            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 385            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 386            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 388            .add_message_handler(update_context)
 389            .add_request_handler({
 390                let app_state = app_state.clone();
 391                move |request, response, session| {
 392                    let app_state = app_state.clone();
 393                    async move {
 394                        count_language_model_tokens(request, response, session, &app_state.config)
 395                            .await
 396                    }
 397                }
 398            })
 399            .add_request_handler(get_cached_embeddings)
 400            .add_request_handler({
 401                let app_state = app_state.clone();
 402                move |request, response, session| {
 403                    compute_embeddings(
 404                        request,
 405                        response,
 406                        session,
 407                        app_state.config.openai_api_key.clone(),
 408                    )
 409                }
 410            });
 411
 412        Arc::new(server)
 413    }
 414
 415    pub async fn start(&self) -> Result<()> {
 416        let server_id = *self.id.lock();
 417        let app_state = self.app_state.clone();
 418        let peer = self.peer.clone();
 419        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 420        let pool = self.connection_pool.clone();
 421        let live_kit_client = self.app_state.live_kit_client.clone();
 422
 423        let span = info_span!("start server");
 424        self.app_state.executor.spawn_detached(
 425            async move {
 426                tracing::info!("waiting for cleanup timeout");
 427                timeout.await;
 428                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 429                if let Some((room_ids, channel_ids)) = app_state
 430                    .db
 431                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 432                    .await
 433                    .trace_err()
 434                {
 435                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 436                    tracing::info!(
 437                        stale_channel_buffer_count = channel_ids.len(),
 438                        "retrieved stale channel buffers"
 439                    );
 440
 441                    for channel_id in channel_ids {
 442                        if let Some(refreshed_channel_buffer) = app_state
 443                            .db
 444                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 445                            .await
 446                            .trace_err()
 447                        {
 448                            for connection_id in refreshed_channel_buffer.connection_ids {
 449                                peer.send(
 450                                    connection_id,
 451                                    proto::UpdateChannelBufferCollaborators {
 452                                        channel_id: channel_id.to_proto(),
 453                                        collaborators: refreshed_channel_buffer
 454                                            .collaborators
 455                                            .clone(),
 456                                    },
 457                                )
 458                                .trace_err();
 459                            }
 460                        }
 461                    }
 462
 463                    for room_id in room_ids {
 464                        let mut contacts_to_update = HashSet::default();
 465                        let mut canceled_calls_to_user_ids = Vec::new();
 466                        let mut live_kit_room = String::new();
 467                        let mut delete_live_kit_room = false;
 468
 469                        if let Some(mut refreshed_room) = app_state
 470                            .db
 471                            .clear_stale_room_participants(room_id, server_id)
 472                            .await
 473                            .trace_err()
 474                        {
 475                            tracing::info!(
 476                                room_id = room_id.0,
 477                                new_participant_count = refreshed_room.room.participants.len(),
 478                                "refreshed room"
 479                            );
 480                            room_updated(&refreshed_room.room, &peer);
 481                            if let Some(channel) = refreshed_room.channel.as_ref() {
 482                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 483                            }
 484                            contacts_to_update
 485                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 486                            contacts_to_update
 487                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 488                            canceled_calls_to_user_ids =
 489                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 490                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 491                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 492                        }
 493
 494                        {
 495                            let pool = pool.lock();
 496                            for canceled_user_id in canceled_calls_to_user_ids {
 497                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 498                                    peer.send(
 499                                        connection_id,
 500                                        proto::CallCanceled {
 501                                            room_id: room_id.to_proto(),
 502                                        },
 503                                    )
 504                                    .trace_err();
 505                                }
 506                            }
 507                        }
 508
 509                        for user_id in contacts_to_update {
 510                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 511                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 512                            if let Some((busy, contacts)) = busy.zip(contacts) {
 513                                let pool = pool.lock();
 514                                let updated_contact = contact_for_user(user_id, busy, &pool);
 515                                for contact in contacts {
 516                                    if let db::Contact::Accepted {
 517                                        user_id: contact_user_id,
 518                                        ..
 519                                    } = contact
 520                                    {
 521                                        for contact_conn_id in
 522                                            pool.user_connection_ids(contact_user_id)
 523                                        {
 524                                            peer.send(
 525                                                contact_conn_id,
 526                                                proto::UpdateContacts {
 527                                                    contacts: vec![updated_contact.clone()],
 528                                                    remove_contacts: Default::default(),
 529                                                    incoming_requests: Default::default(),
 530                                                    remove_incoming_requests: Default::default(),
 531                                                    outgoing_requests: Default::default(),
 532                                                    remove_outgoing_requests: Default::default(),
 533                                                },
 534                                            )
 535                                            .trace_err();
 536                                        }
 537                                    }
 538                                }
 539                            }
 540                        }
 541
 542                        if let Some(live_kit) = live_kit_client.as_ref() {
 543                            if delete_live_kit_room {
 544                                live_kit.delete_room(live_kit_room).await.trace_err();
 545                            }
 546                        }
 547                    }
 548                }
 549
 550                app_state
 551                    .db
 552                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 553                    .await
 554                    .trace_err();
 555            }
 556            .instrument(span),
 557        );
 558        Ok(())
 559    }
 560
 561    pub fn teardown(&self) {
 562        self.peer.teardown();
 563        self.connection_pool.lock().reset();
 564        let _ = self.teardown.send(true);
 565    }
 566
 567    #[cfg(test)]
 568    pub fn reset(&self, id: ServerId) {
 569        self.teardown();
 570        *self.id.lock() = id;
 571        self.peer.reset(id.0 as u32);
 572        let _ = self.teardown.send(false);
 573    }
 574
 575    #[cfg(test)]
 576    pub fn id(&self) -> ServerId {
 577        *self.id.lock()
 578    }
 579
 580    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 581    where
 582        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 583        Fut: 'static + Send + Future<Output = Result<()>>,
 584        M: EnvelopedMessage,
 585    {
 586        let prev_handler = self.handlers.insert(
 587            TypeId::of::<M>(),
 588            Box::new(move |envelope, session| {
 589                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 590                let received_at = envelope.received_at;
 591                tracing::info!("message received");
 592                let start_time = Instant::now();
 593                let future = (handler)(*envelope, session);
 594                async move {
 595                    let result = future.await;
 596                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 597                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 598                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 599                    let payload_type = M::NAME;
 600
 601                    match result {
 602                        Err(error) => {
 603                            tracing::error!(
 604                                ?error,
 605                                total_duration_ms,
 606                                processing_duration_ms,
 607                                queue_duration_ms,
 608                                payload_type,
 609                                "error handling message"
 610                            )
 611                        }
 612                        Ok(()) => tracing::info!(
 613                            total_duration_ms,
 614                            processing_duration_ms,
 615                            queue_duration_ms,
 616                            "finished handling message"
 617                        ),
 618                    }
 619                }
 620                .boxed()
 621            }),
 622        );
 623        if prev_handler.is_some() {
 624            panic!("registered a handler for the same message twice");
 625        }
 626        self
 627    }
 628
 629    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 630    where
 631        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 632        Fut: 'static + Send + Future<Output = Result<()>>,
 633        M: EnvelopedMessage,
 634    {
 635        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 636        self
 637    }
 638
 639    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 640    where
 641        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 642        Fut: Send + Future<Output = Result<()>>,
 643        M: RequestMessage,
 644    {
 645        let handler = Arc::new(handler);
 646        self.add_handler(move |envelope, session| {
 647            let receipt = envelope.receipt();
 648            let handler = handler.clone();
 649            async move {
 650                let peer = session.peer.clone();
 651                let responded = Arc::new(AtomicBool::default());
 652                let response = Response {
 653                    peer: peer.clone(),
 654                    responded: responded.clone(),
 655                    receipt,
 656                };
 657                match (handler)(envelope.payload, response, session).await {
 658                    Ok(()) => {
 659                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 660                            Ok(())
 661                        } else {
 662                            Err(anyhow!("handler did not send a response"))?
 663                        }
 664                    }
 665                    Err(error) => {
 666                        let proto_err = match &error {
 667                            Error::Internal(err) => err.to_proto(),
 668                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 669                        };
 670                        peer.respond_with_error(receipt, proto_err)?;
 671                        Err(error)
 672                    }
 673                }
 674            }
 675        })
 676    }
 677
 678    #[allow(clippy::too_many_arguments)]
 679    pub fn handle_connection(
 680        self: &Arc<Self>,
 681        connection: Connection,
 682        address: String,
 683        principal: Principal,
 684        zed_version: ZedVersion,
 685        geoip_country_code: Option<String>,
 686        system_id: Option<String>,
 687        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 688        executor: Executor,
 689    ) -> impl Future<Output = ()> {
 690        let this = self.clone();
 691        let span = info_span!("handle connection", %address,
 692            connection_id=field::Empty,
 693            user_id=field::Empty,
 694            login=field::Empty,
 695            impersonator=field::Empty,
 696            geoip_country_code=field::Empty
 697        );
 698        principal.update_span(&span);
 699        if let Some(country_code) = geoip_country_code.as_ref() {
 700            span.record("geoip_country_code", country_code);
 701        }
 702
 703        let mut teardown = self.teardown.subscribe();
 704        async move {
 705            if *teardown.borrow() {
 706                tracing::error!("server is tearing down");
 707                return
 708            }
 709            let (connection_id, handle_io, mut incoming_rx) = this
 710                .peer
 711                .add_connection(connection, {
 712                    let executor = executor.clone();
 713                    move |duration| executor.sleep(duration)
 714                });
 715            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 716
 717            tracing::info!("connection opened");
 718
 719            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 720            let http_client = match ReqwestClient::user_agent(&user_agent) {
 721                Ok(http_client) => Arc::new(http_client),
 722                Err(error) => {
 723                    tracing::error!(?error, "failed to create HTTP client");
 724                    return;
 725                }
 726            };
 727
 728            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 729                    supermaven_admin_api_key.to_string(),
 730                    http_client.clone(),
 731                )));
 732
 733            let session = Session {
 734                principal: principal.clone(),
 735                connection_id,
 736                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 737                peer: this.peer.clone(),
 738                connection_pool: this.connection_pool.clone(),
 739                app_state: this.app_state.clone(),
 740                http_client,
 741                geoip_country_code,
 742                system_id,
 743                _executor: executor.clone(),
 744                supermaven_client,
 745            };
 746
 747            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 748                tracing::error!(?error, "failed to send initial client update");
 749                return;
 750            }
 751
 752            let handle_io = handle_io.fuse();
 753            futures::pin_mut!(handle_io);
 754
 755            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 756            // This prevents deadlocks when e.g., client A performs a request to client B and
 757            // client B performs a request to client A. If both clients stop processing further
 758            // messages until their respective request completes, they won't have a chance to
 759            // respond to the other client's request and cause a deadlock.
 760            //
 761            // This arrangement ensures we will attempt to process earlier messages first, but fall
 762            // back to processing messages arrived later in the spirit of making progress.
 763            let mut foreground_message_handlers = FuturesUnordered::new();
 764            let concurrent_handlers = Arc::new(Semaphore::new(256));
 765            loop {
 766                let next_message = async {
 767                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 768                    let message = incoming_rx.next().await;
 769                    (permit, message)
 770                }.fuse();
 771                futures::pin_mut!(next_message);
 772                futures::select_biased! {
 773                    _ = teardown.changed().fuse() => return,
 774                    result = handle_io => {
 775                        if let Err(error) = result {
 776                            tracing::error!(?error, "error handling I/O");
 777                        }
 778                        break;
 779                    }
 780                    _ = foreground_message_handlers.next() => {}
 781                    next_message = next_message => {
 782                        let (permit, message) = next_message;
 783                        if let Some(message) = message {
 784                            let type_name = message.payload_type_name();
 785                            // note: we copy all the fields from the parent span so we can query them in the logs.
 786                            // (https://github.com/tokio-rs/tracing/issues/2670).
 787                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 788                                user_id=field::Empty,
 789                                login=field::Empty,
 790                                impersonator=field::Empty,
 791                            );
 792                            principal.update_span(&span);
 793                            let span_enter = span.enter();
 794                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 795                                let is_background = message.is_background();
 796                                let handle_message = (handler)(message, session.clone());
 797                                drop(span_enter);
 798
 799                                let handle_message = async move {
 800                                    handle_message.await;
 801                                    drop(permit);
 802                                }.instrument(span);
 803                                if is_background {
 804                                    executor.spawn_detached(handle_message);
 805                                } else {
 806                                    foreground_message_handlers.push(handle_message);
 807                                }
 808                            } else {
 809                                tracing::error!("no message handler");
 810                            }
 811                        } else {
 812                            tracing::info!("connection closed");
 813                            break;
 814                        }
 815                    }
 816                }
 817            }
 818
 819            drop(foreground_message_handlers);
 820            tracing::info!("signing out");
 821            if let Err(error) = connection_lost(session, teardown, executor).await {
 822                tracing::error!(?error, "error signing out");
 823            }
 824
 825        }.instrument(span)
 826    }
 827
 828    async fn send_initial_client_update(
 829        &self,
 830        connection_id: ConnectionId,
 831        principal: &Principal,
 832        zed_version: ZedVersion,
 833        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 834        session: &Session,
 835    ) -> Result<()> {
 836        self.peer.send(
 837            connection_id,
 838            proto::Hello {
 839                peer_id: Some(connection_id.into()),
 840            },
 841        )?;
 842        tracing::info!("sent hello message");
 843        if let Some(send_connection_id) = send_connection_id.take() {
 844            let _ = send_connection_id.send(connection_id);
 845        }
 846
 847        match principal {
 848            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 849                if !user.connected_once {
 850                    self.peer.send(connection_id, proto::ShowContacts {})?;
 851                    self.app_state
 852                        .db
 853                        .set_user_connected_once(user.id, true)
 854                        .await?;
 855                }
 856
 857                update_user_plan(user.id, session).await?;
 858
 859                let contacts = self.app_state.db.get_contacts(user.id).await?;
 860
 861                {
 862                    let mut pool = self.connection_pool.lock();
 863                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 864                    self.peer.send(
 865                        connection_id,
 866                        build_initial_contacts_update(contacts, &pool),
 867                    )?;
 868                }
 869
 870                if should_auto_subscribe_to_channels(zed_version) {
 871                    subscribe_user_to_channels(user.id, session).await?;
 872                }
 873
 874                if let Some(incoming_call) =
 875                    self.app_state.db.incoming_call_for_user(user.id).await?
 876                {
 877                    self.peer.send(connection_id, incoming_call)?;
 878                }
 879
 880                update_user_contacts(user.id, session).await?;
 881            }
 882        }
 883
 884        Ok(())
 885    }
 886
 887    pub async fn invite_code_redeemed(
 888        self: &Arc<Self>,
 889        inviter_id: UserId,
 890        invitee_id: UserId,
 891    ) -> Result<()> {
 892        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 893            if let Some(code) = &user.invite_code {
 894                let pool = self.connection_pool.lock();
 895                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 896                for connection_id in pool.user_connection_ids(inviter_id) {
 897                    self.peer.send(
 898                        connection_id,
 899                        proto::UpdateContacts {
 900                            contacts: vec![invitee_contact.clone()],
 901                            ..Default::default()
 902                        },
 903                    )?;
 904                    self.peer.send(
 905                        connection_id,
 906                        proto::UpdateInviteInfo {
 907                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 908                            count: user.invite_count as u32,
 909                        },
 910                    )?;
 911                }
 912            }
 913        }
 914        Ok(())
 915    }
 916
 917    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 918        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 919            if let Some(invite_code) = &user.invite_code {
 920                let pool = self.connection_pool.lock();
 921                for connection_id in pool.user_connection_ids(user_id) {
 922                    self.peer.send(
 923                        connection_id,
 924                        proto::UpdateInviteInfo {
 925                            url: format!(
 926                                "{}{}",
 927                                self.app_state.config.invite_link_prefix, invite_code
 928                            ),
 929                            count: user.invite_count as u32,
 930                        },
 931                    )?;
 932                }
 933            }
 934        }
 935        Ok(())
 936    }
 937
 938    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 939        let pool = self.connection_pool.lock();
 940        for connection_id in pool.user_connection_ids(user_id) {
 941            self.peer
 942                .send(connection_id, proto::RefreshLlmToken {})
 943                .trace_err();
 944        }
 945    }
 946
 947    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 948        ServerSnapshot {
 949            connection_pool: ConnectionPoolGuard {
 950                guard: self.connection_pool.lock(),
 951                _not_send: PhantomData,
 952            },
 953            peer: &self.peer,
 954        }
 955    }
 956}
 957
 958impl<'a> Deref for ConnectionPoolGuard<'a> {
 959    type Target = ConnectionPool;
 960
 961    fn deref(&self) -> &Self::Target {
 962        &self.guard
 963    }
 964}
 965
 966impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 967    fn deref_mut(&mut self) -> &mut Self::Target {
 968        &mut self.guard
 969    }
 970}
 971
 972impl<'a> Drop for ConnectionPoolGuard<'a> {
 973    fn drop(&mut self) {
 974        #[cfg(test)]
 975        self.check_invariants();
 976    }
 977}
 978
 979fn broadcast<F>(
 980    sender_id: Option<ConnectionId>,
 981    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 982    mut f: F,
 983) where
 984    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 985{
 986    for receiver_id in receiver_ids {
 987        if Some(receiver_id) != sender_id {
 988            if let Err(error) = f(receiver_id) {
 989                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 990            }
 991        }
 992    }
 993}
 994
 995pub struct ProtocolVersion(u32);
 996
 997impl Header for ProtocolVersion {
 998    fn name() -> &'static HeaderName {
 999        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1000        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1001    }
1002
1003    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1004    where
1005        Self: Sized,
1006        I: Iterator<Item = &'i axum::http::HeaderValue>,
1007    {
1008        let version = values
1009            .next()
1010            .ok_or_else(axum::headers::Error::invalid)?
1011            .to_str()
1012            .map_err(|_| axum::headers::Error::invalid())?
1013            .parse()
1014            .map_err(|_| axum::headers::Error::invalid())?;
1015        Ok(Self(version))
1016    }
1017
1018    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1019        values.extend([self.0.to_string().parse().unwrap()]);
1020    }
1021}
1022
1023pub struct AppVersionHeader(SemanticVersion);
1024impl Header for AppVersionHeader {
1025    fn name() -> &'static HeaderName {
1026        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1027        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1028    }
1029
1030    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1031    where
1032        Self: Sized,
1033        I: Iterator<Item = &'i axum::http::HeaderValue>,
1034    {
1035        let version = values
1036            .next()
1037            .ok_or_else(axum::headers::Error::invalid)?
1038            .to_str()
1039            .map_err(|_| axum::headers::Error::invalid())?
1040            .parse()
1041            .map_err(|_| axum::headers::Error::invalid())?;
1042        Ok(Self(version))
1043    }
1044
1045    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1046        values.extend([self.0.to_string().parse().unwrap()]);
1047    }
1048}
1049
1050pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1051    Router::new()
1052        .route("/rpc", get(handle_websocket_request))
1053        .layer(
1054            ServiceBuilder::new()
1055                .layer(Extension(server.app_state.clone()))
1056                .layer(middleware::from_fn(auth::validate_header)),
1057        )
1058        .route("/metrics", get(handle_metrics))
1059        .layer(Extension(server))
1060}
1061
1062#[allow(clippy::too_many_arguments)]
1063pub async fn handle_websocket_request(
1064    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1065    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1066    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1067    Extension(server): Extension<Arc<Server>>,
1068    Extension(principal): Extension<Principal>,
1069    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1070    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1071    ws: WebSocketUpgrade,
1072) -> axum::response::Response {
1073    if protocol_version != rpc::PROTOCOL_VERSION {
1074        return (
1075            StatusCode::UPGRADE_REQUIRED,
1076            "client must be upgraded".to_string(),
1077        )
1078            .into_response();
1079    }
1080
1081    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1082        return (
1083            StatusCode::UPGRADE_REQUIRED,
1084            "no version header found".to_string(),
1085        )
1086            .into_response();
1087    };
1088
1089    if !version.can_collaborate() {
1090        return (
1091            StatusCode::UPGRADE_REQUIRED,
1092            "client must be upgraded".to_string(),
1093        )
1094            .into_response();
1095    }
1096
1097    let socket_address = socket_address.to_string();
1098    ws.on_upgrade(move |socket| {
1099        let socket = socket
1100            .map_ok(to_tungstenite_message)
1101            .err_into()
1102            .with(|message| async move { to_axum_message(message) });
1103        let connection = Connection::new(Box::pin(socket));
1104        async move {
1105            server
1106                .handle_connection(
1107                    connection,
1108                    socket_address,
1109                    principal,
1110                    version,
1111                    country_code_header.map(|header| header.to_string()),
1112                    system_id_header.map(|header| header.to_string()),
1113                    None,
1114                    Executor::Production,
1115                )
1116                .await;
1117        }
1118    })
1119}
1120
1121pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1122    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1123    let connections_metric = CONNECTIONS_METRIC
1124        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1125
1126    let connections = server
1127        .connection_pool
1128        .lock()
1129        .connections()
1130        .filter(|connection| !connection.admin)
1131        .count();
1132    connections_metric.set(connections as _);
1133
1134    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1135    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1136        register_int_gauge!(
1137            "shared_projects",
1138            "number of open projects with one or more guests"
1139        )
1140        .unwrap()
1141    });
1142
1143    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1144    shared_projects_metric.set(shared_projects as _);
1145
1146    let encoder = prometheus::TextEncoder::new();
1147    let metric_families = prometheus::gather();
1148    let encoded_metrics = encoder
1149        .encode_to_string(&metric_families)
1150        .map_err(|err| anyhow!("{}", err))?;
1151    Ok(encoded_metrics)
1152}
1153
1154#[instrument(err, skip(executor))]
1155async fn connection_lost(
1156    session: Session,
1157    mut teardown: watch::Receiver<bool>,
1158    executor: Executor,
1159) -> Result<()> {
1160    session.peer.disconnect(session.connection_id);
1161    session
1162        .connection_pool()
1163        .await
1164        .remove_connection(session.connection_id)?;
1165
1166    session
1167        .db()
1168        .await
1169        .connection_lost(session.connection_id)
1170        .await
1171        .trace_err();
1172
1173    futures::select_biased! {
1174        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1175
1176            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1177            leave_room_for_session(&session, session.connection_id).await.trace_err();
1178            leave_channel_buffers_for_session(&session)
1179                .await
1180                .trace_err();
1181
1182            if !session
1183                .connection_pool()
1184                .await
1185                .is_user_online(session.user_id())
1186            {
1187                let db = session.db().await;
1188                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1189                    room_updated(&room, &session.peer);
1190                }
1191            }
1192
1193            update_user_contacts(session.user_id(), &session).await?;
1194        },
1195        _ = teardown.changed().fuse() => {}
1196    }
1197
1198    Ok(())
1199}
1200
1201/// Acknowledges a ping from a client, used to keep the connection alive.
1202async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1203    response.send(proto::Ack {})?;
1204    Ok(())
1205}
1206
1207/// Creates a new room for calling (outside of channels)
1208async fn create_room(
1209    _request: proto::CreateRoom,
1210    response: Response<proto::CreateRoom>,
1211    session: Session,
1212) -> Result<()> {
1213    let live_kit_room = nanoid::nanoid!(30);
1214
1215    let live_kit_connection_info = util::maybe!(async {
1216        let live_kit = session.app_state.live_kit_client.as_ref();
1217        let live_kit = live_kit?;
1218        let user_id = session.user_id().to_string();
1219
1220        let token = live_kit
1221            .room_token(&live_kit_room, &user_id.to_string())
1222            .trace_err()?;
1223
1224        Some(proto::LiveKitConnectionInfo {
1225            server_url: live_kit.url().into(),
1226            token,
1227            can_publish: true,
1228        })
1229    })
1230    .await;
1231
1232    let room = session
1233        .db()
1234        .await
1235        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1236        .await?;
1237
1238    response.send(proto::CreateRoomResponse {
1239        room: Some(room.clone()),
1240        live_kit_connection_info,
1241    })?;
1242
1243    update_user_contacts(session.user_id(), &session).await?;
1244    Ok(())
1245}
1246
1247/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1248async fn join_room(
1249    request: proto::JoinRoom,
1250    response: Response<proto::JoinRoom>,
1251    session: Session,
1252) -> Result<()> {
1253    let room_id = RoomId::from_proto(request.id);
1254
1255    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1256
1257    if let Some(channel_id) = channel_id {
1258        return join_channel_internal(channel_id, Box::new(response), session).await;
1259    }
1260
1261    let joined_room = {
1262        let room = session
1263            .db()
1264            .await
1265            .join_room(room_id, session.user_id(), session.connection_id)
1266            .await?;
1267        room_updated(&room.room, &session.peer);
1268        room.into_inner()
1269    };
1270
1271    for connection_id in session
1272        .connection_pool()
1273        .await
1274        .user_connection_ids(session.user_id())
1275    {
1276        session
1277            .peer
1278            .send(
1279                connection_id,
1280                proto::CallCanceled {
1281                    room_id: room_id.to_proto(),
1282                },
1283            )
1284            .trace_err();
1285    }
1286
1287    let live_kit_connection_info =
1288        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1289            live_kit
1290                .room_token(
1291                    &joined_room.room.live_kit_room,
1292                    &session.user_id().to_string(),
1293                )
1294                .trace_err()
1295                .map(|token| proto::LiveKitConnectionInfo {
1296                    server_url: live_kit.url().into(),
1297                    token,
1298                    can_publish: true,
1299                })
1300        } else {
1301            None
1302        };
1303
1304    response.send(proto::JoinRoomResponse {
1305        room: Some(joined_room.room),
1306        channel_id: None,
1307        live_kit_connection_info,
1308    })?;
1309
1310    update_user_contacts(session.user_id(), &session).await?;
1311    Ok(())
1312}
1313
1314/// Rejoin room is used to reconnect to a room after connection errors.
1315async fn rejoin_room(
1316    request: proto::RejoinRoom,
1317    response: Response<proto::RejoinRoom>,
1318    session: Session,
1319) -> Result<()> {
1320    let room;
1321    let channel;
1322    {
1323        let mut rejoined_room = session
1324            .db()
1325            .await
1326            .rejoin_room(request, session.user_id(), session.connection_id)
1327            .await?;
1328
1329        response.send(proto::RejoinRoomResponse {
1330            room: Some(rejoined_room.room.clone()),
1331            reshared_projects: rejoined_room
1332                .reshared_projects
1333                .iter()
1334                .map(|project| proto::ResharedProject {
1335                    id: project.id.to_proto(),
1336                    collaborators: project
1337                        .collaborators
1338                        .iter()
1339                        .map(|collaborator| collaborator.to_proto())
1340                        .collect(),
1341                })
1342                .collect(),
1343            rejoined_projects: rejoined_room
1344                .rejoined_projects
1345                .iter()
1346                .map(|rejoined_project| rejoined_project.to_proto())
1347                .collect(),
1348        })?;
1349        room_updated(&rejoined_room.room, &session.peer);
1350
1351        for project in &rejoined_room.reshared_projects {
1352            for collaborator in &project.collaborators {
1353                session
1354                    .peer
1355                    .send(
1356                        collaborator.connection_id,
1357                        proto::UpdateProjectCollaborator {
1358                            project_id: project.id.to_proto(),
1359                            old_peer_id: Some(project.old_connection_id.into()),
1360                            new_peer_id: Some(session.connection_id.into()),
1361                        },
1362                    )
1363                    .trace_err();
1364            }
1365
1366            broadcast(
1367                Some(session.connection_id),
1368                project
1369                    .collaborators
1370                    .iter()
1371                    .map(|collaborator| collaborator.connection_id),
1372                |connection_id| {
1373                    session.peer.forward_send(
1374                        session.connection_id,
1375                        connection_id,
1376                        proto::UpdateProject {
1377                            project_id: project.id.to_proto(),
1378                            worktrees: project.worktrees.clone(),
1379                        },
1380                    )
1381                },
1382            );
1383        }
1384
1385        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1386
1387        let rejoined_room = rejoined_room.into_inner();
1388
1389        room = rejoined_room.room;
1390        channel = rejoined_room.channel;
1391    }
1392
1393    if let Some(channel) = channel {
1394        channel_updated(
1395            &channel,
1396            &room,
1397            &session.peer,
1398            &*session.connection_pool().await,
1399        );
1400    }
1401
1402    update_user_contacts(session.user_id(), &session).await?;
1403    Ok(())
1404}
1405
1406fn notify_rejoined_projects(
1407    rejoined_projects: &mut Vec<RejoinedProject>,
1408    session: &Session,
1409) -> Result<()> {
1410    for project in rejoined_projects.iter() {
1411        for collaborator in &project.collaborators {
1412            session
1413                .peer
1414                .send(
1415                    collaborator.connection_id,
1416                    proto::UpdateProjectCollaborator {
1417                        project_id: project.id.to_proto(),
1418                        old_peer_id: Some(project.old_connection_id.into()),
1419                        new_peer_id: Some(session.connection_id.into()),
1420                    },
1421                )
1422                .trace_err();
1423        }
1424    }
1425
1426    for project in rejoined_projects {
1427        for worktree in mem::take(&mut project.worktrees) {
1428            // Stream this worktree's entries.
1429            let message = proto::UpdateWorktree {
1430                project_id: project.id.to_proto(),
1431                worktree_id: worktree.id,
1432                abs_path: worktree.abs_path.clone(),
1433                root_name: worktree.root_name,
1434                updated_entries: worktree.updated_entries,
1435                removed_entries: worktree.removed_entries,
1436                scan_id: worktree.scan_id,
1437                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1438                updated_repositories: worktree.updated_repositories,
1439                removed_repositories: worktree.removed_repositories,
1440            };
1441            for update in proto::split_worktree_update(message) {
1442                session.peer.send(session.connection_id, update.clone())?;
1443            }
1444
1445            // Stream this worktree's diagnostics.
1446            for summary in worktree.diagnostic_summaries {
1447                session.peer.send(
1448                    session.connection_id,
1449                    proto::UpdateDiagnosticSummary {
1450                        project_id: project.id.to_proto(),
1451                        worktree_id: worktree.id,
1452                        summary: Some(summary),
1453                    },
1454                )?;
1455            }
1456
1457            for settings_file in worktree.settings_files {
1458                session.peer.send(
1459                    session.connection_id,
1460                    proto::UpdateWorktreeSettings {
1461                        project_id: project.id.to_proto(),
1462                        worktree_id: worktree.id,
1463                        path: settings_file.path,
1464                        content: Some(settings_file.content),
1465                        kind: Some(settings_file.kind.to_proto().into()),
1466                    },
1467                )?;
1468            }
1469        }
1470
1471        for language_server in &project.language_servers {
1472            session.peer.send(
1473                session.connection_id,
1474                proto::UpdateLanguageServer {
1475                    project_id: project.id.to_proto(),
1476                    language_server_id: language_server.id,
1477                    variant: Some(
1478                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1479                            proto::LspDiskBasedDiagnosticsUpdated {},
1480                        ),
1481                    ),
1482                },
1483            )?;
1484        }
1485    }
1486    Ok(())
1487}
1488
1489/// leave room disconnects from the room.
1490async fn leave_room(
1491    _: proto::LeaveRoom,
1492    response: Response<proto::LeaveRoom>,
1493    session: Session,
1494) -> Result<()> {
1495    leave_room_for_session(&session, session.connection_id).await?;
1496    response.send(proto::Ack {})?;
1497    Ok(())
1498}
1499
1500/// Updates the permissions of someone else in the room.
1501async fn set_room_participant_role(
1502    request: proto::SetRoomParticipantRole,
1503    response: Response<proto::SetRoomParticipantRole>,
1504    session: Session,
1505) -> Result<()> {
1506    let user_id = UserId::from_proto(request.user_id);
1507    let role = ChannelRole::from(request.role());
1508
1509    let (live_kit_room, can_publish) = {
1510        let room = session
1511            .db()
1512            .await
1513            .set_room_participant_role(
1514                session.user_id(),
1515                RoomId::from_proto(request.room_id),
1516                user_id,
1517                role,
1518            )
1519            .await?;
1520
1521        let live_kit_room = room.live_kit_room.clone();
1522        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1523        room_updated(&room, &session.peer);
1524        (live_kit_room, can_publish)
1525    };
1526
1527    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1528        live_kit
1529            .update_participant(
1530                live_kit_room.clone(),
1531                request.user_id.to_string(),
1532                live_kit_server::proto::ParticipantPermission {
1533                    can_subscribe: true,
1534                    can_publish,
1535                    can_publish_data: can_publish,
1536                    hidden: false,
1537                    recorder: false,
1538                },
1539            )
1540            .await
1541            .trace_err();
1542    }
1543
1544    response.send(proto::Ack {})?;
1545    Ok(())
1546}
1547
1548/// Call someone else into the current room
1549async fn call(
1550    request: proto::Call,
1551    response: Response<proto::Call>,
1552    session: Session,
1553) -> Result<()> {
1554    let room_id = RoomId::from_proto(request.room_id);
1555    let calling_user_id = session.user_id();
1556    let calling_connection_id = session.connection_id;
1557    let called_user_id = UserId::from_proto(request.called_user_id);
1558    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1559    if !session
1560        .db()
1561        .await
1562        .has_contact(calling_user_id, called_user_id)
1563        .await?
1564    {
1565        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1566    }
1567
1568    let incoming_call = {
1569        let (room, incoming_call) = &mut *session
1570            .db()
1571            .await
1572            .call(
1573                room_id,
1574                calling_user_id,
1575                calling_connection_id,
1576                called_user_id,
1577                initial_project_id,
1578            )
1579            .await?;
1580        room_updated(room, &session.peer);
1581        mem::take(incoming_call)
1582    };
1583    update_user_contacts(called_user_id, &session).await?;
1584
1585    let mut calls = session
1586        .connection_pool()
1587        .await
1588        .user_connection_ids(called_user_id)
1589        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1590        .collect::<FuturesUnordered<_>>();
1591
1592    while let Some(call_response) = calls.next().await {
1593        match call_response.as_ref() {
1594            Ok(_) => {
1595                response.send(proto::Ack {})?;
1596                return Ok(());
1597            }
1598            Err(_) => {
1599                call_response.trace_err();
1600            }
1601        }
1602    }
1603
1604    {
1605        let room = session
1606            .db()
1607            .await
1608            .call_failed(room_id, called_user_id)
1609            .await?;
1610        room_updated(&room, &session.peer);
1611    }
1612    update_user_contacts(called_user_id, &session).await?;
1613
1614    Err(anyhow!("failed to ring user"))?
1615}
1616
1617/// Cancel an outgoing call.
1618async fn cancel_call(
1619    request: proto::CancelCall,
1620    response: Response<proto::CancelCall>,
1621    session: Session,
1622) -> Result<()> {
1623    let called_user_id = UserId::from_proto(request.called_user_id);
1624    let room_id = RoomId::from_proto(request.room_id);
1625    {
1626        let room = session
1627            .db()
1628            .await
1629            .cancel_call(room_id, session.connection_id, called_user_id)
1630            .await?;
1631        room_updated(&room, &session.peer);
1632    }
1633
1634    for connection_id in session
1635        .connection_pool()
1636        .await
1637        .user_connection_ids(called_user_id)
1638    {
1639        session
1640            .peer
1641            .send(
1642                connection_id,
1643                proto::CallCanceled {
1644                    room_id: room_id.to_proto(),
1645                },
1646            )
1647            .trace_err();
1648    }
1649    response.send(proto::Ack {})?;
1650
1651    update_user_contacts(called_user_id, &session).await?;
1652    Ok(())
1653}
1654
1655/// Decline an incoming call.
1656async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1657    let room_id = RoomId::from_proto(message.room_id);
1658    {
1659        let room = session
1660            .db()
1661            .await
1662            .decline_call(Some(room_id), session.user_id())
1663            .await?
1664            .ok_or_else(|| anyhow!("failed to decline call"))?;
1665        room_updated(&room, &session.peer);
1666    }
1667
1668    for connection_id in session
1669        .connection_pool()
1670        .await
1671        .user_connection_ids(session.user_id())
1672    {
1673        session
1674            .peer
1675            .send(
1676                connection_id,
1677                proto::CallCanceled {
1678                    room_id: room_id.to_proto(),
1679                },
1680            )
1681            .trace_err();
1682    }
1683    update_user_contacts(session.user_id(), &session).await?;
1684    Ok(())
1685}
1686
1687/// Updates other participants in the room with your current location.
1688async fn update_participant_location(
1689    request: proto::UpdateParticipantLocation,
1690    response: Response<proto::UpdateParticipantLocation>,
1691    session: Session,
1692) -> Result<()> {
1693    let room_id = RoomId::from_proto(request.room_id);
1694    let location = request
1695        .location
1696        .ok_or_else(|| anyhow!("invalid location"))?;
1697
1698    let db = session.db().await;
1699    let room = db
1700        .update_room_participant_location(room_id, session.connection_id, location)
1701        .await?;
1702
1703    room_updated(&room, &session.peer);
1704    response.send(proto::Ack {})?;
1705    Ok(())
1706}
1707
1708/// Share a project into the room.
1709async fn share_project(
1710    request: proto::ShareProject,
1711    response: Response<proto::ShareProject>,
1712    session: Session,
1713) -> Result<()> {
1714    let (project_id, room) = &*session
1715        .db()
1716        .await
1717        .share_project(
1718            RoomId::from_proto(request.room_id),
1719            session.connection_id,
1720            &request.worktrees,
1721            request.is_ssh_project,
1722        )
1723        .await?;
1724    response.send(proto::ShareProjectResponse {
1725        project_id: project_id.to_proto(),
1726    })?;
1727    room_updated(room, &session.peer);
1728
1729    Ok(())
1730}
1731
1732/// Unshare a project from the room.
1733async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1734    let project_id = ProjectId::from_proto(message.project_id);
1735    unshare_project_internal(project_id, session.connection_id, &session).await
1736}
1737
1738async fn unshare_project_internal(
1739    project_id: ProjectId,
1740    connection_id: ConnectionId,
1741    session: &Session,
1742) -> Result<()> {
1743    let delete = {
1744        let room_guard = session
1745            .db()
1746            .await
1747            .unshare_project(project_id, connection_id)
1748            .await?;
1749
1750        let (delete, room, guest_connection_ids) = &*room_guard;
1751
1752        let message = proto::UnshareProject {
1753            project_id: project_id.to_proto(),
1754        };
1755
1756        broadcast(
1757            Some(connection_id),
1758            guest_connection_ids.iter().copied(),
1759            |conn_id| session.peer.send(conn_id, message.clone()),
1760        );
1761        if let Some(room) = room {
1762            room_updated(room, &session.peer);
1763        }
1764
1765        *delete
1766    };
1767
1768    if delete {
1769        let db = session.db().await;
1770        db.delete_project(project_id).await?;
1771    }
1772
1773    Ok(())
1774}
1775
1776/// Join someone elses shared project.
1777async fn join_project(
1778    request: proto::JoinProject,
1779    response: Response<proto::JoinProject>,
1780    session: Session,
1781) -> Result<()> {
1782    let project_id = ProjectId::from_proto(request.project_id);
1783
1784    tracing::info!(%project_id, "join project");
1785
1786    let db = session.db().await;
1787    let (project, replica_id) = &mut *db
1788        .join_project(project_id, session.connection_id, session.user_id())
1789        .await?;
1790    drop(db);
1791    tracing::info!(%project_id, "join remote project");
1792    join_project_internal(response, session, project, replica_id)
1793}
1794
1795trait JoinProjectInternalResponse {
1796    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1797}
1798impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1799    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1800        Response::<proto::JoinProject>::send(self, result)
1801    }
1802}
1803
1804fn join_project_internal(
1805    response: impl JoinProjectInternalResponse,
1806    session: Session,
1807    project: &mut Project,
1808    replica_id: &ReplicaId,
1809) -> Result<()> {
1810    let collaborators = project
1811        .collaborators
1812        .iter()
1813        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1814        .map(|collaborator| collaborator.to_proto())
1815        .collect::<Vec<_>>();
1816    let project_id = project.id;
1817    let guest_user_id = session.user_id();
1818
1819    let worktrees = project
1820        .worktrees
1821        .iter()
1822        .map(|(id, worktree)| proto::WorktreeMetadata {
1823            id: *id,
1824            root_name: worktree.root_name.clone(),
1825            visible: worktree.visible,
1826            abs_path: worktree.abs_path.clone(),
1827        })
1828        .collect::<Vec<_>>();
1829
1830    let add_project_collaborator = proto::AddProjectCollaborator {
1831        project_id: project_id.to_proto(),
1832        collaborator: Some(proto::Collaborator {
1833            peer_id: Some(session.connection_id.into()),
1834            replica_id: replica_id.0 as u32,
1835            user_id: guest_user_id.to_proto(),
1836            is_host: false,
1837        }),
1838    };
1839
1840    for collaborator in &collaborators {
1841        session
1842            .peer
1843            .send(
1844                collaborator.peer_id.unwrap().into(),
1845                add_project_collaborator.clone(),
1846            )
1847            .trace_err();
1848    }
1849
1850    // First, we send the metadata associated with each worktree.
1851    response.send(proto::JoinProjectResponse {
1852        project_id: project.id.0 as u64,
1853        worktrees: worktrees.clone(),
1854        replica_id: replica_id.0 as u32,
1855        collaborators: collaborators.clone(),
1856        language_servers: project.language_servers.clone(),
1857        role: project.role.into(),
1858    })?;
1859
1860    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1861        // Stream this worktree's entries.
1862        let message = proto::UpdateWorktree {
1863            project_id: project_id.to_proto(),
1864            worktree_id,
1865            abs_path: worktree.abs_path.clone(),
1866            root_name: worktree.root_name,
1867            updated_entries: worktree.entries,
1868            removed_entries: Default::default(),
1869            scan_id: worktree.scan_id,
1870            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1871            updated_repositories: worktree.repository_entries.into_values().collect(),
1872            removed_repositories: Default::default(),
1873        };
1874        for update in proto::split_worktree_update(message) {
1875            session.peer.send(session.connection_id, update.clone())?;
1876        }
1877
1878        // Stream this worktree's diagnostics.
1879        for summary in worktree.diagnostic_summaries {
1880            session.peer.send(
1881                session.connection_id,
1882                proto::UpdateDiagnosticSummary {
1883                    project_id: project_id.to_proto(),
1884                    worktree_id: worktree.id,
1885                    summary: Some(summary),
1886                },
1887            )?;
1888        }
1889
1890        for settings_file in worktree.settings_files {
1891            session.peer.send(
1892                session.connection_id,
1893                proto::UpdateWorktreeSettings {
1894                    project_id: project_id.to_proto(),
1895                    worktree_id: worktree.id,
1896                    path: settings_file.path,
1897                    content: Some(settings_file.content),
1898                    kind: Some(settings_file.kind.to_proto() as i32),
1899                },
1900            )?;
1901        }
1902    }
1903
1904    for language_server in &project.language_servers {
1905        session.peer.send(
1906            session.connection_id,
1907            proto::UpdateLanguageServer {
1908                project_id: project_id.to_proto(),
1909                language_server_id: language_server.id,
1910                variant: Some(
1911                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1912                        proto::LspDiskBasedDiagnosticsUpdated {},
1913                    ),
1914                ),
1915            },
1916        )?;
1917    }
1918
1919    Ok(())
1920}
1921
1922/// Leave someone elses shared project.
1923async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1924    let sender_id = session.connection_id;
1925    let project_id = ProjectId::from_proto(request.project_id);
1926    let db = session.db().await;
1927
1928    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1929    tracing::info!(
1930        %project_id,
1931        "leave project"
1932    );
1933
1934    project_left(project, &session);
1935    if let Some(room) = room {
1936        room_updated(room, &session.peer);
1937    }
1938
1939    Ok(())
1940}
1941
1942/// Updates other participants with changes to the project
1943async fn update_project(
1944    request: proto::UpdateProject,
1945    response: Response<proto::UpdateProject>,
1946    session: Session,
1947) -> Result<()> {
1948    let project_id = ProjectId::from_proto(request.project_id);
1949    let (room, guest_connection_ids) = &*session
1950        .db()
1951        .await
1952        .update_project(project_id, session.connection_id, &request.worktrees)
1953        .await?;
1954    broadcast(
1955        Some(session.connection_id),
1956        guest_connection_ids.iter().copied(),
1957        |connection_id| {
1958            session
1959                .peer
1960                .forward_send(session.connection_id, connection_id, request.clone())
1961        },
1962    );
1963    if let Some(room) = room {
1964        room_updated(room, &session.peer);
1965    }
1966    response.send(proto::Ack {})?;
1967
1968    Ok(())
1969}
1970
1971/// Updates other participants with changes to the worktree
1972async fn update_worktree(
1973    request: proto::UpdateWorktree,
1974    response: Response<proto::UpdateWorktree>,
1975    session: Session,
1976) -> Result<()> {
1977    let guest_connection_ids = session
1978        .db()
1979        .await
1980        .update_worktree(&request, session.connection_id)
1981        .await?;
1982
1983    broadcast(
1984        Some(session.connection_id),
1985        guest_connection_ids.iter().copied(),
1986        |connection_id| {
1987            session
1988                .peer
1989                .forward_send(session.connection_id, connection_id, request.clone())
1990        },
1991    );
1992    response.send(proto::Ack {})?;
1993    Ok(())
1994}
1995
1996/// Updates other participants with changes to the diagnostics
1997async fn update_diagnostic_summary(
1998    message: proto::UpdateDiagnosticSummary,
1999    session: Session,
2000) -> Result<()> {
2001    let guest_connection_ids = session
2002        .db()
2003        .await
2004        .update_diagnostic_summary(&message, session.connection_id)
2005        .await?;
2006
2007    broadcast(
2008        Some(session.connection_id),
2009        guest_connection_ids.iter().copied(),
2010        |connection_id| {
2011            session
2012                .peer
2013                .forward_send(session.connection_id, connection_id, message.clone())
2014        },
2015    );
2016
2017    Ok(())
2018}
2019
2020/// Updates other participants with changes to the worktree settings
2021async fn update_worktree_settings(
2022    message: proto::UpdateWorktreeSettings,
2023    session: Session,
2024) -> Result<()> {
2025    let guest_connection_ids = session
2026        .db()
2027        .await
2028        .update_worktree_settings(&message, session.connection_id)
2029        .await?;
2030
2031    broadcast(
2032        Some(session.connection_id),
2033        guest_connection_ids.iter().copied(),
2034        |connection_id| {
2035            session
2036                .peer
2037                .forward_send(session.connection_id, connection_id, message.clone())
2038        },
2039    );
2040
2041    Ok(())
2042}
2043
2044/// Notify other participants that a  language server has started.
2045async fn start_language_server(
2046    request: proto::StartLanguageServer,
2047    session: Session,
2048) -> Result<()> {
2049    let guest_connection_ids = session
2050        .db()
2051        .await
2052        .start_language_server(&request, session.connection_id)
2053        .await?;
2054
2055    broadcast(
2056        Some(session.connection_id),
2057        guest_connection_ids.iter().copied(),
2058        |connection_id| {
2059            session
2060                .peer
2061                .forward_send(session.connection_id, connection_id, request.clone())
2062        },
2063    );
2064    Ok(())
2065}
2066
2067/// Notify other participants that a language server has changed.
2068async fn update_language_server(
2069    request: proto::UpdateLanguageServer,
2070    session: Session,
2071) -> Result<()> {
2072    let project_id = ProjectId::from_proto(request.project_id);
2073    let project_connection_ids = session
2074        .db()
2075        .await
2076        .project_connection_ids(project_id, session.connection_id, true)
2077        .await?;
2078    broadcast(
2079        Some(session.connection_id),
2080        project_connection_ids.iter().copied(),
2081        |connection_id| {
2082            session
2083                .peer
2084                .forward_send(session.connection_id, connection_id, request.clone())
2085        },
2086    );
2087    Ok(())
2088}
2089
2090/// forward a project request to the host. These requests should be read only
2091/// as guests are allowed to send them.
2092async fn forward_read_only_project_request<T>(
2093    request: T,
2094    response: Response<T>,
2095    session: Session,
2096) -> Result<()>
2097where
2098    T: EntityMessage + RequestMessage,
2099{
2100    let project_id = ProjectId::from_proto(request.remote_entity_id());
2101    let host_connection_id = session
2102        .db()
2103        .await
2104        .host_for_read_only_project_request(project_id, session.connection_id)
2105        .await?;
2106    let payload = session
2107        .peer
2108        .forward_request(session.connection_id, host_connection_id, request)
2109        .await?;
2110    response.send(payload)?;
2111    Ok(())
2112}
2113
2114async fn forward_find_search_candidates_request(
2115    request: proto::FindSearchCandidates,
2116    response: Response<proto::FindSearchCandidates>,
2117    session: Session,
2118) -> Result<()> {
2119    let project_id = ProjectId::from_proto(request.remote_entity_id());
2120    let host_connection_id = session
2121        .db()
2122        .await
2123        .host_for_read_only_project_request(project_id, session.connection_id)
2124        .await?;
2125    let payload = session
2126        .peer
2127        .forward_request(session.connection_id, host_connection_id, request)
2128        .await?;
2129    response.send(payload)?;
2130    Ok(())
2131}
2132
2133/// forward a project request to the host. These requests are disallowed
2134/// for guests.
2135async fn forward_mutating_project_request<T>(
2136    request: T,
2137    response: Response<T>,
2138    session: Session,
2139) -> Result<()>
2140where
2141    T: EntityMessage + RequestMessage,
2142{
2143    let project_id = ProjectId::from_proto(request.remote_entity_id());
2144
2145    let host_connection_id = session
2146        .db()
2147        .await
2148        .host_for_mutating_project_request(project_id, session.connection_id)
2149        .await?;
2150    let payload = session
2151        .peer
2152        .forward_request(session.connection_id, host_connection_id, request)
2153        .await?;
2154    response.send(payload)?;
2155    Ok(())
2156}
2157
2158/// Notify other participants that a new buffer has been created
2159async fn create_buffer_for_peer(
2160    request: proto::CreateBufferForPeer,
2161    session: Session,
2162) -> Result<()> {
2163    session
2164        .db()
2165        .await
2166        .check_user_is_project_host(
2167            ProjectId::from_proto(request.project_id),
2168            session.connection_id,
2169        )
2170        .await?;
2171    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2172    session
2173        .peer
2174        .forward_send(session.connection_id, peer_id.into(), request)?;
2175    Ok(())
2176}
2177
2178/// Notify other participants that a buffer has been updated. This is
2179/// allowed for guests as long as the update is limited to selections.
2180async fn update_buffer(
2181    request: proto::UpdateBuffer,
2182    response: Response<proto::UpdateBuffer>,
2183    session: Session,
2184) -> Result<()> {
2185    let project_id = ProjectId::from_proto(request.project_id);
2186    let mut capability = Capability::ReadOnly;
2187
2188    for op in request.operations.iter() {
2189        match op.variant {
2190            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2191            Some(_) => capability = Capability::ReadWrite,
2192        }
2193    }
2194
2195    let host = {
2196        let guard = session
2197            .db()
2198            .await
2199            .connections_for_buffer_update(project_id, session.connection_id, capability)
2200            .await?;
2201
2202        let (host, guests) = &*guard;
2203
2204        broadcast(
2205            Some(session.connection_id),
2206            guests.clone(),
2207            |connection_id| {
2208                session
2209                    .peer
2210                    .forward_send(session.connection_id, connection_id, request.clone())
2211            },
2212        );
2213
2214        *host
2215    };
2216
2217    if host != session.connection_id {
2218        session
2219            .peer
2220            .forward_request(session.connection_id, host, request.clone())
2221            .await?;
2222    }
2223
2224    response.send(proto::Ack {})?;
2225    Ok(())
2226}
2227
2228async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2229    let project_id = ProjectId::from_proto(message.project_id);
2230
2231    let operation = message.operation.as_ref().context("invalid operation")?;
2232    let capability = match operation.variant.as_ref() {
2233        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2234            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2235                match buffer_op.variant {
2236                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2237                        Capability::ReadOnly
2238                    }
2239                    _ => Capability::ReadWrite,
2240                }
2241            } else {
2242                Capability::ReadWrite
2243            }
2244        }
2245        Some(_) => Capability::ReadWrite,
2246        None => Capability::ReadOnly,
2247    };
2248
2249    let guard = session
2250        .db()
2251        .await
2252        .connections_for_buffer_update(project_id, session.connection_id, capability)
2253        .await?;
2254
2255    let (host, guests) = &*guard;
2256
2257    broadcast(
2258        Some(session.connection_id),
2259        guests.iter().chain([host]).copied(),
2260        |connection_id| {
2261            session
2262                .peer
2263                .forward_send(session.connection_id, connection_id, message.clone())
2264        },
2265    );
2266
2267    Ok(())
2268}
2269
2270/// Notify other participants that a project has been updated.
2271async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2272    request: T,
2273    session: Session,
2274) -> Result<()> {
2275    let project_id = ProjectId::from_proto(request.remote_entity_id());
2276    let project_connection_ids = session
2277        .db()
2278        .await
2279        .project_connection_ids(project_id, session.connection_id, false)
2280        .await?;
2281
2282    broadcast(
2283        Some(session.connection_id),
2284        project_connection_ids.iter().copied(),
2285        |connection_id| {
2286            session
2287                .peer
2288                .forward_send(session.connection_id, connection_id, request.clone())
2289        },
2290    );
2291    Ok(())
2292}
2293
2294/// Start following another user in a call.
2295async fn follow(
2296    request: proto::Follow,
2297    response: Response<proto::Follow>,
2298    session: Session,
2299) -> Result<()> {
2300    let room_id = RoomId::from_proto(request.room_id);
2301    let project_id = request.project_id.map(ProjectId::from_proto);
2302    let leader_id = request
2303        .leader_id
2304        .ok_or_else(|| anyhow!("invalid leader id"))?
2305        .into();
2306    let follower_id = session.connection_id;
2307
2308    session
2309        .db()
2310        .await
2311        .check_room_participants(room_id, leader_id, session.connection_id)
2312        .await?;
2313
2314    let response_payload = session
2315        .peer
2316        .forward_request(session.connection_id, leader_id, request)
2317        .await?;
2318    response.send(response_payload)?;
2319
2320    if let Some(project_id) = project_id {
2321        let room = session
2322            .db()
2323            .await
2324            .follow(room_id, project_id, leader_id, follower_id)
2325            .await?;
2326        room_updated(&room, &session.peer);
2327    }
2328
2329    Ok(())
2330}
2331
2332/// Stop following another user in a call.
2333async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2334    let room_id = RoomId::from_proto(request.room_id);
2335    let project_id = request.project_id.map(ProjectId::from_proto);
2336    let leader_id = request
2337        .leader_id
2338        .ok_or_else(|| anyhow!("invalid leader id"))?
2339        .into();
2340    let follower_id = session.connection_id;
2341
2342    session
2343        .db()
2344        .await
2345        .check_room_participants(room_id, leader_id, session.connection_id)
2346        .await?;
2347
2348    session
2349        .peer
2350        .forward_send(session.connection_id, leader_id, request)?;
2351
2352    if let Some(project_id) = project_id {
2353        let room = session
2354            .db()
2355            .await
2356            .unfollow(room_id, project_id, leader_id, follower_id)
2357            .await?;
2358        room_updated(&room, &session.peer);
2359    }
2360
2361    Ok(())
2362}
2363
2364/// Notify everyone following you of your current location.
2365async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2366    let room_id = RoomId::from_proto(request.room_id);
2367    let database = session.db.lock().await;
2368
2369    let connection_ids = if let Some(project_id) = request.project_id {
2370        let project_id = ProjectId::from_proto(project_id);
2371        database
2372            .project_connection_ids(project_id, session.connection_id, true)
2373            .await?
2374    } else {
2375        database
2376            .room_connection_ids(room_id, session.connection_id)
2377            .await?
2378    };
2379
2380    // For now, don't send view update messages back to that view's current leader.
2381    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2382        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2383        _ => None,
2384    });
2385
2386    for connection_id in connection_ids.iter().cloned() {
2387        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2388            session
2389                .peer
2390                .forward_send(session.connection_id, connection_id, request.clone())?;
2391        }
2392    }
2393    Ok(())
2394}
2395
2396/// Get public data about users.
2397async fn get_users(
2398    request: proto::GetUsers,
2399    response: Response<proto::GetUsers>,
2400    session: Session,
2401) -> Result<()> {
2402    let user_ids = request
2403        .user_ids
2404        .into_iter()
2405        .map(UserId::from_proto)
2406        .collect();
2407    let users = session
2408        .db()
2409        .await
2410        .get_users_by_ids(user_ids)
2411        .await?
2412        .into_iter()
2413        .map(|user| proto::User {
2414            id: user.id.to_proto(),
2415            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2416            github_login: user.github_login,
2417        })
2418        .collect();
2419    response.send(proto::UsersResponse { users })?;
2420    Ok(())
2421}
2422
2423/// Search for users (to invite) buy Github login
2424async fn fuzzy_search_users(
2425    request: proto::FuzzySearchUsers,
2426    response: Response<proto::FuzzySearchUsers>,
2427    session: Session,
2428) -> Result<()> {
2429    let query = request.query;
2430    let users = match query.len() {
2431        0 => vec![],
2432        1 | 2 => session
2433            .db()
2434            .await
2435            .get_user_by_github_login(&query)
2436            .await?
2437            .into_iter()
2438            .collect(),
2439        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2440    };
2441    let users = users
2442        .into_iter()
2443        .filter(|user| user.id != session.user_id())
2444        .map(|user| proto::User {
2445            id: user.id.to_proto(),
2446            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2447            github_login: user.github_login,
2448        })
2449        .collect();
2450    response.send(proto::UsersResponse { users })?;
2451    Ok(())
2452}
2453
2454/// Send a contact request to another user.
2455async fn request_contact(
2456    request: proto::RequestContact,
2457    response: Response<proto::RequestContact>,
2458    session: Session,
2459) -> Result<()> {
2460    let requester_id = session.user_id();
2461    let responder_id = UserId::from_proto(request.responder_id);
2462    if requester_id == responder_id {
2463        return Err(anyhow!("cannot add yourself as a contact"))?;
2464    }
2465
2466    let notifications = session
2467        .db()
2468        .await
2469        .send_contact_request(requester_id, responder_id)
2470        .await?;
2471
2472    // Update outgoing contact requests of requester
2473    let mut update = proto::UpdateContacts::default();
2474    update.outgoing_requests.push(responder_id.to_proto());
2475    for connection_id in session
2476        .connection_pool()
2477        .await
2478        .user_connection_ids(requester_id)
2479    {
2480        session.peer.send(connection_id, update.clone())?;
2481    }
2482
2483    // Update incoming contact requests of responder
2484    let mut update = proto::UpdateContacts::default();
2485    update
2486        .incoming_requests
2487        .push(proto::IncomingContactRequest {
2488            requester_id: requester_id.to_proto(),
2489        });
2490    let connection_pool = session.connection_pool().await;
2491    for connection_id in connection_pool.user_connection_ids(responder_id) {
2492        session.peer.send(connection_id, update.clone())?;
2493    }
2494
2495    send_notifications(&connection_pool, &session.peer, notifications);
2496
2497    response.send(proto::Ack {})?;
2498    Ok(())
2499}
2500
2501/// Accept or decline a contact request
2502async fn respond_to_contact_request(
2503    request: proto::RespondToContactRequest,
2504    response: Response<proto::RespondToContactRequest>,
2505    session: Session,
2506) -> Result<()> {
2507    let responder_id = session.user_id();
2508    let requester_id = UserId::from_proto(request.requester_id);
2509    let db = session.db().await;
2510    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2511        db.dismiss_contact_notification(responder_id, requester_id)
2512            .await?;
2513    } else {
2514        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2515
2516        let notifications = db
2517            .respond_to_contact_request(responder_id, requester_id, accept)
2518            .await?;
2519        let requester_busy = db.is_user_busy(requester_id).await?;
2520        let responder_busy = db.is_user_busy(responder_id).await?;
2521
2522        let pool = session.connection_pool().await;
2523        // Update responder with new contact
2524        let mut update = proto::UpdateContacts::default();
2525        if accept {
2526            update
2527                .contacts
2528                .push(contact_for_user(requester_id, requester_busy, &pool));
2529        }
2530        update
2531            .remove_incoming_requests
2532            .push(requester_id.to_proto());
2533        for connection_id in pool.user_connection_ids(responder_id) {
2534            session.peer.send(connection_id, update.clone())?;
2535        }
2536
2537        // Update requester with new contact
2538        let mut update = proto::UpdateContacts::default();
2539        if accept {
2540            update
2541                .contacts
2542                .push(contact_for_user(responder_id, responder_busy, &pool));
2543        }
2544        update
2545            .remove_outgoing_requests
2546            .push(responder_id.to_proto());
2547
2548        for connection_id in pool.user_connection_ids(requester_id) {
2549            session.peer.send(connection_id, update.clone())?;
2550        }
2551
2552        send_notifications(&pool, &session.peer, notifications);
2553    }
2554
2555    response.send(proto::Ack {})?;
2556    Ok(())
2557}
2558
2559/// Remove a contact.
2560async fn remove_contact(
2561    request: proto::RemoveContact,
2562    response: Response<proto::RemoveContact>,
2563    session: Session,
2564) -> Result<()> {
2565    let requester_id = session.user_id();
2566    let responder_id = UserId::from_proto(request.user_id);
2567    let db = session.db().await;
2568    let (contact_accepted, deleted_notification_id) =
2569        db.remove_contact(requester_id, responder_id).await?;
2570
2571    let pool = session.connection_pool().await;
2572    // Update outgoing contact requests of requester
2573    let mut update = proto::UpdateContacts::default();
2574    if contact_accepted {
2575        update.remove_contacts.push(responder_id.to_proto());
2576    } else {
2577        update
2578            .remove_outgoing_requests
2579            .push(responder_id.to_proto());
2580    }
2581    for connection_id in pool.user_connection_ids(requester_id) {
2582        session.peer.send(connection_id, update.clone())?;
2583    }
2584
2585    // Update incoming contact requests of responder
2586    let mut update = proto::UpdateContacts::default();
2587    if contact_accepted {
2588        update.remove_contacts.push(requester_id.to_proto());
2589    } else {
2590        update
2591            .remove_incoming_requests
2592            .push(requester_id.to_proto());
2593    }
2594    for connection_id in pool.user_connection_ids(responder_id) {
2595        session.peer.send(connection_id, update.clone())?;
2596        if let Some(notification_id) = deleted_notification_id {
2597            session.peer.send(
2598                connection_id,
2599                proto::DeleteNotification {
2600                    notification_id: notification_id.to_proto(),
2601                },
2602            )?;
2603        }
2604    }
2605
2606    response.send(proto::Ack {})?;
2607    Ok(())
2608}
2609
2610fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2611    version.0.minor() < 139
2612}
2613
2614async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2615    let plan = session.current_plan(&session.db().await).await?;
2616
2617    session
2618        .peer
2619        .send(
2620            session.connection_id,
2621            proto::UpdateUserPlan { plan: plan.into() },
2622        )
2623        .trace_err();
2624
2625    Ok(())
2626}
2627
2628async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2629    subscribe_user_to_channels(session.user_id(), &session).await?;
2630    Ok(())
2631}
2632
2633async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2634    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2635    let mut pool = session.connection_pool().await;
2636    for membership in &channels_for_user.channel_memberships {
2637        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2638    }
2639    session.peer.send(
2640        session.connection_id,
2641        build_update_user_channels(&channels_for_user),
2642    )?;
2643    session.peer.send(
2644        session.connection_id,
2645        build_channels_update(channels_for_user),
2646    )?;
2647    Ok(())
2648}
2649
2650/// Creates a new channel.
2651async fn create_channel(
2652    request: proto::CreateChannel,
2653    response: Response<proto::CreateChannel>,
2654    session: Session,
2655) -> Result<()> {
2656    let db = session.db().await;
2657
2658    let parent_id = request.parent_id.map(ChannelId::from_proto);
2659    let (channel, membership) = db
2660        .create_channel(&request.name, parent_id, session.user_id())
2661        .await?;
2662
2663    let root_id = channel.root_id();
2664    let channel = Channel::from_model(channel);
2665
2666    response.send(proto::CreateChannelResponse {
2667        channel: Some(channel.to_proto()),
2668        parent_id: request.parent_id,
2669    })?;
2670
2671    let mut connection_pool = session.connection_pool().await;
2672    if let Some(membership) = membership {
2673        connection_pool.subscribe_to_channel(
2674            membership.user_id,
2675            membership.channel_id,
2676            membership.role,
2677        );
2678        let update = proto::UpdateUserChannels {
2679            channel_memberships: vec![proto::ChannelMembership {
2680                channel_id: membership.channel_id.to_proto(),
2681                role: membership.role.into(),
2682            }],
2683            ..Default::default()
2684        };
2685        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2686            session.peer.send(connection_id, update.clone())?;
2687        }
2688    }
2689
2690    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2691        if !role.can_see_channel(channel.visibility) {
2692            continue;
2693        }
2694
2695        let update = proto::UpdateChannels {
2696            channels: vec![channel.to_proto()],
2697            ..Default::default()
2698        };
2699        session.peer.send(connection_id, update.clone())?;
2700    }
2701
2702    Ok(())
2703}
2704
2705/// Delete a channel
2706async fn delete_channel(
2707    request: proto::DeleteChannel,
2708    response: Response<proto::DeleteChannel>,
2709    session: Session,
2710) -> Result<()> {
2711    let db = session.db().await;
2712
2713    let channel_id = request.channel_id;
2714    let (root_channel, removed_channels) = db
2715        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2716        .await?;
2717    response.send(proto::Ack {})?;
2718
2719    // Notify members of removed channels
2720    let mut update = proto::UpdateChannels::default();
2721    update
2722        .delete_channels
2723        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2724
2725    let connection_pool = session.connection_pool().await;
2726    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2727        session.peer.send(connection_id, update.clone())?;
2728    }
2729
2730    Ok(())
2731}
2732
2733/// Invite someone to join a channel.
2734async fn invite_channel_member(
2735    request: proto::InviteChannelMember,
2736    response: Response<proto::InviteChannelMember>,
2737    session: Session,
2738) -> Result<()> {
2739    let db = session.db().await;
2740    let channel_id = ChannelId::from_proto(request.channel_id);
2741    let invitee_id = UserId::from_proto(request.user_id);
2742    let InviteMemberResult {
2743        channel,
2744        notifications,
2745    } = db
2746        .invite_channel_member(
2747            channel_id,
2748            invitee_id,
2749            session.user_id(),
2750            request.role().into(),
2751        )
2752        .await?;
2753
2754    let update = proto::UpdateChannels {
2755        channel_invitations: vec![channel.to_proto()],
2756        ..Default::default()
2757    };
2758
2759    let connection_pool = session.connection_pool().await;
2760    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2761        session.peer.send(connection_id, update.clone())?;
2762    }
2763
2764    send_notifications(&connection_pool, &session.peer, notifications);
2765
2766    response.send(proto::Ack {})?;
2767    Ok(())
2768}
2769
2770/// remove someone from a channel
2771async fn remove_channel_member(
2772    request: proto::RemoveChannelMember,
2773    response: Response<proto::RemoveChannelMember>,
2774    session: Session,
2775) -> Result<()> {
2776    let db = session.db().await;
2777    let channel_id = ChannelId::from_proto(request.channel_id);
2778    let member_id = UserId::from_proto(request.user_id);
2779
2780    let RemoveChannelMemberResult {
2781        membership_update,
2782        notification_id,
2783    } = db
2784        .remove_channel_member(channel_id, member_id, session.user_id())
2785        .await?;
2786
2787    let mut connection_pool = session.connection_pool().await;
2788    notify_membership_updated(
2789        &mut connection_pool,
2790        membership_update,
2791        member_id,
2792        &session.peer,
2793    );
2794    for connection_id in connection_pool.user_connection_ids(member_id) {
2795        if let Some(notification_id) = notification_id {
2796            session
2797                .peer
2798                .send(
2799                    connection_id,
2800                    proto::DeleteNotification {
2801                        notification_id: notification_id.to_proto(),
2802                    },
2803                )
2804                .trace_err();
2805        }
2806    }
2807
2808    response.send(proto::Ack {})?;
2809    Ok(())
2810}
2811
2812/// Toggle the channel between public and private.
2813/// Care is taken to maintain the invariant that public channels only descend from public channels,
2814/// (though members-only channels can appear at any point in the hierarchy).
2815async fn set_channel_visibility(
2816    request: proto::SetChannelVisibility,
2817    response: Response<proto::SetChannelVisibility>,
2818    session: Session,
2819) -> Result<()> {
2820    let db = session.db().await;
2821    let channel_id = ChannelId::from_proto(request.channel_id);
2822    let visibility = request.visibility().into();
2823
2824    let channel_model = db
2825        .set_channel_visibility(channel_id, visibility, session.user_id())
2826        .await?;
2827    let root_id = channel_model.root_id();
2828    let channel = Channel::from_model(channel_model);
2829
2830    let mut connection_pool = session.connection_pool().await;
2831    for (user_id, role) in connection_pool
2832        .channel_user_ids(root_id)
2833        .collect::<Vec<_>>()
2834        .into_iter()
2835    {
2836        let update = if role.can_see_channel(channel.visibility) {
2837            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2838            proto::UpdateChannels {
2839                channels: vec![channel.to_proto()],
2840                ..Default::default()
2841            }
2842        } else {
2843            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2844            proto::UpdateChannels {
2845                delete_channels: vec![channel.id.to_proto()],
2846                ..Default::default()
2847            }
2848        };
2849
2850        for connection_id in connection_pool.user_connection_ids(user_id) {
2851            session.peer.send(connection_id, update.clone())?;
2852        }
2853    }
2854
2855    response.send(proto::Ack {})?;
2856    Ok(())
2857}
2858
2859/// Alter the role for a user in the channel.
2860async fn set_channel_member_role(
2861    request: proto::SetChannelMemberRole,
2862    response: Response<proto::SetChannelMemberRole>,
2863    session: Session,
2864) -> Result<()> {
2865    let db = session.db().await;
2866    let channel_id = ChannelId::from_proto(request.channel_id);
2867    let member_id = UserId::from_proto(request.user_id);
2868    let result = db
2869        .set_channel_member_role(
2870            channel_id,
2871            session.user_id(),
2872            member_id,
2873            request.role().into(),
2874        )
2875        .await?;
2876
2877    match result {
2878        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2879            let mut connection_pool = session.connection_pool().await;
2880            notify_membership_updated(
2881                &mut connection_pool,
2882                membership_update,
2883                member_id,
2884                &session.peer,
2885            )
2886        }
2887        db::SetMemberRoleResult::InviteUpdated(channel) => {
2888            let update = proto::UpdateChannels {
2889                channel_invitations: vec![channel.to_proto()],
2890                ..Default::default()
2891            };
2892
2893            for connection_id in session
2894                .connection_pool()
2895                .await
2896                .user_connection_ids(member_id)
2897            {
2898                session.peer.send(connection_id, update.clone())?;
2899            }
2900        }
2901    }
2902
2903    response.send(proto::Ack {})?;
2904    Ok(())
2905}
2906
2907/// Change the name of a channel
2908async fn rename_channel(
2909    request: proto::RenameChannel,
2910    response: Response<proto::RenameChannel>,
2911    session: Session,
2912) -> Result<()> {
2913    let db = session.db().await;
2914    let channel_id = ChannelId::from_proto(request.channel_id);
2915    let channel_model = db
2916        .rename_channel(channel_id, session.user_id(), &request.name)
2917        .await?;
2918    let root_id = channel_model.root_id();
2919    let channel = Channel::from_model(channel_model);
2920
2921    response.send(proto::RenameChannelResponse {
2922        channel: Some(channel.to_proto()),
2923    })?;
2924
2925    let connection_pool = session.connection_pool().await;
2926    let update = proto::UpdateChannels {
2927        channels: vec![channel.to_proto()],
2928        ..Default::default()
2929    };
2930    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2931        if role.can_see_channel(channel.visibility) {
2932            session.peer.send(connection_id, update.clone())?;
2933        }
2934    }
2935
2936    Ok(())
2937}
2938
2939/// Move a channel to a new parent.
2940async fn move_channel(
2941    request: proto::MoveChannel,
2942    response: Response<proto::MoveChannel>,
2943    session: Session,
2944) -> Result<()> {
2945    let channel_id = ChannelId::from_proto(request.channel_id);
2946    let to = ChannelId::from_proto(request.to);
2947
2948    let (root_id, channels) = session
2949        .db()
2950        .await
2951        .move_channel(channel_id, to, session.user_id())
2952        .await?;
2953
2954    let connection_pool = session.connection_pool().await;
2955    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2956        let channels = channels
2957            .iter()
2958            .filter_map(|channel| {
2959                if role.can_see_channel(channel.visibility) {
2960                    Some(channel.to_proto())
2961                } else {
2962                    None
2963                }
2964            })
2965            .collect::<Vec<_>>();
2966        if channels.is_empty() {
2967            continue;
2968        }
2969
2970        let update = proto::UpdateChannels {
2971            channels,
2972            ..Default::default()
2973        };
2974
2975        session.peer.send(connection_id, update.clone())?;
2976    }
2977
2978    response.send(Ack {})?;
2979    Ok(())
2980}
2981
2982/// Get the list of channel members
2983async fn get_channel_members(
2984    request: proto::GetChannelMembers,
2985    response: Response<proto::GetChannelMembers>,
2986    session: Session,
2987) -> Result<()> {
2988    let db = session.db().await;
2989    let channel_id = ChannelId::from_proto(request.channel_id);
2990    let limit = if request.limit == 0 {
2991        u16::MAX as u64
2992    } else {
2993        request.limit
2994    };
2995    let (members, users) = db
2996        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
2997        .await?;
2998    response.send(proto::GetChannelMembersResponse { members, users })?;
2999    Ok(())
3000}
3001
3002/// Accept or decline a channel invitation.
3003async fn respond_to_channel_invite(
3004    request: proto::RespondToChannelInvite,
3005    response: Response<proto::RespondToChannelInvite>,
3006    session: Session,
3007) -> Result<()> {
3008    let db = session.db().await;
3009    let channel_id = ChannelId::from_proto(request.channel_id);
3010    let RespondToChannelInvite {
3011        membership_update,
3012        notifications,
3013    } = db
3014        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3015        .await?;
3016
3017    let mut connection_pool = session.connection_pool().await;
3018    if let Some(membership_update) = membership_update {
3019        notify_membership_updated(
3020            &mut connection_pool,
3021            membership_update,
3022            session.user_id(),
3023            &session.peer,
3024        );
3025    } else {
3026        let update = proto::UpdateChannels {
3027            remove_channel_invitations: vec![channel_id.to_proto()],
3028            ..Default::default()
3029        };
3030
3031        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3032            session.peer.send(connection_id, update.clone())?;
3033        }
3034    };
3035
3036    send_notifications(&connection_pool, &session.peer, notifications);
3037
3038    response.send(proto::Ack {})?;
3039
3040    Ok(())
3041}
3042
3043/// Join the channels' room
3044async fn join_channel(
3045    request: proto::JoinChannel,
3046    response: Response<proto::JoinChannel>,
3047    session: Session,
3048) -> Result<()> {
3049    let channel_id = ChannelId::from_proto(request.channel_id);
3050    join_channel_internal(channel_id, Box::new(response), session).await
3051}
3052
3053trait JoinChannelInternalResponse {
3054    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3055}
3056impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3057    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3058        Response::<proto::JoinChannel>::send(self, result)
3059    }
3060}
3061impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3062    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3063        Response::<proto::JoinRoom>::send(self, result)
3064    }
3065}
3066
3067async fn join_channel_internal(
3068    channel_id: ChannelId,
3069    response: Box<impl JoinChannelInternalResponse>,
3070    session: Session,
3071) -> Result<()> {
3072    let joined_room = {
3073        let mut db = session.db().await;
3074        // If zed quits without leaving the room, and the user re-opens zed before the
3075        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3076        // room they were in.
3077        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3078            tracing::info!(
3079                stale_connection_id = %connection,
3080                "cleaning up stale connection",
3081            );
3082            drop(db);
3083            leave_room_for_session(&session, connection).await?;
3084            db = session.db().await;
3085        }
3086
3087        let (joined_room, membership_updated, role) = db
3088            .join_channel(channel_id, session.user_id(), session.connection_id)
3089            .await?;
3090
3091        let live_kit_connection_info =
3092            session
3093                .app_state
3094                .live_kit_client
3095                .as_ref()
3096                .and_then(|live_kit| {
3097                    let (can_publish, token) = if role == ChannelRole::Guest {
3098                        (
3099                            false,
3100                            live_kit
3101                                .guest_token(
3102                                    &joined_room.room.live_kit_room,
3103                                    &session.user_id().to_string(),
3104                                )
3105                                .trace_err()?,
3106                        )
3107                    } else {
3108                        (
3109                            true,
3110                            live_kit
3111                                .room_token(
3112                                    &joined_room.room.live_kit_room,
3113                                    &session.user_id().to_string(),
3114                                )
3115                                .trace_err()?,
3116                        )
3117                    };
3118
3119                    Some(LiveKitConnectionInfo {
3120                        server_url: live_kit.url().into(),
3121                        token,
3122                        can_publish,
3123                    })
3124                });
3125
3126        response.send(proto::JoinRoomResponse {
3127            room: Some(joined_room.room.clone()),
3128            channel_id: joined_room
3129                .channel
3130                .as_ref()
3131                .map(|channel| channel.id.to_proto()),
3132            live_kit_connection_info,
3133        })?;
3134
3135        let mut connection_pool = session.connection_pool().await;
3136        if let Some(membership_updated) = membership_updated {
3137            notify_membership_updated(
3138                &mut connection_pool,
3139                membership_updated,
3140                session.user_id(),
3141                &session.peer,
3142            );
3143        }
3144
3145        room_updated(&joined_room.room, &session.peer);
3146
3147        joined_room
3148    };
3149
3150    channel_updated(
3151        &joined_room
3152            .channel
3153            .ok_or_else(|| anyhow!("channel not returned"))?,
3154        &joined_room.room,
3155        &session.peer,
3156        &*session.connection_pool().await,
3157    );
3158
3159    update_user_contacts(session.user_id(), &session).await?;
3160    Ok(())
3161}
3162
3163/// Start editing the channel notes
3164async fn join_channel_buffer(
3165    request: proto::JoinChannelBuffer,
3166    response: Response<proto::JoinChannelBuffer>,
3167    session: Session,
3168) -> Result<()> {
3169    let db = session.db().await;
3170    let channel_id = ChannelId::from_proto(request.channel_id);
3171
3172    let open_response = db
3173        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3174        .await?;
3175
3176    let collaborators = open_response.collaborators.clone();
3177    response.send(open_response)?;
3178
3179    let update = UpdateChannelBufferCollaborators {
3180        channel_id: channel_id.to_proto(),
3181        collaborators: collaborators.clone(),
3182    };
3183    channel_buffer_updated(
3184        session.connection_id,
3185        collaborators
3186            .iter()
3187            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3188        &update,
3189        &session.peer,
3190    );
3191
3192    Ok(())
3193}
3194
3195/// Edit the channel notes
3196async fn update_channel_buffer(
3197    request: proto::UpdateChannelBuffer,
3198    session: Session,
3199) -> Result<()> {
3200    let db = session.db().await;
3201    let channel_id = ChannelId::from_proto(request.channel_id);
3202
3203    let (collaborators, epoch, version) = db
3204        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3205        .await?;
3206
3207    channel_buffer_updated(
3208        session.connection_id,
3209        collaborators.clone(),
3210        &proto::UpdateChannelBuffer {
3211            channel_id: channel_id.to_proto(),
3212            operations: request.operations,
3213        },
3214        &session.peer,
3215    );
3216
3217    let pool = &*session.connection_pool().await;
3218
3219    let non_collaborators =
3220        pool.channel_connection_ids(channel_id)
3221            .filter_map(|(connection_id, _)| {
3222                if collaborators.contains(&connection_id) {
3223                    None
3224                } else {
3225                    Some(connection_id)
3226                }
3227            });
3228
3229    broadcast(None, non_collaborators, |peer_id| {
3230        session.peer.send(
3231            peer_id,
3232            proto::UpdateChannels {
3233                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3234                    channel_id: channel_id.to_proto(),
3235                    epoch: epoch as u64,
3236                    version: version.clone(),
3237                }],
3238                ..Default::default()
3239            },
3240        )
3241    });
3242
3243    Ok(())
3244}
3245
3246/// Rejoin the channel notes after a connection blip
3247async fn rejoin_channel_buffers(
3248    request: proto::RejoinChannelBuffers,
3249    response: Response<proto::RejoinChannelBuffers>,
3250    session: Session,
3251) -> Result<()> {
3252    let db = session.db().await;
3253    let buffers = db
3254        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3255        .await?;
3256
3257    for rejoined_buffer in &buffers {
3258        let collaborators_to_notify = rejoined_buffer
3259            .buffer
3260            .collaborators
3261            .iter()
3262            .filter_map(|c| Some(c.peer_id?.into()));
3263        channel_buffer_updated(
3264            session.connection_id,
3265            collaborators_to_notify,
3266            &proto::UpdateChannelBufferCollaborators {
3267                channel_id: rejoined_buffer.buffer.channel_id,
3268                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3269            },
3270            &session.peer,
3271        );
3272    }
3273
3274    response.send(proto::RejoinChannelBuffersResponse {
3275        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3276    })?;
3277
3278    Ok(())
3279}
3280
3281/// Stop editing the channel notes
3282async fn leave_channel_buffer(
3283    request: proto::LeaveChannelBuffer,
3284    response: Response<proto::LeaveChannelBuffer>,
3285    session: Session,
3286) -> Result<()> {
3287    let db = session.db().await;
3288    let channel_id = ChannelId::from_proto(request.channel_id);
3289
3290    let left_buffer = db
3291        .leave_channel_buffer(channel_id, session.connection_id)
3292        .await?;
3293
3294    response.send(Ack {})?;
3295
3296    channel_buffer_updated(
3297        session.connection_id,
3298        left_buffer.connections,
3299        &proto::UpdateChannelBufferCollaborators {
3300            channel_id: channel_id.to_proto(),
3301            collaborators: left_buffer.collaborators,
3302        },
3303        &session.peer,
3304    );
3305
3306    Ok(())
3307}
3308
3309fn channel_buffer_updated<T: EnvelopedMessage>(
3310    sender_id: ConnectionId,
3311    collaborators: impl IntoIterator<Item = ConnectionId>,
3312    message: &T,
3313    peer: &Peer,
3314) {
3315    broadcast(Some(sender_id), collaborators, |peer_id| {
3316        peer.send(peer_id, message.clone())
3317    });
3318}
3319
3320fn send_notifications(
3321    connection_pool: &ConnectionPool,
3322    peer: &Peer,
3323    notifications: db::NotificationBatch,
3324) {
3325    for (user_id, notification) in notifications {
3326        for connection_id in connection_pool.user_connection_ids(user_id) {
3327            if let Err(error) = peer.send(
3328                connection_id,
3329                proto::AddNotification {
3330                    notification: Some(notification.clone()),
3331                },
3332            ) {
3333                tracing::error!(
3334                    "failed to send notification to {:?} {}",
3335                    connection_id,
3336                    error
3337                );
3338            }
3339        }
3340    }
3341}
3342
3343/// Send a message to the channel
3344async fn send_channel_message(
3345    request: proto::SendChannelMessage,
3346    response: Response<proto::SendChannelMessage>,
3347    session: Session,
3348) -> Result<()> {
3349    // Validate the message body.
3350    let body = request.body.trim().to_string();
3351    if body.len() > MAX_MESSAGE_LEN {
3352        return Err(anyhow!("message is too long"))?;
3353    }
3354    if body.is_empty() {
3355        return Err(anyhow!("message can't be blank"))?;
3356    }
3357
3358    // TODO: adjust mentions if body is trimmed
3359
3360    let timestamp = OffsetDateTime::now_utc();
3361    let nonce = request
3362        .nonce
3363        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3364
3365    let channel_id = ChannelId::from_proto(request.channel_id);
3366    let CreatedChannelMessage {
3367        message_id,
3368        participant_connection_ids,
3369        notifications,
3370    } = session
3371        .db()
3372        .await
3373        .create_channel_message(
3374            channel_id,
3375            session.user_id(),
3376            &body,
3377            &request.mentions,
3378            timestamp,
3379            nonce.clone().into(),
3380            request.reply_to_message_id.map(MessageId::from_proto),
3381        )
3382        .await?;
3383
3384    let message = proto::ChannelMessage {
3385        sender_id: session.user_id().to_proto(),
3386        id: message_id.to_proto(),
3387        body,
3388        mentions: request.mentions,
3389        timestamp: timestamp.unix_timestamp() as u64,
3390        nonce: Some(nonce),
3391        reply_to_message_id: request.reply_to_message_id,
3392        edited_at: None,
3393    };
3394    broadcast(
3395        Some(session.connection_id),
3396        participant_connection_ids.clone(),
3397        |connection| {
3398            session.peer.send(
3399                connection,
3400                proto::ChannelMessageSent {
3401                    channel_id: channel_id.to_proto(),
3402                    message: Some(message.clone()),
3403                },
3404            )
3405        },
3406    );
3407    response.send(proto::SendChannelMessageResponse {
3408        message: Some(message),
3409    })?;
3410
3411    let pool = &*session.connection_pool().await;
3412    let non_participants =
3413        pool.channel_connection_ids(channel_id)
3414            .filter_map(|(connection_id, _)| {
3415                if participant_connection_ids.contains(&connection_id) {
3416                    None
3417                } else {
3418                    Some(connection_id)
3419                }
3420            });
3421    broadcast(None, non_participants, |peer_id| {
3422        session.peer.send(
3423            peer_id,
3424            proto::UpdateChannels {
3425                latest_channel_message_ids: vec![proto::ChannelMessageId {
3426                    channel_id: channel_id.to_proto(),
3427                    message_id: message_id.to_proto(),
3428                }],
3429                ..Default::default()
3430            },
3431        )
3432    });
3433    send_notifications(pool, &session.peer, notifications);
3434
3435    Ok(())
3436}
3437
3438/// Delete a channel message
3439async fn remove_channel_message(
3440    request: proto::RemoveChannelMessage,
3441    response: Response<proto::RemoveChannelMessage>,
3442    session: Session,
3443) -> Result<()> {
3444    let channel_id = ChannelId::from_proto(request.channel_id);
3445    let message_id = MessageId::from_proto(request.message_id);
3446    let (connection_ids, existing_notification_ids) = session
3447        .db()
3448        .await
3449        .remove_channel_message(channel_id, message_id, session.user_id())
3450        .await?;
3451
3452    broadcast(
3453        Some(session.connection_id),
3454        connection_ids,
3455        move |connection| {
3456            session.peer.send(connection, request.clone())?;
3457
3458            for notification_id in &existing_notification_ids {
3459                session.peer.send(
3460                    connection,
3461                    proto::DeleteNotification {
3462                        notification_id: (*notification_id).to_proto(),
3463                    },
3464                )?;
3465            }
3466
3467            Ok(())
3468        },
3469    );
3470    response.send(proto::Ack {})?;
3471    Ok(())
3472}
3473
3474async fn update_channel_message(
3475    request: proto::UpdateChannelMessage,
3476    response: Response<proto::UpdateChannelMessage>,
3477    session: Session,
3478) -> Result<()> {
3479    let channel_id = ChannelId::from_proto(request.channel_id);
3480    let message_id = MessageId::from_proto(request.message_id);
3481    let updated_at = OffsetDateTime::now_utc();
3482    let UpdatedChannelMessage {
3483        message_id,
3484        participant_connection_ids,
3485        notifications,
3486        reply_to_message_id,
3487        timestamp,
3488        deleted_mention_notification_ids,
3489        updated_mention_notifications,
3490    } = session
3491        .db()
3492        .await
3493        .update_channel_message(
3494            channel_id,
3495            message_id,
3496            session.user_id(),
3497            request.body.as_str(),
3498            &request.mentions,
3499            updated_at,
3500        )
3501        .await?;
3502
3503    let nonce = request
3504        .nonce
3505        .clone()
3506        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3507
3508    let message = proto::ChannelMessage {
3509        sender_id: session.user_id().to_proto(),
3510        id: message_id.to_proto(),
3511        body: request.body.clone(),
3512        mentions: request.mentions.clone(),
3513        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3514        nonce: Some(nonce),
3515        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3516        edited_at: Some(updated_at.unix_timestamp() as u64),
3517    };
3518
3519    response.send(proto::Ack {})?;
3520
3521    let pool = &*session.connection_pool().await;
3522    broadcast(
3523        Some(session.connection_id),
3524        participant_connection_ids,
3525        |connection| {
3526            session.peer.send(
3527                connection,
3528                proto::ChannelMessageUpdate {
3529                    channel_id: channel_id.to_proto(),
3530                    message: Some(message.clone()),
3531                },
3532            )?;
3533
3534            for notification_id in &deleted_mention_notification_ids {
3535                session.peer.send(
3536                    connection,
3537                    proto::DeleteNotification {
3538                        notification_id: (*notification_id).to_proto(),
3539                    },
3540                )?;
3541            }
3542
3543            for notification in &updated_mention_notifications {
3544                session.peer.send(
3545                    connection,
3546                    proto::UpdateNotification {
3547                        notification: Some(notification.clone()),
3548                    },
3549                )?;
3550            }
3551
3552            Ok(())
3553        },
3554    );
3555
3556    send_notifications(pool, &session.peer, notifications);
3557
3558    Ok(())
3559}
3560
3561/// Mark a channel message as read
3562async fn acknowledge_channel_message(
3563    request: proto::AckChannelMessage,
3564    session: Session,
3565) -> Result<()> {
3566    let channel_id = ChannelId::from_proto(request.channel_id);
3567    let message_id = MessageId::from_proto(request.message_id);
3568    let notifications = session
3569        .db()
3570        .await
3571        .observe_channel_message(channel_id, session.user_id(), message_id)
3572        .await?;
3573    send_notifications(
3574        &*session.connection_pool().await,
3575        &session.peer,
3576        notifications,
3577    );
3578    Ok(())
3579}
3580
3581/// Mark a buffer version as synced
3582async fn acknowledge_buffer_version(
3583    request: proto::AckBufferOperation,
3584    session: Session,
3585) -> Result<()> {
3586    let buffer_id = BufferId::from_proto(request.buffer_id);
3587    session
3588        .db()
3589        .await
3590        .observe_buffer_version(
3591            buffer_id,
3592            session.user_id(),
3593            request.epoch as i32,
3594            &request.version,
3595        )
3596        .await?;
3597    Ok(())
3598}
3599
3600async fn count_language_model_tokens(
3601    request: proto::CountLanguageModelTokens,
3602    response: Response<proto::CountLanguageModelTokens>,
3603    session: Session,
3604    config: &Config,
3605) -> Result<()> {
3606    authorize_access_to_legacy_llm_endpoints(&session).await?;
3607
3608    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3609        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3610        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3611    };
3612
3613    session
3614        .app_state
3615        .rate_limiter
3616        .check(&*rate_limit, session.user_id())
3617        .await?;
3618
3619    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3620        Some(proto::LanguageModelProvider::Google) => {
3621            let api_key = config
3622                .google_ai_api_key
3623                .as_ref()
3624                .context("no Google AI API key configured on the server")?;
3625            google_ai::count_tokens(
3626                session.http_client.as_ref(),
3627                google_ai::API_URL,
3628                api_key,
3629                serde_json::from_str(&request.request)?,
3630            )
3631            .await?
3632        }
3633        _ => return Err(anyhow!("unsupported provider"))?,
3634    };
3635
3636    response.send(proto::CountLanguageModelTokensResponse {
3637        token_count: result.total_tokens as u32,
3638    })?;
3639
3640    Ok(())
3641}
3642
3643struct ZedProCountLanguageModelTokensRateLimit;
3644
3645impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3646    fn capacity(&self) -> usize {
3647        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3648            .ok()
3649            .and_then(|v| v.parse().ok())
3650            .unwrap_or(600) // Picked arbitrarily
3651    }
3652
3653    fn refill_duration(&self) -> chrono::Duration {
3654        chrono::Duration::hours(1)
3655    }
3656
3657    fn db_name(&self) -> &'static str {
3658        "zed-pro:count-language-model-tokens"
3659    }
3660}
3661
3662struct FreeCountLanguageModelTokensRateLimit;
3663
3664impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3665    fn capacity(&self) -> usize {
3666        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3667            .ok()
3668            .and_then(|v| v.parse().ok())
3669            .unwrap_or(600 / 10) // Picked arbitrarily
3670    }
3671
3672    fn refill_duration(&self) -> chrono::Duration {
3673        chrono::Duration::hours(1)
3674    }
3675
3676    fn db_name(&self) -> &'static str {
3677        "free:count-language-model-tokens"
3678    }
3679}
3680
3681struct ZedProComputeEmbeddingsRateLimit;
3682
3683impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3684    fn capacity(&self) -> usize {
3685        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3686            .ok()
3687            .and_then(|v| v.parse().ok())
3688            .unwrap_or(5000) // Picked arbitrarily
3689    }
3690
3691    fn refill_duration(&self) -> chrono::Duration {
3692        chrono::Duration::hours(1)
3693    }
3694
3695    fn db_name(&self) -> &'static str {
3696        "zed-pro:compute-embeddings"
3697    }
3698}
3699
3700struct FreeComputeEmbeddingsRateLimit;
3701
3702impl RateLimit for FreeComputeEmbeddingsRateLimit {
3703    fn capacity(&self) -> usize {
3704        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3705            .ok()
3706            .and_then(|v| v.parse().ok())
3707            .unwrap_or(5000 / 10) // Picked arbitrarily
3708    }
3709
3710    fn refill_duration(&self) -> chrono::Duration {
3711        chrono::Duration::hours(1)
3712    }
3713
3714    fn db_name(&self) -> &'static str {
3715        "free:compute-embeddings"
3716    }
3717}
3718
3719async fn compute_embeddings(
3720    request: proto::ComputeEmbeddings,
3721    response: Response<proto::ComputeEmbeddings>,
3722    session: Session,
3723    api_key: Option<Arc<str>>,
3724) -> Result<()> {
3725    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3726    authorize_access_to_legacy_llm_endpoints(&session).await?;
3727
3728    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3729        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3730        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3731    };
3732
3733    session
3734        .app_state
3735        .rate_limiter
3736        .check(&*rate_limit, session.user_id())
3737        .await?;
3738
3739    let embeddings = match request.model.as_str() {
3740        "openai/text-embedding-3-small" => {
3741            open_ai::embed(
3742                session.http_client.as_ref(),
3743                OPEN_AI_API_URL,
3744                &api_key,
3745                OpenAiEmbeddingModel::TextEmbedding3Small,
3746                request.texts.iter().map(|text| text.as_str()),
3747            )
3748            .await?
3749        }
3750        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3751    };
3752
3753    let embeddings = request
3754        .texts
3755        .iter()
3756        .map(|text| {
3757            let mut hasher = sha2::Sha256::new();
3758            hasher.update(text.as_bytes());
3759            let result = hasher.finalize();
3760            result.to_vec()
3761        })
3762        .zip(
3763            embeddings
3764                .data
3765                .into_iter()
3766                .map(|embedding| embedding.embedding),
3767        )
3768        .collect::<HashMap<_, _>>();
3769
3770    let db = session.db().await;
3771    db.save_embeddings(&request.model, &embeddings)
3772        .await
3773        .context("failed to save embeddings")
3774        .trace_err();
3775
3776    response.send(proto::ComputeEmbeddingsResponse {
3777        embeddings: embeddings
3778            .into_iter()
3779            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3780            .collect(),
3781    })?;
3782    Ok(())
3783}
3784
3785async fn get_cached_embeddings(
3786    request: proto::GetCachedEmbeddings,
3787    response: Response<proto::GetCachedEmbeddings>,
3788    session: Session,
3789) -> Result<()> {
3790    authorize_access_to_legacy_llm_endpoints(&session).await?;
3791
3792    let db = session.db().await;
3793    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3794
3795    response.send(proto::GetCachedEmbeddingsResponse {
3796        embeddings: embeddings
3797            .into_iter()
3798            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3799            .collect(),
3800    })?;
3801    Ok(())
3802}
3803
3804/// This is leftover from before the LLM service.
3805///
3806/// The endpoints protected by this check will be moved there eventually.
3807async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3808    if session.is_staff() {
3809        Ok(())
3810    } else {
3811        Err(anyhow!("permission denied"))?
3812    }
3813}
3814
3815/// Get a Supermaven API key for the user
3816async fn get_supermaven_api_key(
3817    _request: proto::GetSupermavenApiKey,
3818    response: Response<proto::GetSupermavenApiKey>,
3819    session: Session,
3820) -> Result<()> {
3821    let user_id: String = session.user_id().to_string();
3822    if !session.is_staff() {
3823        return Err(anyhow!("supermaven not enabled for this account"))?;
3824    }
3825
3826    let email = session
3827        .email()
3828        .ok_or_else(|| anyhow!("user must have an email"))?;
3829
3830    let supermaven_admin_api = session
3831        .supermaven_client
3832        .as_ref()
3833        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3834
3835    let result = supermaven_admin_api
3836        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3837        .await?;
3838
3839    response.send(proto::GetSupermavenApiKeyResponse {
3840        api_key: result.api_key,
3841    })?;
3842
3843    Ok(())
3844}
3845
3846/// Start receiving chat updates for a channel
3847async fn join_channel_chat(
3848    request: proto::JoinChannelChat,
3849    response: Response<proto::JoinChannelChat>,
3850    session: Session,
3851) -> Result<()> {
3852    let channel_id = ChannelId::from_proto(request.channel_id);
3853
3854    let db = session.db().await;
3855    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3856        .await?;
3857    let messages = db
3858        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3859        .await?;
3860    response.send(proto::JoinChannelChatResponse {
3861        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3862        messages,
3863    })?;
3864    Ok(())
3865}
3866
3867/// Stop receiving chat updates for a channel
3868async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3869    let channel_id = ChannelId::from_proto(request.channel_id);
3870    session
3871        .db()
3872        .await
3873        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3874        .await?;
3875    Ok(())
3876}
3877
3878/// Retrieve the chat history for a channel
3879async fn get_channel_messages(
3880    request: proto::GetChannelMessages,
3881    response: Response<proto::GetChannelMessages>,
3882    session: Session,
3883) -> Result<()> {
3884    let channel_id = ChannelId::from_proto(request.channel_id);
3885    let messages = session
3886        .db()
3887        .await
3888        .get_channel_messages(
3889            channel_id,
3890            session.user_id(),
3891            MESSAGE_COUNT_PER_PAGE,
3892            Some(MessageId::from_proto(request.before_message_id)),
3893        )
3894        .await?;
3895    response.send(proto::GetChannelMessagesResponse {
3896        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3897        messages,
3898    })?;
3899    Ok(())
3900}
3901
3902/// Retrieve specific chat messages
3903async fn get_channel_messages_by_id(
3904    request: proto::GetChannelMessagesById,
3905    response: Response<proto::GetChannelMessagesById>,
3906    session: Session,
3907) -> Result<()> {
3908    let message_ids = request
3909        .message_ids
3910        .iter()
3911        .map(|id| MessageId::from_proto(*id))
3912        .collect::<Vec<_>>();
3913    let messages = session
3914        .db()
3915        .await
3916        .get_channel_messages_by_id(session.user_id(), &message_ids)
3917        .await?;
3918    response.send(proto::GetChannelMessagesResponse {
3919        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3920        messages,
3921    })?;
3922    Ok(())
3923}
3924
3925/// Retrieve the current users notifications
3926async fn get_notifications(
3927    request: proto::GetNotifications,
3928    response: Response<proto::GetNotifications>,
3929    session: Session,
3930) -> Result<()> {
3931    let notifications = session
3932        .db()
3933        .await
3934        .get_notifications(
3935            session.user_id(),
3936            NOTIFICATION_COUNT_PER_PAGE,
3937            request.before_id.map(db::NotificationId::from_proto),
3938        )
3939        .await?;
3940    response.send(proto::GetNotificationsResponse {
3941        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3942        notifications,
3943    })?;
3944    Ok(())
3945}
3946
3947/// Mark notifications as read
3948async fn mark_notification_as_read(
3949    request: proto::MarkNotificationRead,
3950    response: Response<proto::MarkNotificationRead>,
3951    session: Session,
3952) -> Result<()> {
3953    let database = &session.db().await;
3954    let notifications = database
3955        .mark_notification_as_read_by_id(
3956            session.user_id(),
3957            NotificationId::from_proto(request.notification_id),
3958        )
3959        .await?;
3960    send_notifications(
3961        &*session.connection_pool().await,
3962        &session.peer,
3963        notifications,
3964    );
3965    response.send(proto::Ack {})?;
3966    Ok(())
3967}
3968
3969/// Get the current users information
3970async fn get_private_user_info(
3971    _request: proto::GetPrivateUserInfo,
3972    response: Response<proto::GetPrivateUserInfo>,
3973    session: Session,
3974) -> Result<()> {
3975    let db = session.db().await;
3976
3977    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3978    let user = db
3979        .get_user_by_id(session.user_id())
3980        .await?
3981        .ok_or_else(|| anyhow!("user not found"))?;
3982    let flags = db.get_user_flags(session.user_id()).await?;
3983
3984    response.send(proto::GetPrivateUserInfoResponse {
3985        metrics_id,
3986        staff: user.admin,
3987        flags,
3988        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3989    })?;
3990    Ok(())
3991}
3992
3993/// Accept the terms of service (tos) on behalf of the current user
3994async fn accept_terms_of_service(
3995    _request: proto::AcceptTermsOfService,
3996    response: Response<proto::AcceptTermsOfService>,
3997    session: Session,
3998) -> Result<()> {
3999    let db = session.db().await;
4000
4001    let accepted_tos_at = Utc::now();
4002    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4003        .await?;
4004
4005    response.send(proto::AcceptTermsOfServiceResponse {
4006        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4007    })?;
4008    Ok(())
4009}
4010
4011/// The minimum account age an account must have in order to use the LLM service.
4012const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4013
4014async fn get_llm_api_token(
4015    _request: proto::GetLlmToken,
4016    response: Response<proto::GetLlmToken>,
4017    session: Session,
4018) -> Result<()> {
4019    let db = session.db().await;
4020
4021    let flags = db.get_user_flags(session.user_id()).await?;
4022    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4023    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4024
4025    if !session.is_staff() && !has_language_models_feature_flag {
4026        Err(anyhow!("permission denied"))?
4027    }
4028
4029    let user_id = session.user_id();
4030    let user = db
4031        .get_user_by_id(user_id)
4032        .await?
4033        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4034
4035    if user.accepted_tos_at.is_none() {
4036        Err(anyhow!("terms of service not accepted"))?
4037    }
4038
4039    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4040
4041    let bypass_account_age_check =
4042        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4043    if !bypass_account_age_check {
4044        let mut account_created_at = user.created_at;
4045        if let Some(github_created_at) = user.github_user_created_at {
4046            account_created_at = account_created_at.min(github_created_at);
4047        }
4048        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4049            Err(anyhow!("account too young"))?
4050        }
4051    }
4052
4053    let billing_preferences = db.get_billing_preferences(user.id).await?;
4054
4055    let token = LlmTokenClaims::create(
4056        &user,
4057        session.is_staff(),
4058        billing_preferences,
4059        has_llm_closed_beta_feature_flag,
4060        has_llm_subscription,
4061        session.current_plan(&db).await?,
4062        session.system_id.clone(),
4063        &session.app_state.config,
4064    )?;
4065    response.send(proto::GetLlmTokenResponse { token })?;
4066    Ok(())
4067}
4068
4069fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4070    let message = match message {
4071        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4072        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4073        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4074        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4075        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4076            code: frame.code.into(),
4077            reason: frame.reason,
4078        })),
4079        // We should never receive a frame while reading the message, according
4080        // to the `tungstenite` maintainers:
4081        //
4082        // > It cannot occur when you read messages from the WebSocket, but it
4083        // > can be used when you want to send the raw frames (e.g. you want to
4084        // > send the frames to the WebSocket without composing the full message first).
4085        // >
4086        // > — https://github.com/snapview/tungstenite-rs/issues/268
4087        TungsteniteMessage::Frame(_) => {
4088            bail!("received an unexpected frame while reading the message")
4089        }
4090    };
4091
4092    Ok(message)
4093}
4094
4095fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4096    match message {
4097        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4098        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4099        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4100        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4101        AxumMessage::Close(frame) => {
4102            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4103                code: frame.code.into(),
4104                reason: frame.reason,
4105            }))
4106        }
4107    }
4108}
4109
4110fn notify_membership_updated(
4111    connection_pool: &mut ConnectionPool,
4112    result: MembershipUpdated,
4113    user_id: UserId,
4114    peer: &Peer,
4115) {
4116    for membership in &result.new_channels.channel_memberships {
4117        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4118    }
4119    for channel_id in &result.removed_channels {
4120        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4121    }
4122
4123    let user_channels_update = proto::UpdateUserChannels {
4124        channel_memberships: result
4125            .new_channels
4126            .channel_memberships
4127            .iter()
4128            .map(|cm| proto::ChannelMembership {
4129                channel_id: cm.channel_id.to_proto(),
4130                role: cm.role.into(),
4131            })
4132            .collect(),
4133        ..Default::default()
4134    };
4135
4136    let mut update = build_channels_update(result.new_channels);
4137    update.delete_channels = result
4138        .removed_channels
4139        .into_iter()
4140        .map(|id| id.to_proto())
4141        .collect();
4142    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4143
4144    for connection_id in connection_pool.user_connection_ids(user_id) {
4145        peer.send(connection_id, user_channels_update.clone())
4146            .trace_err();
4147        peer.send(connection_id, update.clone()).trace_err();
4148    }
4149}
4150
4151fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4152    proto::UpdateUserChannels {
4153        channel_memberships: channels
4154            .channel_memberships
4155            .iter()
4156            .map(|m| proto::ChannelMembership {
4157                channel_id: m.channel_id.to_proto(),
4158                role: m.role.into(),
4159            })
4160            .collect(),
4161        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4162        observed_channel_message_id: channels.observed_channel_messages.clone(),
4163    }
4164}
4165
4166fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4167    let mut update = proto::UpdateChannels::default();
4168
4169    for channel in channels.channels {
4170        update.channels.push(channel.to_proto());
4171    }
4172
4173    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4174    update.latest_channel_message_ids = channels.latest_channel_messages;
4175
4176    for (channel_id, participants) in channels.channel_participants {
4177        update
4178            .channel_participants
4179            .push(proto::ChannelParticipants {
4180                channel_id: channel_id.to_proto(),
4181                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4182            });
4183    }
4184
4185    for channel in channels.invited_channels {
4186        update.channel_invitations.push(channel.to_proto());
4187    }
4188
4189    update
4190}
4191
4192fn build_initial_contacts_update(
4193    contacts: Vec<db::Contact>,
4194    pool: &ConnectionPool,
4195) -> proto::UpdateContacts {
4196    let mut update = proto::UpdateContacts::default();
4197
4198    for contact in contacts {
4199        match contact {
4200            db::Contact::Accepted { user_id, busy } => {
4201                update.contacts.push(contact_for_user(user_id, busy, pool));
4202            }
4203            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4204            db::Contact::Incoming { user_id } => {
4205                update
4206                    .incoming_requests
4207                    .push(proto::IncomingContactRequest {
4208                        requester_id: user_id.to_proto(),
4209                    })
4210            }
4211        }
4212    }
4213
4214    update
4215}
4216
4217fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4218    proto::Contact {
4219        user_id: user_id.to_proto(),
4220        online: pool.is_user_online(user_id),
4221        busy,
4222    }
4223}
4224
4225fn room_updated(room: &proto::Room, peer: &Peer) {
4226    broadcast(
4227        None,
4228        room.participants
4229            .iter()
4230            .filter_map(|participant| Some(participant.peer_id?.into())),
4231        |peer_id| {
4232            peer.send(
4233                peer_id,
4234                proto::RoomUpdated {
4235                    room: Some(room.clone()),
4236                },
4237            )
4238        },
4239    );
4240}
4241
4242fn channel_updated(
4243    channel: &db::channel::Model,
4244    room: &proto::Room,
4245    peer: &Peer,
4246    pool: &ConnectionPool,
4247) {
4248    let participants = room
4249        .participants
4250        .iter()
4251        .map(|p| p.user_id)
4252        .collect::<Vec<_>>();
4253
4254    broadcast(
4255        None,
4256        pool.channel_connection_ids(channel.root_id())
4257            .filter_map(|(channel_id, role)| {
4258                role.can_see_channel(channel.visibility)
4259                    .then_some(channel_id)
4260            }),
4261        |peer_id| {
4262            peer.send(
4263                peer_id,
4264                proto::UpdateChannels {
4265                    channel_participants: vec![proto::ChannelParticipants {
4266                        channel_id: channel.id.to_proto(),
4267                        participant_user_ids: participants.clone(),
4268                    }],
4269                    ..Default::default()
4270                },
4271            )
4272        },
4273    );
4274}
4275
4276async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4277    let db = session.db().await;
4278
4279    let contacts = db.get_contacts(user_id).await?;
4280    let busy = db.is_user_busy(user_id).await?;
4281
4282    let pool = session.connection_pool().await;
4283    let updated_contact = contact_for_user(user_id, busy, &pool);
4284    for contact in contacts {
4285        if let db::Contact::Accepted {
4286            user_id: contact_user_id,
4287            ..
4288        } = contact
4289        {
4290            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4291                session
4292                    .peer
4293                    .send(
4294                        contact_conn_id,
4295                        proto::UpdateContacts {
4296                            contacts: vec![updated_contact.clone()],
4297                            remove_contacts: Default::default(),
4298                            incoming_requests: Default::default(),
4299                            remove_incoming_requests: Default::default(),
4300                            outgoing_requests: Default::default(),
4301                            remove_outgoing_requests: Default::default(),
4302                        },
4303                    )
4304                    .trace_err();
4305            }
4306        }
4307    }
4308    Ok(())
4309}
4310
4311async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4312    let mut contacts_to_update = HashSet::default();
4313
4314    let room_id;
4315    let canceled_calls_to_user_ids;
4316    let live_kit_room;
4317    let delete_live_kit_room;
4318    let room;
4319    let channel;
4320
4321    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4322        contacts_to_update.insert(session.user_id());
4323
4324        for project in left_room.left_projects.values() {
4325            project_left(project, session);
4326        }
4327
4328        room_id = RoomId::from_proto(left_room.room.id);
4329        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4330        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4331        delete_live_kit_room = left_room.deleted;
4332        room = mem::take(&mut left_room.room);
4333        channel = mem::take(&mut left_room.channel);
4334
4335        room_updated(&room, &session.peer);
4336    } else {
4337        return Ok(());
4338    }
4339
4340    if let Some(channel) = channel {
4341        channel_updated(
4342            &channel,
4343            &room,
4344            &session.peer,
4345            &*session.connection_pool().await,
4346        );
4347    }
4348
4349    {
4350        let pool = session.connection_pool().await;
4351        for canceled_user_id in canceled_calls_to_user_ids {
4352            for connection_id in pool.user_connection_ids(canceled_user_id) {
4353                session
4354                    .peer
4355                    .send(
4356                        connection_id,
4357                        proto::CallCanceled {
4358                            room_id: room_id.to_proto(),
4359                        },
4360                    )
4361                    .trace_err();
4362            }
4363            contacts_to_update.insert(canceled_user_id);
4364        }
4365    }
4366
4367    for contact_user_id in contacts_to_update {
4368        update_user_contacts(contact_user_id, session).await?;
4369    }
4370
4371    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
4372        live_kit
4373            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4374            .await
4375            .trace_err();
4376
4377        if delete_live_kit_room {
4378            live_kit.delete_room(live_kit_room).await.trace_err();
4379        }
4380    }
4381
4382    Ok(())
4383}
4384
4385async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4386    let left_channel_buffers = session
4387        .db()
4388        .await
4389        .leave_channel_buffers(session.connection_id)
4390        .await?;
4391
4392    for left_buffer in left_channel_buffers {
4393        channel_buffer_updated(
4394            session.connection_id,
4395            left_buffer.connections,
4396            &proto::UpdateChannelBufferCollaborators {
4397                channel_id: left_buffer.channel_id.to_proto(),
4398                collaborators: left_buffer.collaborators,
4399            },
4400            &session.peer,
4401        );
4402    }
4403
4404    Ok(())
4405}
4406
4407fn project_left(project: &db::LeftProject, session: &Session) {
4408    for connection_id in &project.connection_ids {
4409        if project.should_unshare {
4410            session
4411                .peer
4412                .send(
4413                    *connection_id,
4414                    proto::UnshareProject {
4415                        project_id: project.id.to_proto(),
4416                    },
4417                )
4418                .trace_err();
4419        } else {
4420            session
4421                .peer
4422                .send(
4423                    *connection_id,
4424                    proto::RemoveProjectCollaborator {
4425                        project_id: project.id.to_proto(),
4426                        peer_id: Some(session.connection_id.into()),
4427                    },
4428                )
4429                .trace_err();
4430        }
4431    }
4432}
4433
4434pub trait ResultExt {
4435    type Ok;
4436
4437    fn trace_err(self) -> Option<Self::Ok>;
4438}
4439
4440impl<T, E> ResultExt for Result<T, E>
4441where
4442    E: std::fmt::Debug,
4443{
4444    type Ok = T;
4445
4446    #[track_caller]
4447    fn trace_err(self) -> Option<T> {
4448        match self {
4449            Ok(value) => Some(value),
4450            Err(error) => {
4451                tracing::error!("{:?}", error);
4452                None
4453            }
4454        }
4455    }
4456}