rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    AppState, Config, Error, RateLimit, Result, auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14};
  15use anyhow::{Context as _, anyhow, bail};
  16use async_tungstenite::tungstenite::{
  17    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  18};
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use chrono::Utc;
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use http_client::HttpClient;
  37use reqwest_client::ReqwestClient;
  38use rpc::proto::split_repository_update;
  39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{MutexGuard, Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone, Debug)]
 105pub enum Principal {
 106    User(User),
 107    Impersonated { user: User, admin: User },
 108}
 109
 110impl Principal {
 111    fn update_span(&self, span: &tracing::Span) {
 112        match &self {
 113            Principal::User(user) => {
 114                span.record("user_id", user.id.0);
 115                span.record("login", &user.github_login);
 116            }
 117            Principal::Impersonated { user, admin } => {
 118                span.record("user_id", user.id.0);
 119                span.record("login", &user.github_login);
 120                span.record("impersonator", &admin.github_login);
 121            }
 122        }
 123    }
 124}
 125
 126#[derive(Clone)]
 127struct Session {
 128    principal: Principal,
 129    connection_id: ConnectionId,
 130    db: Arc<tokio::sync::Mutex<DbHandle>>,
 131    peer: Arc<Peer>,
 132    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 133    app_state: Arc<AppState>,
 134    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 135    http_client: Arc<dyn HttpClient>,
 136    /// The GeoIP country code for the user.
 137    #[allow(unused)]
 138    geoip_country_code: Option<String>,
 139    system_id: Option<String>,
 140    _executor: Executor,
 141}
 142
 143impl Session {
 144    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 145        #[cfg(test)]
 146        tokio::task::yield_now().await;
 147        let guard = self.db.lock().await;
 148        #[cfg(test)]
 149        tokio::task::yield_now().await;
 150        guard
 151    }
 152
 153    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.connection_pool.lock();
 157        ConnectionPoolGuard {
 158            guard,
 159            _not_send: PhantomData,
 160        }
 161    }
 162
 163    fn is_staff(&self) -> bool {
 164        match &self.principal {
 165            Principal::User(user) => user.admin,
 166            Principal::Impersonated { .. } => true,
 167        }
 168    }
 169
 170    pub async fn has_llm_subscription(
 171        &self,
 172        db: &MutexGuard<'_, DbHandle>,
 173    ) -> anyhow::Result<bool> {
 174        if self.is_staff() {
 175            return Ok(true);
 176        }
 177
 178        let user_id = self.user_id();
 179
 180        Ok(db.has_active_billing_subscription(user_id).await?)
 181    }
 182
 183    pub async fn current_plan(
 184        &self,
 185        _db: &MutexGuard<'_, DbHandle>,
 186    ) -> anyhow::Result<proto::Plan> {
 187        if self.is_staff() {
 188            Ok(proto::Plan::ZedPro)
 189        } else {
 190            Ok(proto::Plan::Free)
 191        }
 192    }
 193
 194    fn user_id(&self) -> UserId {
 195        match &self.principal {
 196            Principal::User(user) => user.id,
 197            Principal::Impersonated { user, .. } => user.id,
 198        }
 199    }
 200
 201    pub fn email(&self) -> Option<String> {
 202        match &self.principal {
 203            Principal::User(user) => user.email_address.clone(),
 204            Principal::Impersonated { user, .. } => user.email_address.clone(),
 205        }
 206    }
 207}
 208
 209impl Debug for Session {
 210    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 211        let mut result = f.debug_struct("Session");
 212        match &self.principal {
 213            Principal::User(user) => {
 214                result.field("user", &user.github_login);
 215            }
 216            Principal::Impersonated { user, admin } => {
 217                result.field("user", &user.github_login);
 218                result.field("impersonator", &admin.github_login);
 219            }
 220        }
 221        result.field("connection_id", &self.connection_id).finish()
 222    }
 223}
 224
 225struct DbHandle(Arc<Database>);
 226
 227impl Deref for DbHandle {
 228    type Target = Database;
 229
 230    fn deref(&self) -> &Self::Target {
 231        self.0.as_ref()
 232    }
 233}
 234
 235pub struct Server {
 236    id: parking_lot::Mutex<ServerId>,
 237    peer: Arc<Peer>,
 238    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 239    app_state: Arc<AppState>,
 240    handlers: HashMap<TypeId, MessageHandler>,
 241    teardown: watch::Sender<bool>,
 242}
 243
 244pub(crate) struct ConnectionPoolGuard<'a> {
 245    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 246    _not_send: PhantomData<Rc<()>>,
 247}
 248
 249#[derive(Serialize)]
 250pub struct ServerSnapshot<'a> {
 251    peer: &'a Peer,
 252    #[serde(serialize_with = "serialize_deref")]
 253    connection_pool: ConnectionPoolGuard<'a>,
 254}
 255
 256pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 257where
 258    S: Serializer,
 259    T: Deref<Target = U>,
 260    U: Serialize,
 261{
 262    Serialize::serialize(value.deref(), serializer)
 263}
 264
 265impl Server {
 266    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 267        let mut server = Self {
 268            id: parking_lot::Mutex::new(id),
 269            peer: Peer::new(id.0 as u32),
 270            app_state: app_state.clone(),
 271            connection_pool: Default::default(),
 272            handlers: Default::default(),
 273            teardown: watch::channel(false).0,
 274        };
 275
 276        server
 277            .add_request_handler(ping)
 278            .add_request_handler(create_room)
 279            .add_request_handler(join_room)
 280            .add_request_handler(rejoin_room)
 281            .add_request_handler(leave_room)
 282            .add_request_handler(set_room_participant_role)
 283            .add_request_handler(call)
 284            .add_request_handler(cancel_call)
 285            .add_message_handler(decline_call)
 286            .add_request_handler(update_participant_location)
 287            .add_request_handler(share_project)
 288            .add_message_handler(unshare_project)
 289            .add_request_handler(join_project)
 290            .add_message_handler(leave_project)
 291            .add_request_handler(update_project)
 292            .add_request_handler(update_worktree)
 293            .add_request_handler(update_repository)
 294            .add_request_handler(remove_repository)
 295            .add_message_handler(start_language_server)
 296            .add_message_handler(update_language_server)
 297            .add_message_handler(update_diagnostic_summary)
 298            .add_message_handler(update_worktree_settings)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 302            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 303            .add_request_handler(forward_find_search_candidates_request)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 305            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 306            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 307            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 308            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 309            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 310            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 311            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 312            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 315            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 316            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 317            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 318            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 319            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 320            .add_request_handler(
 321                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 322            )
 323            .add_request_handler(
 324                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 325            )
 326            .add_request_handler(
 327                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 328            )
 329            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 330            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 331            .add_request_handler(
 332                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 333            )
 334            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 335            .add_request_handler(
 336                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 337            )
 338            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 339            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 340            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 341            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 342            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 343            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 344            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 345            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 346            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 347            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 348            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 349            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 350            .add_request_handler(
 351                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 352            )
 353            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 354            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 355            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 356            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 357            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 358            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 359            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 360            .add_message_handler(create_buffer_for_peer)
 361            .add_request_handler(update_buffer)
 362            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 363            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 364            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 365            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 366            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 367            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 368            .add_request_handler(get_users)
 369            .add_request_handler(fuzzy_search_users)
 370            .add_request_handler(request_contact)
 371            .add_request_handler(remove_contact)
 372            .add_request_handler(respond_to_contact_request)
 373            .add_message_handler(subscribe_to_channels)
 374            .add_request_handler(create_channel)
 375            .add_request_handler(delete_channel)
 376            .add_request_handler(invite_channel_member)
 377            .add_request_handler(remove_channel_member)
 378            .add_request_handler(set_channel_member_role)
 379            .add_request_handler(set_channel_visibility)
 380            .add_request_handler(rename_channel)
 381            .add_request_handler(join_channel_buffer)
 382            .add_request_handler(leave_channel_buffer)
 383            .add_message_handler(update_channel_buffer)
 384            .add_request_handler(rejoin_channel_buffers)
 385            .add_request_handler(get_channel_members)
 386            .add_request_handler(respond_to_channel_invite)
 387            .add_request_handler(join_channel)
 388            .add_request_handler(join_channel_chat)
 389            .add_message_handler(leave_channel_chat)
 390            .add_request_handler(send_channel_message)
 391            .add_request_handler(remove_channel_message)
 392            .add_request_handler(update_channel_message)
 393            .add_request_handler(get_channel_messages)
 394            .add_request_handler(get_channel_messages_by_id)
 395            .add_request_handler(get_notifications)
 396            .add_request_handler(mark_notification_as_read)
 397            .add_request_handler(move_channel)
 398            .add_request_handler(follow)
 399            .add_message_handler(unfollow)
 400            .add_message_handler(update_followers)
 401            .add_request_handler(get_private_user_info)
 402            .add_request_handler(get_llm_api_token)
 403            .add_request_handler(accept_terms_of_service)
 404            .add_message_handler(acknowledge_channel_message)
 405            .add_message_handler(acknowledge_buffer_version)
 406            .add_request_handler(get_supermaven_api_key)
 407            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 408            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 409            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 410            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 411            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 412            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 413            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 414            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 415            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 416            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 417            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 418            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 419            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 420            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 421            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 422            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 423            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 424            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 425            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 426            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 427            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 428            .add_message_handler(update_context)
 429            .add_request_handler({
 430                let app_state = app_state.clone();
 431                move |request, response, session| {
 432                    let app_state = app_state.clone();
 433                    async move {
 434                        count_language_model_tokens(request, response, session, &app_state.config)
 435                            .await
 436                    }
 437                }
 438            });
 439
 440        Arc::new(server)
 441    }
 442
 443    pub async fn start(&self) -> Result<()> {
 444        let server_id = *self.id.lock();
 445        let app_state = self.app_state.clone();
 446        let peer = self.peer.clone();
 447        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 448        let pool = self.connection_pool.clone();
 449        let livekit_client = self.app_state.livekit_client.clone();
 450
 451        let span = info_span!("start server");
 452        self.app_state.executor.spawn_detached(
 453            async move {
 454                tracing::info!("waiting for cleanup timeout");
 455                timeout.await;
 456                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 457                if let Some((room_ids, channel_ids)) = app_state
 458                    .db
 459                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 460                    .await
 461                    .trace_err()
 462                {
 463                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 464                    tracing::info!(
 465                        stale_channel_buffer_count = channel_ids.len(),
 466                        "retrieved stale channel buffers"
 467                    );
 468
 469                    for channel_id in channel_ids {
 470                        if let Some(refreshed_channel_buffer) = app_state
 471                            .db
 472                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 473                            .await
 474                            .trace_err()
 475                        {
 476                            for connection_id in refreshed_channel_buffer.connection_ids {
 477                                peer.send(
 478                                    connection_id,
 479                                    proto::UpdateChannelBufferCollaborators {
 480                                        channel_id: channel_id.to_proto(),
 481                                        collaborators: refreshed_channel_buffer
 482                                            .collaborators
 483                                            .clone(),
 484                                    },
 485                                )
 486                                .trace_err();
 487                            }
 488                        }
 489                    }
 490
 491                    for room_id in room_ids {
 492                        let mut contacts_to_update = HashSet::default();
 493                        let mut canceled_calls_to_user_ids = Vec::new();
 494                        let mut livekit_room = String::new();
 495                        let mut delete_livekit_room = false;
 496
 497                        if let Some(mut refreshed_room) = app_state
 498                            .db
 499                            .clear_stale_room_participants(room_id, server_id)
 500                            .await
 501                            .trace_err()
 502                        {
 503                            tracing::info!(
 504                                room_id = room_id.0,
 505                                new_participant_count = refreshed_room.room.participants.len(),
 506                                "refreshed room"
 507                            );
 508                            room_updated(&refreshed_room.room, &peer);
 509                            if let Some(channel) = refreshed_room.channel.as_ref() {
 510                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 511                            }
 512                            contacts_to_update
 513                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 514                            contacts_to_update
 515                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 516                            canceled_calls_to_user_ids =
 517                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 518                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 519                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 520                        }
 521
 522                        {
 523                            let pool = pool.lock();
 524                            for canceled_user_id in canceled_calls_to_user_ids {
 525                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 526                                    peer.send(
 527                                        connection_id,
 528                                        proto::CallCanceled {
 529                                            room_id: room_id.to_proto(),
 530                                        },
 531                                    )
 532                                    .trace_err();
 533                                }
 534                            }
 535                        }
 536
 537                        for user_id in contacts_to_update {
 538                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 539                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 540                            if let Some((busy, contacts)) = busy.zip(contacts) {
 541                                let pool = pool.lock();
 542                                let updated_contact = contact_for_user(user_id, busy, &pool);
 543                                for contact in contacts {
 544                                    if let db::Contact::Accepted {
 545                                        user_id: contact_user_id,
 546                                        ..
 547                                    } = contact
 548                                    {
 549                                        for contact_conn_id in
 550                                            pool.user_connection_ids(contact_user_id)
 551                                        {
 552                                            peer.send(
 553                                                contact_conn_id,
 554                                                proto::UpdateContacts {
 555                                                    contacts: vec![updated_contact.clone()],
 556                                                    remove_contacts: Default::default(),
 557                                                    incoming_requests: Default::default(),
 558                                                    remove_incoming_requests: Default::default(),
 559                                                    outgoing_requests: Default::default(),
 560                                                    remove_outgoing_requests: Default::default(),
 561                                                },
 562                                            )
 563                                            .trace_err();
 564                                        }
 565                                    }
 566                                }
 567                            }
 568                        }
 569
 570                        if let Some(live_kit) = livekit_client.as_ref() {
 571                            if delete_livekit_room {
 572                                live_kit.delete_room(livekit_room).await.trace_err();
 573                            }
 574                        }
 575                    }
 576                }
 577
 578                app_state
 579                    .db
 580                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 581                    .await
 582                    .trace_err();
 583            }
 584            .instrument(span),
 585        );
 586        Ok(())
 587    }
 588
 589    pub fn teardown(&self) {
 590        self.peer.teardown();
 591        self.connection_pool.lock().reset();
 592        let _ = self.teardown.send(true);
 593    }
 594
 595    #[cfg(test)]
 596    pub fn reset(&self, id: ServerId) {
 597        self.teardown();
 598        *self.id.lock() = id;
 599        self.peer.reset(id.0 as u32);
 600        let _ = self.teardown.send(false);
 601    }
 602
 603    #[cfg(test)]
 604    pub fn id(&self) -> ServerId {
 605        *self.id.lock()
 606    }
 607
 608    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 609    where
 610        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 611        Fut: 'static + Send + Future<Output = Result<()>>,
 612        M: EnvelopedMessage,
 613    {
 614        let prev_handler = self.handlers.insert(
 615            TypeId::of::<M>(),
 616            Box::new(move |envelope, session| {
 617                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 618                let received_at = envelope.received_at;
 619                tracing::info!("message received");
 620                let start_time = Instant::now();
 621                let future = (handler)(*envelope, session);
 622                async move {
 623                    let result = future.await;
 624                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 625                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 626                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 627                    let payload_type = M::NAME;
 628
 629                    match result {
 630                        Err(error) => {
 631                            tracing::error!(
 632                                ?error,
 633                                total_duration_ms,
 634                                processing_duration_ms,
 635                                queue_duration_ms,
 636                                payload_type,
 637                                "error handling message"
 638                            )
 639                        }
 640                        Ok(()) => tracing::info!(
 641                            total_duration_ms,
 642                            processing_duration_ms,
 643                            queue_duration_ms,
 644                            "finished handling message"
 645                        ),
 646                    }
 647                }
 648                .boxed()
 649            }),
 650        );
 651        if prev_handler.is_some() {
 652            panic!("registered a handler for the same message twice");
 653        }
 654        self
 655    }
 656
 657    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 658    where
 659        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 660        Fut: 'static + Send + Future<Output = Result<()>>,
 661        M: EnvelopedMessage,
 662    {
 663        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 664        self
 665    }
 666
 667    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 668    where
 669        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 670        Fut: Send + Future<Output = Result<()>>,
 671        M: RequestMessage,
 672    {
 673        let handler = Arc::new(handler);
 674        self.add_handler(move |envelope, session| {
 675            let receipt = envelope.receipt();
 676            let handler = handler.clone();
 677            async move {
 678                let peer = session.peer.clone();
 679                let responded = Arc::new(AtomicBool::default());
 680                let response = Response {
 681                    peer: peer.clone(),
 682                    responded: responded.clone(),
 683                    receipt,
 684                };
 685                match (handler)(envelope.payload, response, session).await {
 686                    Ok(()) => {
 687                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 688                            Ok(())
 689                        } else {
 690                            Err(anyhow!("handler did not send a response"))?
 691                        }
 692                    }
 693                    Err(error) => {
 694                        let proto_err = match &error {
 695                            Error::Internal(err) => err.to_proto(),
 696                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 697                        };
 698                        peer.respond_with_error(receipt, proto_err)?;
 699                        Err(error)
 700                    }
 701                }
 702            }
 703        })
 704    }
 705
 706    pub fn handle_connection(
 707        self: &Arc<Self>,
 708        connection: Connection,
 709        address: String,
 710        principal: Principal,
 711        zed_version: ZedVersion,
 712        geoip_country_code: Option<String>,
 713        system_id: Option<String>,
 714        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 715        executor: Executor,
 716    ) -> impl Future<Output = ()> + use<> {
 717        let this = self.clone();
 718        let span = info_span!("handle connection", %address,
 719            connection_id=field::Empty,
 720            user_id=field::Empty,
 721            login=field::Empty,
 722            impersonator=field::Empty,
 723            geoip_country_code=field::Empty
 724        );
 725        principal.update_span(&span);
 726        if let Some(country_code) = geoip_country_code.as_ref() {
 727            span.record("geoip_country_code", country_code);
 728        }
 729
 730        let mut teardown = self.teardown.subscribe();
 731        async move {
 732            if *teardown.borrow() {
 733                tracing::error!("server is tearing down");
 734                return
 735            }
 736            let (connection_id, handle_io, mut incoming_rx) = this
 737                .peer
 738                .add_connection(connection, {
 739                    let executor = executor.clone();
 740                    move |duration| executor.sleep(duration)
 741                });
 742            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 743
 744            tracing::info!("connection opened");
 745
 746            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 747            let http_client = match ReqwestClient::user_agent(&user_agent) {
 748                Ok(http_client) => Arc::new(http_client),
 749                Err(error) => {
 750                    tracing::error!(?error, "failed to create HTTP client");
 751                    return;
 752                }
 753            };
 754
 755            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 756                    supermaven_admin_api_key.to_string(),
 757                    http_client.clone(),
 758                )));
 759
 760            let session = Session {
 761                principal: principal.clone(),
 762                connection_id,
 763                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 764                peer: this.peer.clone(),
 765                connection_pool: this.connection_pool.clone(),
 766                app_state: this.app_state.clone(),
 767                http_client,
 768                geoip_country_code,
 769                system_id,
 770                _executor: executor.clone(),
 771                supermaven_client,
 772            };
 773
 774            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 775                tracing::error!(?error, "failed to send initial client update");
 776                return;
 777            }
 778
 779            let handle_io = handle_io.fuse();
 780            futures::pin_mut!(handle_io);
 781
 782            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 783            // This prevents deadlocks when e.g., client A performs a request to client B and
 784            // client B performs a request to client A. If both clients stop processing further
 785            // messages until their respective request completes, they won't have a chance to
 786            // respond to the other client's request and cause a deadlock.
 787            //
 788            // This arrangement ensures we will attempt to process earlier messages first, but fall
 789            // back to processing messages arrived later in the spirit of making progress.
 790            let mut foreground_message_handlers = FuturesUnordered::new();
 791            let concurrent_handlers = Arc::new(Semaphore::new(256));
 792            loop {
 793                let next_message = async {
 794                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 795                    let message = incoming_rx.next().await;
 796                    (permit, message)
 797                }.fuse();
 798                futures::pin_mut!(next_message);
 799                futures::select_biased! {
 800                    _ = teardown.changed().fuse() => return,
 801                    result = handle_io => {
 802                        if let Err(error) = result {
 803                            tracing::error!(?error, "error handling I/O");
 804                        }
 805                        break;
 806                    }
 807                    _ = foreground_message_handlers.next() => {}
 808                    next_message = next_message => {
 809                        let (permit, message) = next_message;
 810                        if let Some(message) = message {
 811                            let type_name = message.payload_type_name();
 812                            // note: we copy all the fields from the parent span so we can query them in the logs.
 813                            // (https://github.com/tokio-rs/tracing/issues/2670).
 814                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 815                                user_id=field::Empty,
 816                                login=field::Empty,
 817                                impersonator=field::Empty,
 818                            );
 819                            principal.update_span(&span);
 820                            let span_enter = span.enter();
 821                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 822                                let is_background = message.is_background();
 823                                let handle_message = (handler)(message, session.clone());
 824                                drop(span_enter);
 825
 826                                let handle_message = async move {
 827                                    handle_message.await;
 828                                    drop(permit);
 829                                }.instrument(span);
 830                                if is_background {
 831                                    executor.spawn_detached(handle_message);
 832                                } else {
 833                                    foreground_message_handlers.push(handle_message);
 834                                }
 835                            } else {
 836                                tracing::error!("no message handler");
 837                            }
 838                        } else {
 839                            tracing::info!("connection closed");
 840                            break;
 841                        }
 842                    }
 843                }
 844            }
 845
 846            drop(foreground_message_handlers);
 847            tracing::info!("signing out");
 848            if let Err(error) = connection_lost(session, teardown, executor).await {
 849                tracing::error!(?error, "error signing out");
 850            }
 851
 852        }.instrument(span)
 853    }
 854
 855    async fn send_initial_client_update(
 856        &self,
 857        connection_id: ConnectionId,
 858        principal: &Principal,
 859        zed_version: ZedVersion,
 860        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 861        session: &Session,
 862    ) -> Result<()> {
 863        self.peer.send(
 864            connection_id,
 865            proto::Hello {
 866                peer_id: Some(connection_id.into()),
 867            },
 868        )?;
 869        tracing::info!("sent hello message");
 870        if let Some(send_connection_id) = send_connection_id.take() {
 871            let _ = send_connection_id.send(connection_id);
 872        }
 873
 874        match principal {
 875            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 876                if !user.connected_once {
 877                    self.peer.send(connection_id, proto::ShowContacts {})?;
 878                    self.app_state
 879                        .db
 880                        .set_user_connected_once(user.id, true)
 881                        .await?;
 882                }
 883
 884                update_user_plan(user.id, session).await?;
 885
 886                let contacts = self.app_state.db.get_contacts(user.id).await?;
 887
 888                {
 889                    let mut pool = self.connection_pool.lock();
 890                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 891                    self.peer.send(
 892                        connection_id,
 893                        build_initial_contacts_update(contacts, &pool),
 894                    )?;
 895                }
 896
 897                if should_auto_subscribe_to_channels(zed_version) {
 898                    subscribe_user_to_channels(user.id, session).await?;
 899                }
 900
 901                if let Some(incoming_call) =
 902                    self.app_state.db.incoming_call_for_user(user.id).await?
 903                {
 904                    self.peer.send(connection_id, incoming_call)?;
 905                }
 906
 907                update_user_contacts(user.id, session).await?;
 908            }
 909        }
 910
 911        Ok(())
 912    }
 913
 914    pub async fn invite_code_redeemed(
 915        self: &Arc<Self>,
 916        inviter_id: UserId,
 917        invitee_id: UserId,
 918    ) -> Result<()> {
 919        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 920            if let Some(code) = &user.invite_code {
 921                let pool = self.connection_pool.lock();
 922                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 923                for connection_id in pool.user_connection_ids(inviter_id) {
 924                    self.peer.send(
 925                        connection_id,
 926                        proto::UpdateContacts {
 927                            contacts: vec![invitee_contact.clone()],
 928                            ..Default::default()
 929                        },
 930                    )?;
 931                    self.peer.send(
 932                        connection_id,
 933                        proto::UpdateInviteInfo {
 934                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 935                            count: user.invite_count as u32,
 936                        },
 937                    )?;
 938                }
 939            }
 940        }
 941        Ok(())
 942    }
 943
 944    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 945        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 946            if let Some(invite_code) = &user.invite_code {
 947                let pool = self.connection_pool.lock();
 948                for connection_id in pool.user_connection_ids(user_id) {
 949                    self.peer.send(
 950                        connection_id,
 951                        proto::UpdateInviteInfo {
 952                            url: format!(
 953                                "{}{}",
 954                                self.app_state.config.invite_link_prefix, invite_code
 955                            ),
 956                            count: user.invite_count as u32,
 957                        },
 958                    )?;
 959                }
 960            }
 961        }
 962        Ok(())
 963    }
 964
 965    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 966        let pool = self.connection_pool.lock();
 967        for connection_id in pool.user_connection_ids(user_id) {
 968            self.peer
 969                .send(connection_id, proto::RefreshLlmToken {})
 970                .trace_err();
 971        }
 972    }
 973
 974    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
 975        ServerSnapshot {
 976            connection_pool: ConnectionPoolGuard {
 977                guard: self.connection_pool.lock(),
 978                _not_send: PhantomData,
 979            },
 980            peer: &self.peer,
 981        }
 982    }
 983}
 984
 985impl Deref for ConnectionPoolGuard<'_> {
 986    type Target = ConnectionPool;
 987
 988    fn deref(&self) -> &Self::Target {
 989        &self.guard
 990    }
 991}
 992
 993impl DerefMut for ConnectionPoolGuard<'_> {
 994    fn deref_mut(&mut self) -> &mut Self::Target {
 995        &mut self.guard
 996    }
 997}
 998
 999impl Drop for ConnectionPoolGuard<'_> {
1000    fn drop(&mut self) {
1001        #[cfg(test)]
1002        self.check_invariants();
1003    }
1004}
1005
1006fn broadcast<F>(
1007    sender_id: Option<ConnectionId>,
1008    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1009    mut f: F,
1010) where
1011    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1012{
1013    for receiver_id in receiver_ids {
1014        if Some(receiver_id) != sender_id {
1015            if let Err(error) = f(receiver_id) {
1016                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1017            }
1018        }
1019    }
1020}
1021
1022pub struct ProtocolVersion(u32);
1023
1024impl Header for ProtocolVersion {
1025    fn name() -> &'static HeaderName {
1026        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1027        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1028    }
1029
1030    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1031    where
1032        Self: Sized,
1033        I: Iterator<Item = &'i axum::http::HeaderValue>,
1034    {
1035        let version = values
1036            .next()
1037            .ok_or_else(axum::headers::Error::invalid)?
1038            .to_str()
1039            .map_err(|_| axum::headers::Error::invalid())?
1040            .parse()
1041            .map_err(|_| axum::headers::Error::invalid())?;
1042        Ok(Self(version))
1043    }
1044
1045    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1046        values.extend([self.0.to_string().parse().unwrap()]);
1047    }
1048}
1049
1050pub struct AppVersionHeader(SemanticVersion);
1051impl Header for AppVersionHeader {
1052    fn name() -> &'static HeaderName {
1053        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1054        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1055    }
1056
1057    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1058    where
1059        Self: Sized,
1060        I: Iterator<Item = &'i axum::http::HeaderValue>,
1061    {
1062        let version = values
1063            .next()
1064            .ok_or_else(axum::headers::Error::invalid)?
1065            .to_str()
1066            .map_err(|_| axum::headers::Error::invalid())?
1067            .parse()
1068            .map_err(|_| axum::headers::Error::invalid())?;
1069        Ok(Self(version))
1070    }
1071
1072    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1073        values.extend([self.0.to_string().parse().unwrap()]);
1074    }
1075}
1076
1077pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1078    Router::new()
1079        .route("/rpc", get(handle_websocket_request))
1080        .layer(
1081            ServiceBuilder::new()
1082                .layer(Extension(server.app_state.clone()))
1083                .layer(middleware::from_fn(auth::validate_header)),
1084        )
1085        .route("/metrics", get(handle_metrics))
1086        .layer(Extension(server))
1087}
1088
1089pub async fn handle_websocket_request(
1090    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1091    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1092    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1093    Extension(server): Extension<Arc<Server>>,
1094    Extension(principal): Extension<Principal>,
1095    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1096    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1097    ws: WebSocketUpgrade,
1098) -> axum::response::Response {
1099    if protocol_version != rpc::PROTOCOL_VERSION {
1100        return (
1101            StatusCode::UPGRADE_REQUIRED,
1102            "client must be upgraded".to_string(),
1103        )
1104            .into_response();
1105    }
1106
1107    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1108        return (
1109            StatusCode::UPGRADE_REQUIRED,
1110            "no version header found".to_string(),
1111        )
1112            .into_response();
1113    };
1114
1115    if !version.can_collaborate() {
1116        return (
1117            StatusCode::UPGRADE_REQUIRED,
1118            "client must be upgraded".to_string(),
1119        )
1120            .into_response();
1121    }
1122
1123    let socket_address = socket_address.to_string();
1124    ws.on_upgrade(move |socket| {
1125        let socket = socket
1126            .map_ok(to_tungstenite_message)
1127            .err_into()
1128            .with(|message| async move { to_axum_message(message) });
1129        let connection = Connection::new(Box::pin(socket));
1130        async move {
1131            server
1132                .handle_connection(
1133                    connection,
1134                    socket_address,
1135                    principal,
1136                    version,
1137                    country_code_header.map(|header| header.to_string()),
1138                    system_id_header.map(|header| header.to_string()),
1139                    None,
1140                    Executor::Production,
1141                )
1142                .await;
1143        }
1144    })
1145}
1146
1147pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1148    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1149    let connections_metric = CONNECTIONS_METRIC
1150        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1151
1152    let connections = server
1153        .connection_pool
1154        .lock()
1155        .connections()
1156        .filter(|connection| !connection.admin)
1157        .count();
1158    connections_metric.set(connections as _);
1159
1160    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1161    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1162        register_int_gauge!(
1163            "shared_projects",
1164            "number of open projects with one or more guests"
1165        )
1166        .unwrap()
1167    });
1168
1169    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1170    shared_projects_metric.set(shared_projects as _);
1171
1172    let encoder = prometheus::TextEncoder::new();
1173    let metric_families = prometheus::gather();
1174    let encoded_metrics = encoder
1175        .encode_to_string(&metric_families)
1176        .map_err(|err| anyhow!("{}", err))?;
1177    Ok(encoded_metrics)
1178}
1179
1180#[instrument(err, skip(executor))]
1181async fn connection_lost(
1182    session: Session,
1183    mut teardown: watch::Receiver<bool>,
1184    executor: Executor,
1185) -> Result<()> {
1186    session.peer.disconnect(session.connection_id);
1187    session
1188        .connection_pool()
1189        .await
1190        .remove_connection(session.connection_id)?;
1191
1192    session
1193        .db()
1194        .await
1195        .connection_lost(session.connection_id)
1196        .await
1197        .trace_err();
1198
1199    futures::select_biased! {
1200        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1201
1202            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1203            leave_room_for_session(&session, session.connection_id).await.trace_err();
1204            leave_channel_buffers_for_session(&session)
1205                .await
1206                .trace_err();
1207
1208            if !session
1209                .connection_pool()
1210                .await
1211                .is_user_online(session.user_id())
1212            {
1213                let db = session.db().await;
1214                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1215                    room_updated(&room, &session.peer);
1216                }
1217            }
1218
1219            update_user_contacts(session.user_id(), &session).await?;
1220        },
1221        _ = teardown.changed().fuse() => {}
1222    }
1223
1224    Ok(())
1225}
1226
1227/// Acknowledges a ping from a client, used to keep the connection alive.
1228async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1229    response.send(proto::Ack {})?;
1230    Ok(())
1231}
1232
1233/// Creates a new room for calling (outside of channels)
1234async fn create_room(
1235    _request: proto::CreateRoom,
1236    response: Response<proto::CreateRoom>,
1237    session: Session,
1238) -> Result<()> {
1239    let livekit_room = nanoid::nanoid!(30);
1240
1241    let live_kit_connection_info = util::maybe!(async {
1242        let live_kit = session.app_state.livekit_client.as_ref();
1243        let live_kit = live_kit?;
1244        let user_id = session.user_id().to_string();
1245
1246        let token = live_kit
1247            .room_token(&livekit_room, &user_id.to_string())
1248            .trace_err()?;
1249
1250        Some(proto::LiveKitConnectionInfo {
1251            server_url: live_kit.url().into(),
1252            token,
1253            can_publish: true,
1254        })
1255    })
1256    .await;
1257
1258    let room = session
1259        .db()
1260        .await
1261        .create_room(session.user_id(), session.connection_id, &livekit_room)
1262        .await?;
1263
1264    response.send(proto::CreateRoomResponse {
1265        room: Some(room.clone()),
1266        live_kit_connection_info,
1267    })?;
1268
1269    update_user_contacts(session.user_id(), &session).await?;
1270    Ok(())
1271}
1272
1273/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1274async fn join_room(
1275    request: proto::JoinRoom,
1276    response: Response<proto::JoinRoom>,
1277    session: Session,
1278) -> Result<()> {
1279    let room_id = RoomId::from_proto(request.id);
1280
1281    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1282
1283    if let Some(channel_id) = channel_id {
1284        return join_channel_internal(channel_id, Box::new(response), session).await;
1285    }
1286
1287    let joined_room = {
1288        let room = session
1289            .db()
1290            .await
1291            .join_room(room_id, session.user_id(), session.connection_id)
1292            .await?;
1293        room_updated(&room.room, &session.peer);
1294        room.into_inner()
1295    };
1296
1297    for connection_id in session
1298        .connection_pool()
1299        .await
1300        .user_connection_ids(session.user_id())
1301    {
1302        session
1303            .peer
1304            .send(
1305                connection_id,
1306                proto::CallCanceled {
1307                    room_id: room_id.to_proto(),
1308                },
1309            )
1310            .trace_err();
1311    }
1312
1313    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1314    {
1315        live_kit
1316            .room_token(
1317                &joined_room.room.livekit_room,
1318                &session.user_id().to_string(),
1319            )
1320            .trace_err()
1321            .map(|token| proto::LiveKitConnectionInfo {
1322                server_url: live_kit.url().into(),
1323                token,
1324                can_publish: true,
1325            })
1326    } else {
1327        None
1328    };
1329
1330    response.send(proto::JoinRoomResponse {
1331        room: Some(joined_room.room),
1332        channel_id: None,
1333        live_kit_connection_info,
1334    })?;
1335
1336    update_user_contacts(session.user_id(), &session).await?;
1337    Ok(())
1338}
1339
1340/// Rejoin room is used to reconnect to a room after connection errors.
1341async fn rejoin_room(
1342    request: proto::RejoinRoom,
1343    response: Response<proto::RejoinRoom>,
1344    session: Session,
1345) -> Result<()> {
1346    let room;
1347    let channel;
1348    {
1349        let mut rejoined_room = session
1350            .db()
1351            .await
1352            .rejoin_room(request, session.user_id(), session.connection_id)
1353            .await?;
1354
1355        response.send(proto::RejoinRoomResponse {
1356            room: Some(rejoined_room.room.clone()),
1357            reshared_projects: rejoined_room
1358                .reshared_projects
1359                .iter()
1360                .map(|project| proto::ResharedProject {
1361                    id: project.id.to_proto(),
1362                    collaborators: project
1363                        .collaborators
1364                        .iter()
1365                        .map(|collaborator| collaborator.to_proto())
1366                        .collect(),
1367                })
1368                .collect(),
1369            rejoined_projects: rejoined_room
1370                .rejoined_projects
1371                .iter()
1372                .map(|rejoined_project| rejoined_project.to_proto())
1373                .collect(),
1374        })?;
1375        room_updated(&rejoined_room.room, &session.peer);
1376
1377        for project in &rejoined_room.reshared_projects {
1378            for collaborator in &project.collaborators {
1379                session
1380                    .peer
1381                    .send(
1382                        collaborator.connection_id,
1383                        proto::UpdateProjectCollaborator {
1384                            project_id: project.id.to_proto(),
1385                            old_peer_id: Some(project.old_connection_id.into()),
1386                            new_peer_id: Some(session.connection_id.into()),
1387                        },
1388                    )
1389                    .trace_err();
1390            }
1391
1392            broadcast(
1393                Some(session.connection_id),
1394                project
1395                    .collaborators
1396                    .iter()
1397                    .map(|collaborator| collaborator.connection_id),
1398                |connection_id| {
1399                    session.peer.forward_send(
1400                        session.connection_id,
1401                        connection_id,
1402                        proto::UpdateProject {
1403                            project_id: project.id.to_proto(),
1404                            worktrees: project.worktrees.clone(),
1405                        },
1406                    )
1407                },
1408            );
1409        }
1410
1411        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1412
1413        let rejoined_room = rejoined_room.into_inner();
1414
1415        room = rejoined_room.room;
1416        channel = rejoined_room.channel;
1417    }
1418
1419    if let Some(channel) = channel {
1420        channel_updated(
1421            &channel,
1422            &room,
1423            &session.peer,
1424            &*session.connection_pool().await,
1425        );
1426    }
1427
1428    update_user_contacts(session.user_id(), &session).await?;
1429    Ok(())
1430}
1431
1432fn notify_rejoined_projects(
1433    rejoined_projects: &mut Vec<RejoinedProject>,
1434    session: &Session,
1435) -> Result<()> {
1436    for project in rejoined_projects.iter() {
1437        for collaborator in &project.collaborators {
1438            session
1439                .peer
1440                .send(
1441                    collaborator.connection_id,
1442                    proto::UpdateProjectCollaborator {
1443                        project_id: project.id.to_proto(),
1444                        old_peer_id: Some(project.old_connection_id.into()),
1445                        new_peer_id: Some(session.connection_id.into()),
1446                    },
1447                )
1448                .trace_err();
1449        }
1450    }
1451
1452    for project in rejoined_projects {
1453        for worktree in mem::take(&mut project.worktrees) {
1454            // Stream this worktree's entries.
1455            let message = proto::UpdateWorktree {
1456                project_id: project.id.to_proto(),
1457                worktree_id: worktree.id,
1458                abs_path: worktree.abs_path.clone(),
1459                root_name: worktree.root_name,
1460                updated_entries: worktree.updated_entries,
1461                removed_entries: worktree.removed_entries,
1462                scan_id: worktree.scan_id,
1463                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1464                updated_repositories: worktree.updated_repositories,
1465                removed_repositories: worktree.removed_repositories,
1466            };
1467            for update in proto::split_worktree_update(message) {
1468                session.peer.send(session.connection_id, update)?;
1469            }
1470
1471            // Stream this worktree's diagnostics.
1472            for summary in worktree.diagnostic_summaries {
1473                session.peer.send(
1474                    session.connection_id,
1475                    proto::UpdateDiagnosticSummary {
1476                        project_id: project.id.to_proto(),
1477                        worktree_id: worktree.id,
1478                        summary: Some(summary),
1479                    },
1480                )?;
1481            }
1482
1483            for settings_file in worktree.settings_files {
1484                session.peer.send(
1485                    session.connection_id,
1486                    proto::UpdateWorktreeSettings {
1487                        project_id: project.id.to_proto(),
1488                        worktree_id: worktree.id,
1489                        path: settings_file.path,
1490                        content: Some(settings_file.content),
1491                        kind: Some(settings_file.kind.to_proto().into()),
1492                    },
1493                )?;
1494            }
1495        }
1496
1497        for repository in mem::take(&mut project.updated_repositories) {
1498            for update in split_repository_update(repository) {
1499                session.peer.send(session.connection_id, update)?;
1500            }
1501        }
1502
1503        for id in mem::take(&mut project.removed_repositories) {
1504            session.peer.send(
1505                session.connection_id,
1506                proto::RemoveRepository {
1507                    project_id: project.id.to_proto(),
1508                    id,
1509                },
1510            )?;
1511        }
1512    }
1513
1514    Ok(())
1515}
1516
1517/// leave room disconnects from the room.
1518async fn leave_room(
1519    _: proto::LeaveRoom,
1520    response: Response<proto::LeaveRoom>,
1521    session: Session,
1522) -> Result<()> {
1523    leave_room_for_session(&session, session.connection_id).await?;
1524    response.send(proto::Ack {})?;
1525    Ok(())
1526}
1527
1528/// Updates the permissions of someone else in the room.
1529async fn set_room_participant_role(
1530    request: proto::SetRoomParticipantRole,
1531    response: Response<proto::SetRoomParticipantRole>,
1532    session: Session,
1533) -> Result<()> {
1534    let user_id = UserId::from_proto(request.user_id);
1535    let role = ChannelRole::from(request.role());
1536
1537    let (livekit_room, can_publish) = {
1538        let room = session
1539            .db()
1540            .await
1541            .set_room_participant_role(
1542                session.user_id(),
1543                RoomId::from_proto(request.room_id),
1544                user_id,
1545                role,
1546            )
1547            .await?;
1548
1549        let livekit_room = room.livekit_room.clone();
1550        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1551        room_updated(&room, &session.peer);
1552        (livekit_room, can_publish)
1553    };
1554
1555    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1556        live_kit
1557            .update_participant(
1558                livekit_room.clone(),
1559                request.user_id.to_string(),
1560                livekit_api::proto::ParticipantPermission {
1561                    can_subscribe: true,
1562                    can_publish,
1563                    can_publish_data: can_publish,
1564                    hidden: false,
1565                    recorder: false,
1566                },
1567            )
1568            .await
1569            .trace_err();
1570    }
1571
1572    response.send(proto::Ack {})?;
1573    Ok(())
1574}
1575
1576/// Call someone else into the current room
1577async fn call(
1578    request: proto::Call,
1579    response: Response<proto::Call>,
1580    session: Session,
1581) -> Result<()> {
1582    let room_id = RoomId::from_proto(request.room_id);
1583    let calling_user_id = session.user_id();
1584    let calling_connection_id = session.connection_id;
1585    let called_user_id = UserId::from_proto(request.called_user_id);
1586    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1587    if !session
1588        .db()
1589        .await
1590        .has_contact(calling_user_id, called_user_id)
1591        .await?
1592    {
1593        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1594    }
1595
1596    let incoming_call = {
1597        let (room, incoming_call) = &mut *session
1598            .db()
1599            .await
1600            .call(
1601                room_id,
1602                calling_user_id,
1603                calling_connection_id,
1604                called_user_id,
1605                initial_project_id,
1606            )
1607            .await?;
1608        room_updated(room, &session.peer);
1609        mem::take(incoming_call)
1610    };
1611    update_user_contacts(called_user_id, &session).await?;
1612
1613    let mut calls = session
1614        .connection_pool()
1615        .await
1616        .user_connection_ids(called_user_id)
1617        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1618        .collect::<FuturesUnordered<_>>();
1619
1620    while let Some(call_response) = calls.next().await {
1621        match call_response.as_ref() {
1622            Ok(_) => {
1623                response.send(proto::Ack {})?;
1624                return Ok(());
1625            }
1626            Err(_) => {
1627                call_response.trace_err();
1628            }
1629        }
1630    }
1631
1632    {
1633        let room = session
1634            .db()
1635            .await
1636            .call_failed(room_id, called_user_id)
1637            .await?;
1638        room_updated(&room, &session.peer);
1639    }
1640    update_user_contacts(called_user_id, &session).await?;
1641
1642    Err(anyhow!("failed to ring user"))?
1643}
1644
1645/// Cancel an outgoing call.
1646async fn cancel_call(
1647    request: proto::CancelCall,
1648    response: Response<proto::CancelCall>,
1649    session: Session,
1650) -> Result<()> {
1651    let called_user_id = UserId::from_proto(request.called_user_id);
1652    let room_id = RoomId::from_proto(request.room_id);
1653    {
1654        let room = session
1655            .db()
1656            .await
1657            .cancel_call(room_id, session.connection_id, called_user_id)
1658            .await?;
1659        room_updated(&room, &session.peer);
1660    }
1661
1662    for connection_id in session
1663        .connection_pool()
1664        .await
1665        .user_connection_ids(called_user_id)
1666    {
1667        session
1668            .peer
1669            .send(
1670                connection_id,
1671                proto::CallCanceled {
1672                    room_id: room_id.to_proto(),
1673                },
1674            )
1675            .trace_err();
1676    }
1677    response.send(proto::Ack {})?;
1678
1679    update_user_contacts(called_user_id, &session).await?;
1680    Ok(())
1681}
1682
1683/// Decline an incoming call.
1684async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1685    let room_id = RoomId::from_proto(message.room_id);
1686    {
1687        let room = session
1688            .db()
1689            .await
1690            .decline_call(Some(room_id), session.user_id())
1691            .await?
1692            .ok_or_else(|| anyhow!("failed to decline call"))?;
1693        room_updated(&room, &session.peer);
1694    }
1695
1696    for connection_id in session
1697        .connection_pool()
1698        .await
1699        .user_connection_ids(session.user_id())
1700    {
1701        session
1702            .peer
1703            .send(
1704                connection_id,
1705                proto::CallCanceled {
1706                    room_id: room_id.to_proto(),
1707                },
1708            )
1709            .trace_err();
1710    }
1711    update_user_contacts(session.user_id(), &session).await?;
1712    Ok(())
1713}
1714
1715/// Updates other participants in the room with your current location.
1716async fn update_participant_location(
1717    request: proto::UpdateParticipantLocation,
1718    response: Response<proto::UpdateParticipantLocation>,
1719    session: Session,
1720) -> Result<()> {
1721    let room_id = RoomId::from_proto(request.room_id);
1722    let location = request
1723        .location
1724        .ok_or_else(|| anyhow!("invalid location"))?;
1725
1726    let db = session.db().await;
1727    let room = db
1728        .update_room_participant_location(room_id, session.connection_id, location)
1729        .await?;
1730
1731    room_updated(&room, &session.peer);
1732    response.send(proto::Ack {})?;
1733    Ok(())
1734}
1735
1736/// Share a project into the room.
1737async fn share_project(
1738    request: proto::ShareProject,
1739    response: Response<proto::ShareProject>,
1740    session: Session,
1741) -> Result<()> {
1742    let (project_id, room) = &*session
1743        .db()
1744        .await
1745        .share_project(
1746            RoomId::from_proto(request.room_id),
1747            session.connection_id,
1748            &request.worktrees,
1749            request.is_ssh_project,
1750        )
1751        .await?;
1752    response.send(proto::ShareProjectResponse {
1753        project_id: project_id.to_proto(),
1754    })?;
1755    room_updated(room, &session.peer);
1756
1757    Ok(())
1758}
1759
1760/// Unshare a project from the room.
1761async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1762    let project_id = ProjectId::from_proto(message.project_id);
1763    unshare_project_internal(project_id, session.connection_id, &session).await
1764}
1765
1766async fn unshare_project_internal(
1767    project_id: ProjectId,
1768    connection_id: ConnectionId,
1769    session: &Session,
1770) -> Result<()> {
1771    let delete = {
1772        let room_guard = session
1773            .db()
1774            .await
1775            .unshare_project(project_id, connection_id)
1776            .await?;
1777
1778        let (delete, room, guest_connection_ids) = &*room_guard;
1779
1780        let message = proto::UnshareProject {
1781            project_id: project_id.to_proto(),
1782        };
1783
1784        broadcast(
1785            Some(connection_id),
1786            guest_connection_ids.iter().copied(),
1787            |conn_id| session.peer.send(conn_id, message.clone()),
1788        );
1789        if let Some(room) = room {
1790            room_updated(room, &session.peer);
1791        }
1792
1793        *delete
1794    };
1795
1796    if delete {
1797        let db = session.db().await;
1798        db.delete_project(project_id).await?;
1799    }
1800
1801    Ok(())
1802}
1803
1804/// Join someone elses shared project.
1805async fn join_project(
1806    request: proto::JoinProject,
1807    response: Response<proto::JoinProject>,
1808    session: Session,
1809) -> Result<()> {
1810    let project_id = ProjectId::from_proto(request.project_id);
1811
1812    tracing::info!(%project_id, "join project");
1813
1814    let db = session.db().await;
1815    let (project, replica_id) = &mut *db
1816        .join_project(project_id, session.connection_id, session.user_id())
1817        .await?;
1818    drop(db);
1819    tracing::info!(%project_id, "join remote project");
1820    join_project_internal(response, session, project, replica_id)
1821}
1822
1823trait JoinProjectInternalResponse {
1824    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1825}
1826impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1827    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1828        Response::<proto::JoinProject>::send(self, result)
1829    }
1830}
1831
1832fn join_project_internal(
1833    response: impl JoinProjectInternalResponse,
1834    session: Session,
1835    project: &mut Project,
1836    replica_id: &ReplicaId,
1837) -> Result<()> {
1838    let collaborators = project
1839        .collaborators
1840        .iter()
1841        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1842        .map(|collaborator| collaborator.to_proto())
1843        .collect::<Vec<_>>();
1844    let project_id = project.id;
1845    let guest_user_id = session.user_id();
1846
1847    let worktrees = project
1848        .worktrees
1849        .iter()
1850        .map(|(id, worktree)| proto::WorktreeMetadata {
1851            id: *id,
1852            root_name: worktree.root_name.clone(),
1853            visible: worktree.visible,
1854            abs_path: worktree.abs_path.clone(),
1855        })
1856        .collect::<Vec<_>>();
1857
1858    let add_project_collaborator = proto::AddProjectCollaborator {
1859        project_id: project_id.to_proto(),
1860        collaborator: Some(proto::Collaborator {
1861            peer_id: Some(session.connection_id.into()),
1862            replica_id: replica_id.0 as u32,
1863            user_id: guest_user_id.to_proto(),
1864            is_host: false,
1865        }),
1866    };
1867
1868    for collaborator in &collaborators {
1869        session
1870            .peer
1871            .send(
1872                collaborator.peer_id.unwrap().into(),
1873                add_project_collaborator.clone(),
1874            )
1875            .trace_err();
1876    }
1877
1878    // First, we send the metadata associated with each worktree.
1879    response.send(proto::JoinProjectResponse {
1880        project_id: project.id.0 as u64,
1881        worktrees: worktrees.clone(),
1882        replica_id: replica_id.0 as u32,
1883        collaborators: collaborators.clone(),
1884        language_servers: project.language_servers.clone(),
1885        role: project.role.into(),
1886    })?;
1887
1888    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1889        // Stream this worktree's entries.
1890        let message = proto::UpdateWorktree {
1891            project_id: project_id.to_proto(),
1892            worktree_id,
1893            abs_path: worktree.abs_path.clone(),
1894            root_name: worktree.root_name,
1895            updated_entries: worktree.entries,
1896            removed_entries: Default::default(),
1897            scan_id: worktree.scan_id,
1898            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1899            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1900            removed_repositories: Default::default(),
1901        };
1902        for update in proto::split_worktree_update(message) {
1903            session.peer.send(session.connection_id, update.clone())?;
1904        }
1905
1906        // Stream this worktree's diagnostics.
1907        for summary in worktree.diagnostic_summaries {
1908            session.peer.send(
1909                session.connection_id,
1910                proto::UpdateDiagnosticSummary {
1911                    project_id: project_id.to_proto(),
1912                    worktree_id: worktree.id,
1913                    summary: Some(summary),
1914                },
1915            )?;
1916        }
1917
1918        for settings_file in worktree.settings_files {
1919            session.peer.send(
1920                session.connection_id,
1921                proto::UpdateWorktreeSettings {
1922                    project_id: project_id.to_proto(),
1923                    worktree_id: worktree.id,
1924                    path: settings_file.path,
1925                    content: Some(settings_file.content),
1926                    kind: Some(settings_file.kind.to_proto() as i32),
1927                },
1928            )?;
1929        }
1930    }
1931
1932    for repository in mem::take(&mut project.repositories) {
1933        for update in split_repository_update(repository) {
1934            session.peer.send(session.connection_id, update)?;
1935        }
1936    }
1937
1938    for language_server in &project.language_servers {
1939        session.peer.send(
1940            session.connection_id,
1941            proto::UpdateLanguageServer {
1942                project_id: project_id.to_proto(),
1943                language_server_id: language_server.id,
1944                variant: Some(
1945                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1946                        proto::LspDiskBasedDiagnosticsUpdated {},
1947                    ),
1948                ),
1949            },
1950        )?;
1951    }
1952
1953    Ok(())
1954}
1955
1956/// Leave someone elses shared project.
1957async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1958    let sender_id = session.connection_id;
1959    let project_id = ProjectId::from_proto(request.project_id);
1960    let db = session.db().await;
1961
1962    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1963    tracing::info!(
1964        %project_id,
1965        "leave project"
1966    );
1967
1968    project_left(project, &session);
1969    if let Some(room) = room {
1970        room_updated(room, &session.peer);
1971    }
1972
1973    Ok(())
1974}
1975
1976/// Updates other participants with changes to the project
1977async fn update_project(
1978    request: proto::UpdateProject,
1979    response: Response<proto::UpdateProject>,
1980    session: Session,
1981) -> Result<()> {
1982    let project_id = ProjectId::from_proto(request.project_id);
1983    let (room, guest_connection_ids) = &*session
1984        .db()
1985        .await
1986        .update_project(project_id, session.connection_id, &request.worktrees)
1987        .await?;
1988    broadcast(
1989        Some(session.connection_id),
1990        guest_connection_ids.iter().copied(),
1991        |connection_id| {
1992            session
1993                .peer
1994                .forward_send(session.connection_id, connection_id, request.clone())
1995        },
1996    );
1997    if let Some(room) = room {
1998        room_updated(room, &session.peer);
1999    }
2000    response.send(proto::Ack {})?;
2001
2002    Ok(())
2003}
2004
2005/// Updates other participants with changes to the worktree
2006async fn update_worktree(
2007    request: proto::UpdateWorktree,
2008    response: Response<proto::UpdateWorktree>,
2009    session: Session,
2010) -> Result<()> {
2011    let guest_connection_ids = session
2012        .db()
2013        .await
2014        .update_worktree(&request, session.connection_id)
2015        .await?;
2016
2017    broadcast(
2018        Some(session.connection_id),
2019        guest_connection_ids.iter().copied(),
2020        |connection_id| {
2021            session
2022                .peer
2023                .forward_send(session.connection_id, connection_id, request.clone())
2024        },
2025    );
2026    response.send(proto::Ack {})?;
2027    Ok(())
2028}
2029
2030async fn update_repository(
2031    request: proto::UpdateRepository,
2032    response: Response<proto::UpdateRepository>,
2033    session: Session,
2034) -> Result<()> {
2035    let guest_connection_ids = session
2036        .db()
2037        .await
2038        .update_repository(&request, session.connection_id)
2039        .await?;
2040
2041    broadcast(
2042        Some(session.connection_id),
2043        guest_connection_ids.iter().copied(),
2044        |connection_id| {
2045            session
2046                .peer
2047                .forward_send(session.connection_id, connection_id, request.clone())
2048        },
2049    );
2050    response.send(proto::Ack {})?;
2051    Ok(())
2052}
2053
2054async fn remove_repository(
2055    request: proto::RemoveRepository,
2056    response: Response<proto::RemoveRepository>,
2057    session: Session,
2058) -> Result<()> {
2059    let guest_connection_ids = session
2060        .db()
2061        .await
2062        .remove_repository(&request, session.connection_id)
2063        .await?;
2064
2065    broadcast(
2066        Some(session.connection_id),
2067        guest_connection_ids.iter().copied(),
2068        |connection_id| {
2069            session
2070                .peer
2071                .forward_send(session.connection_id, connection_id, request.clone())
2072        },
2073    );
2074    response.send(proto::Ack {})?;
2075    Ok(())
2076}
2077
2078/// Updates other participants with changes to the diagnostics
2079async fn update_diagnostic_summary(
2080    message: proto::UpdateDiagnosticSummary,
2081    session: Session,
2082) -> Result<()> {
2083    let guest_connection_ids = session
2084        .db()
2085        .await
2086        .update_diagnostic_summary(&message, session.connection_id)
2087        .await?;
2088
2089    broadcast(
2090        Some(session.connection_id),
2091        guest_connection_ids.iter().copied(),
2092        |connection_id| {
2093            session
2094                .peer
2095                .forward_send(session.connection_id, connection_id, message.clone())
2096        },
2097    );
2098
2099    Ok(())
2100}
2101
2102/// Updates other participants with changes to the worktree settings
2103async fn update_worktree_settings(
2104    message: proto::UpdateWorktreeSettings,
2105    session: Session,
2106) -> Result<()> {
2107    let guest_connection_ids = session
2108        .db()
2109        .await
2110        .update_worktree_settings(&message, session.connection_id)
2111        .await?;
2112
2113    broadcast(
2114        Some(session.connection_id),
2115        guest_connection_ids.iter().copied(),
2116        |connection_id| {
2117            session
2118                .peer
2119                .forward_send(session.connection_id, connection_id, message.clone())
2120        },
2121    );
2122
2123    Ok(())
2124}
2125
2126/// Notify other participants that a language server has started.
2127async fn start_language_server(
2128    request: proto::StartLanguageServer,
2129    session: Session,
2130) -> Result<()> {
2131    let guest_connection_ids = session
2132        .db()
2133        .await
2134        .start_language_server(&request, session.connection_id)
2135        .await?;
2136
2137    broadcast(
2138        Some(session.connection_id),
2139        guest_connection_ids.iter().copied(),
2140        |connection_id| {
2141            session
2142                .peer
2143                .forward_send(session.connection_id, connection_id, request.clone())
2144        },
2145    );
2146    Ok(())
2147}
2148
2149/// Notify other participants that a language server has changed.
2150async fn update_language_server(
2151    request: proto::UpdateLanguageServer,
2152    session: Session,
2153) -> Result<()> {
2154    let project_id = ProjectId::from_proto(request.project_id);
2155    let project_connection_ids = session
2156        .db()
2157        .await
2158        .project_connection_ids(project_id, session.connection_id, true)
2159        .await?;
2160    broadcast(
2161        Some(session.connection_id),
2162        project_connection_ids.iter().copied(),
2163        |connection_id| {
2164            session
2165                .peer
2166                .forward_send(session.connection_id, connection_id, request.clone())
2167        },
2168    );
2169    Ok(())
2170}
2171
2172/// forward a project request to the host. These requests should be read only
2173/// as guests are allowed to send them.
2174async fn forward_read_only_project_request<T>(
2175    request: T,
2176    response: Response<T>,
2177    session: Session,
2178) -> Result<()>
2179where
2180    T: EntityMessage + RequestMessage,
2181{
2182    let project_id = ProjectId::from_proto(request.remote_entity_id());
2183    let host_connection_id = session
2184        .db()
2185        .await
2186        .host_for_read_only_project_request(project_id, session.connection_id)
2187        .await?;
2188    let payload = session
2189        .peer
2190        .forward_request(session.connection_id, host_connection_id, request)
2191        .await?;
2192    response.send(payload)?;
2193    Ok(())
2194}
2195
2196async fn forward_find_search_candidates_request(
2197    request: proto::FindSearchCandidates,
2198    response: Response<proto::FindSearchCandidates>,
2199    session: Session,
2200) -> Result<()> {
2201    let project_id = ProjectId::from_proto(request.remote_entity_id());
2202    let host_connection_id = session
2203        .db()
2204        .await
2205        .host_for_read_only_project_request(project_id, session.connection_id)
2206        .await?;
2207    let payload = session
2208        .peer
2209        .forward_request(session.connection_id, host_connection_id, request)
2210        .await?;
2211    response.send(payload)?;
2212    Ok(())
2213}
2214
2215/// forward a project request to the host. These requests are disallowed
2216/// for guests.
2217async fn forward_mutating_project_request<T>(
2218    request: T,
2219    response: Response<T>,
2220    session: Session,
2221) -> Result<()>
2222where
2223    T: EntityMessage + RequestMessage,
2224{
2225    let project_id = ProjectId::from_proto(request.remote_entity_id());
2226
2227    let host_connection_id = session
2228        .db()
2229        .await
2230        .host_for_mutating_project_request(project_id, session.connection_id)
2231        .await?;
2232    let payload = session
2233        .peer
2234        .forward_request(session.connection_id, host_connection_id, request)
2235        .await?;
2236    response.send(payload)?;
2237    Ok(())
2238}
2239
2240/// Notify other participants that a new buffer has been created
2241async fn create_buffer_for_peer(
2242    request: proto::CreateBufferForPeer,
2243    session: Session,
2244) -> Result<()> {
2245    session
2246        .db()
2247        .await
2248        .check_user_is_project_host(
2249            ProjectId::from_proto(request.project_id),
2250            session.connection_id,
2251        )
2252        .await?;
2253    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2254    session
2255        .peer
2256        .forward_send(session.connection_id, peer_id.into(), request)?;
2257    Ok(())
2258}
2259
2260/// Notify other participants that a buffer has been updated. This is
2261/// allowed for guests as long as the update is limited to selections.
2262async fn update_buffer(
2263    request: proto::UpdateBuffer,
2264    response: Response<proto::UpdateBuffer>,
2265    session: Session,
2266) -> Result<()> {
2267    let project_id = ProjectId::from_proto(request.project_id);
2268    let mut capability = Capability::ReadOnly;
2269
2270    for op in request.operations.iter() {
2271        match op.variant {
2272            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2273            Some(_) => capability = Capability::ReadWrite,
2274        }
2275    }
2276
2277    let host = {
2278        let guard = session
2279            .db()
2280            .await
2281            .connections_for_buffer_update(project_id, session.connection_id, capability)
2282            .await?;
2283
2284        let (host, guests) = &*guard;
2285
2286        broadcast(
2287            Some(session.connection_id),
2288            guests.clone(),
2289            |connection_id| {
2290                session
2291                    .peer
2292                    .forward_send(session.connection_id, connection_id, request.clone())
2293            },
2294        );
2295
2296        *host
2297    };
2298
2299    if host != session.connection_id {
2300        session
2301            .peer
2302            .forward_request(session.connection_id, host, request.clone())
2303            .await?;
2304    }
2305
2306    response.send(proto::Ack {})?;
2307    Ok(())
2308}
2309
2310async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2311    let project_id = ProjectId::from_proto(message.project_id);
2312
2313    let operation = message.operation.as_ref().context("invalid operation")?;
2314    let capability = match operation.variant.as_ref() {
2315        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2316            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2317                match buffer_op.variant {
2318                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2319                        Capability::ReadOnly
2320                    }
2321                    _ => Capability::ReadWrite,
2322                }
2323            } else {
2324                Capability::ReadWrite
2325            }
2326        }
2327        Some(_) => Capability::ReadWrite,
2328        None => Capability::ReadOnly,
2329    };
2330
2331    let guard = session
2332        .db()
2333        .await
2334        .connections_for_buffer_update(project_id, session.connection_id, capability)
2335        .await?;
2336
2337    let (host, guests) = &*guard;
2338
2339    broadcast(
2340        Some(session.connection_id),
2341        guests.iter().chain([host]).copied(),
2342        |connection_id| {
2343            session
2344                .peer
2345                .forward_send(session.connection_id, connection_id, message.clone())
2346        },
2347    );
2348
2349    Ok(())
2350}
2351
2352/// Notify other participants that a project has been updated.
2353async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2354    request: T,
2355    session: Session,
2356) -> Result<()> {
2357    let project_id = ProjectId::from_proto(request.remote_entity_id());
2358    let project_connection_ids = session
2359        .db()
2360        .await
2361        .project_connection_ids(project_id, session.connection_id, false)
2362        .await?;
2363
2364    broadcast(
2365        Some(session.connection_id),
2366        project_connection_ids.iter().copied(),
2367        |connection_id| {
2368            session
2369                .peer
2370                .forward_send(session.connection_id, connection_id, request.clone())
2371        },
2372    );
2373    Ok(())
2374}
2375
2376/// Start following another user in a call.
2377async fn follow(
2378    request: proto::Follow,
2379    response: Response<proto::Follow>,
2380    session: Session,
2381) -> Result<()> {
2382    let room_id = RoomId::from_proto(request.room_id);
2383    let project_id = request.project_id.map(ProjectId::from_proto);
2384    let leader_id = request
2385        .leader_id
2386        .ok_or_else(|| anyhow!("invalid leader id"))?
2387        .into();
2388    let follower_id = session.connection_id;
2389
2390    session
2391        .db()
2392        .await
2393        .check_room_participants(room_id, leader_id, session.connection_id)
2394        .await?;
2395
2396    let response_payload = session
2397        .peer
2398        .forward_request(session.connection_id, leader_id, request)
2399        .await?;
2400    response.send(response_payload)?;
2401
2402    if let Some(project_id) = project_id {
2403        let room = session
2404            .db()
2405            .await
2406            .follow(room_id, project_id, leader_id, follower_id)
2407            .await?;
2408        room_updated(&room, &session.peer);
2409    }
2410
2411    Ok(())
2412}
2413
2414/// Stop following another user in a call.
2415async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2416    let room_id = RoomId::from_proto(request.room_id);
2417    let project_id = request.project_id.map(ProjectId::from_proto);
2418    let leader_id = request
2419        .leader_id
2420        .ok_or_else(|| anyhow!("invalid leader id"))?
2421        .into();
2422    let follower_id = session.connection_id;
2423
2424    session
2425        .db()
2426        .await
2427        .check_room_participants(room_id, leader_id, session.connection_id)
2428        .await?;
2429
2430    session
2431        .peer
2432        .forward_send(session.connection_id, leader_id, request)?;
2433
2434    if let Some(project_id) = project_id {
2435        let room = session
2436            .db()
2437            .await
2438            .unfollow(room_id, project_id, leader_id, follower_id)
2439            .await?;
2440        room_updated(&room, &session.peer);
2441    }
2442
2443    Ok(())
2444}
2445
2446/// Notify everyone following you of your current location.
2447async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2448    let room_id = RoomId::from_proto(request.room_id);
2449    let database = session.db.lock().await;
2450
2451    let connection_ids = if let Some(project_id) = request.project_id {
2452        let project_id = ProjectId::from_proto(project_id);
2453        database
2454            .project_connection_ids(project_id, session.connection_id, true)
2455            .await?
2456    } else {
2457        database
2458            .room_connection_ids(room_id, session.connection_id)
2459            .await?
2460    };
2461
2462    // For now, don't send view update messages back to that view's current leader.
2463    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2464        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2465        _ => None,
2466    });
2467
2468    for connection_id in connection_ids.iter().cloned() {
2469        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2470            session
2471                .peer
2472                .forward_send(session.connection_id, connection_id, request.clone())?;
2473        }
2474    }
2475    Ok(())
2476}
2477
2478/// Get public data about users.
2479async fn get_users(
2480    request: proto::GetUsers,
2481    response: Response<proto::GetUsers>,
2482    session: Session,
2483) -> Result<()> {
2484    let user_ids = request
2485        .user_ids
2486        .into_iter()
2487        .map(UserId::from_proto)
2488        .collect();
2489    let users = session
2490        .db()
2491        .await
2492        .get_users_by_ids(user_ids)
2493        .await?
2494        .into_iter()
2495        .map(|user| proto::User {
2496            id: user.id.to_proto(),
2497            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2498            github_login: user.github_login,
2499            email: user.email_address,
2500            name: user.name,
2501        })
2502        .collect();
2503    response.send(proto::UsersResponse { users })?;
2504    Ok(())
2505}
2506
2507/// Search for users (to invite) buy Github login
2508async fn fuzzy_search_users(
2509    request: proto::FuzzySearchUsers,
2510    response: Response<proto::FuzzySearchUsers>,
2511    session: Session,
2512) -> Result<()> {
2513    let query = request.query;
2514    let users = match query.len() {
2515        0 => vec![],
2516        1 | 2 => session
2517            .db()
2518            .await
2519            .get_user_by_github_login(&query)
2520            .await?
2521            .into_iter()
2522            .collect(),
2523        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2524    };
2525    let users = users
2526        .into_iter()
2527        .filter(|user| user.id != session.user_id())
2528        .map(|user| proto::User {
2529            id: user.id.to_proto(),
2530            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2531            github_login: user.github_login,
2532            name: user.name,
2533            email: user.email_address,
2534        })
2535        .collect();
2536    response.send(proto::UsersResponse { users })?;
2537    Ok(())
2538}
2539
2540/// Send a contact request to another user.
2541async fn request_contact(
2542    request: proto::RequestContact,
2543    response: Response<proto::RequestContact>,
2544    session: Session,
2545) -> Result<()> {
2546    let requester_id = session.user_id();
2547    let responder_id = UserId::from_proto(request.responder_id);
2548    if requester_id == responder_id {
2549        return Err(anyhow!("cannot add yourself as a contact"))?;
2550    }
2551
2552    let notifications = session
2553        .db()
2554        .await
2555        .send_contact_request(requester_id, responder_id)
2556        .await?;
2557
2558    // Update outgoing contact requests of requester
2559    let mut update = proto::UpdateContacts::default();
2560    update.outgoing_requests.push(responder_id.to_proto());
2561    for connection_id in session
2562        .connection_pool()
2563        .await
2564        .user_connection_ids(requester_id)
2565    {
2566        session.peer.send(connection_id, update.clone())?;
2567    }
2568
2569    // Update incoming contact requests of responder
2570    let mut update = proto::UpdateContacts::default();
2571    update
2572        .incoming_requests
2573        .push(proto::IncomingContactRequest {
2574            requester_id: requester_id.to_proto(),
2575        });
2576    let connection_pool = session.connection_pool().await;
2577    for connection_id in connection_pool.user_connection_ids(responder_id) {
2578        session.peer.send(connection_id, update.clone())?;
2579    }
2580
2581    send_notifications(&connection_pool, &session.peer, notifications);
2582
2583    response.send(proto::Ack {})?;
2584    Ok(())
2585}
2586
2587/// Accept or decline a contact request
2588async fn respond_to_contact_request(
2589    request: proto::RespondToContactRequest,
2590    response: Response<proto::RespondToContactRequest>,
2591    session: Session,
2592) -> Result<()> {
2593    let responder_id = session.user_id();
2594    let requester_id = UserId::from_proto(request.requester_id);
2595    let db = session.db().await;
2596    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2597        db.dismiss_contact_notification(responder_id, requester_id)
2598            .await?;
2599    } else {
2600        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2601
2602        let notifications = db
2603            .respond_to_contact_request(responder_id, requester_id, accept)
2604            .await?;
2605        let requester_busy = db.is_user_busy(requester_id).await?;
2606        let responder_busy = db.is_user_busy(responder_id).await?;
2607
2608        let pool = session.connection_pool().await;
2609        // Update responder with new contact
2610        let mut update = proto::UpdateContacts::default();
2611        if accept {
2612            update
2613                .contacts
2614                .push(contact_for_user(requester_id, requester_busy, &pool));
2615        }
2616        update
2617            .remove_incoming_requests
2618            .push(requester_id.to_proto());
2619        for connection_id in pool.user_connection_ids(responder_id) {
2620            session.peer.send(connection_id, update.clone())?;
2621        }
2622
2623        // Update requester with new contact
2624        let mut update = proto::UpdateContacts::default();
2625        if accept {
2626            update
2627                .contacts
2628                .push(contact_for_user(responder_id, responder_busy, &pool));
2629        }
2630        update
2631            .remove_outgoing_requests
2632            .push(responder_id.to_proto());
2633
2634        for connection_id in pool.user_connection_ids(requester_id) {
2635            session.peer.send(connection_id, update.clone())?;
2636        }
2637
2638        send_notifications(&pool, &session.peer, notifications);
2639    }
2640
2641    response.send(proto::Ack {})?;
2642    Ok(())
2643}
2644
2645/// Remove a contact.
2646async fn remove_contact(
2647    request: proto::RemoveContact,
2648    response: Response<proto::RemoveContact>,
2649    session: Session,
2650) -> Result<()> {
2651    let requester_id = session.user_id();
2652    let responder_id = UserId::from_proto(request.user_id);
2653    let db = session.db().await;
2654    let (contact_accepted, deleted_notification_id) =
2655        db.remove_contact(requester_id, responder_id).await?;
2656
2657    let pool = session.connection_pool().await;
2658    // Update outgoing contact requests of requester
2659    let mut update = proto::UpdateContacts::default();
2660    if contact_accepted {
2661        update.remove_contacts.push(responder_id.to_proto());
2662    } else {
2663        update
2664            .remove_outgoing_requests
2665            .push(responder_id.to_proto());
2666    }
2667    for connection_id in pool.user_connection_ids(requester_id) {
2668        session.peer.send(connection_id, update.clone())?;
2669    }
2670
2671    // Update incoming contact requests of responder
2672    let mut update = proto::UpdateContacts::default();
2673    if contact_accepted {
2674        update.remove_contacts.push(requester_id.to_proto());
2675    } else {
2676        update
2677            .remove_incoming_requests
2678            .push(requester_id.to_proto());
2679    }
2680    for connection_id in pool.user_connection_ids(responder_id) {
2681        session.peer.send(connection_id, update.clone())?;
2682        if let Some(notification_id) = deleted_notification_id {
2683            session.peer.send(
2684                connection_id,
2685                proto::DeleteNotification {
2686                    notification_id: notification_id.to_proto(),
2687                },
2688            )?;
2689        }
2690    }
2691
2692    response.send(proto::Ack {})?;
2693    Ok(())
2694}
2695
2696fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2697    version.0.minor() < 139
2698}
2699
2700async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2701    let plan = session.current_plan(&session.db().await).await?;
2702
2703    session
2704        .peer
2705        .send(
2706            session.connection_id,
2707            proto::UpdateUserPlan { plan: plan.into() },
2708        )
2709        .trace_err();
2710
2711    Ok(())
2712}
2713
2714async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2715    subscribe_user_to_channels(session.user_id(), &session).await?;
2716    Ok(())
2717}
2718
2719async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2720    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2721    let mut pool = session.connection_pool().await;
2722    for membership in &channels_for_user.channel_memberships {
2723        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2724    }
2725    session.peer.send(
2726        session.connection_id,
2727        build_update_user_channels(&channels_for_user),
2728    )?;
2729    session.peer.send(
2730        session.connection_id,
2731        build_channels_update(channels_for_user),
2732    )?;
2733    Ok(())
2734}
2735
2736/// Creates a new channel.
2737async fn create_channel(
2738    request: proto::CreateChannel,
2739    response: Response<proto::CreateChannel>,
2740    session: Session,
2741) -> Result<()> {
2742    let db = session.db().await;
2743
2744    let parent_id = request.parent_id.map(ChannelId::from_proto);
2745    let (channel, membership) = db
2746        .create_channel(&request.name, parent_id, session.user_id())
2747        .await?;
2748
2749    let root_id = channel.root_id();
2750    let channel = Channel::from_model(channel);
2751
2752    response.send(proto::CreateChannelResponse {
2753        channel: Some(channel.to_proto()),
2754        parent_id: request.parent_id,
2755    })?;
2756
2757    let mut connection_pool = session.connection_pool().await;
2758    if let Some(membership) = membership {
2759        connection_pool.subscribe_to_channel(
2760            membership.user_id,
2761            membership.channel_id,
2762            membership.role,
2763        );
2764        let update = proto::UpdateUserChannels {
2765            channel_memberships: vec![proto::ChannelMembership {
2766                channel_id: membership.channel_id.to_proto(),
2767                role: membership.role.into(),
2768            }],
2769            ..Default::default()
2770        };
2771        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2772            session.peer.send(connection_id, update.clone())?;
2773        }
2774    }
2775
2776    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2777        if !role.can_see_channel(channel.visibility) {
2778            continue;
2779        }
2780
2781        let update = proto::UpdateChannels {
2782            channels: vec![channel.to_proto()],
2783            ..Default::default()
2784        };
2785        session.peer.send(connection_id, update.clone())?;
2786    }
2787
2788    Ok(())
2789}
2790
2791/// Delete a channel
2792async fn delete_channel(
2793    request: proto::DeleteChannel,
2794    response: Response<proto::DeleteChannel>,
2795    session: Session,
2796) -> Result<()> {
2797    let db = session.db().await;
2798
2799    let channel_id = request.channel_id;
2800    let (root_channel, removed_channels) = db
2801        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2802        .await?;
2803    response.send(proto::Ack {})?;
2804
2805    // Notify members of removed channels
2806    let mut update = proto::UpdateChannels::default();
2807    update
2808        .delete_channels
2809        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2810
2811    let connection_pool = session.connection_pool().await;
2812    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2813        session.peer.send(connection_id, update.clone())?;
2814    }
2815
2816    Ok(())
2817}
2818
2819/// Invite someone to join a channel.
2820async fn invite_channel_member(
2821    request: proto::InviteChannelMember,
2822    response: Response<proto::InviteChannelMember>,
2823    session: Session,
2824) -> Result<()> {
2825    let db = session.db().await;
2826    let channel_id = ChannelId::from_proto(request.channel_id);
2827    let invitee_id = UserId::from_proto(request.user_id);
2828    let InviteMemberResult {
2829        channel,
2830        notifications,
2831    } = db
2832        .invite_channel_member(
2833            channel_id,
2834            invitee_id,
2835            session.user_id(),
2836            request.role().into(),
2837        )
2838        .await?;
2839
2840    let update = proto::UpdateChannels {
2841        channel_invitations: vec![channel.to_proto()],
2842        ..Default::default()
2843    };
2844
2845    let connection_pool = session.connection_pool().await;
2846    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2847        session.peer.send(connection_id, update.clone())?;
2848    }
2849
2850    send_notifications(&connection_pool, &session.peer, notifications);
2851
2852    response.send(proto::Ack {})?;
2853    Ok(())
2854}
2855
2856/// remove someone from a channel
2857async fn remove_channel_member(
2858    request: proto::RemoveChannelMember,
2859    response: Response<proto::RemoveChannelMember>,
2860    session: Session,
2861) -> Result<()> {
2862    let db = session.db().await;
2863    let channel_id = ChannelId::from_proto(request.channel_id);
2864    let member_id = UserId::from_proto(request.user_id);
2865
2866    let RemoveChannelMemberResult {
2867        membership_update,
2868        notification_id,
2869    } = db
2870        .remove_channel_member(channel_id, member_id, session.user_id())
2871        .await?;
2872
2873    let mut connection_pool = session.connection_pool().await;
2874    notify_membership_updated(
2875        &mut connection_pool,
2876        membership_update,
2877        member_id,
2878        &session.peer,
2879    );
2880    for connection_id in connection_pool.user_connection_ids(member_id) {
2881        if let Some(notification_id) = notification_id {
2882            session
2883                .peer
2884                .send(
2885                    connection_id,
2886                    proto::DeleteNotification {
2887                        notification_id: notification_id.to_proto(),
2888                    },
2889                )
2890                .trace_err();
2891        }
2892    }
2893
2894    response.send(proto::Ack {})?;
2895    Ok(())
2896}
2897
2898/// Toggle the channel between public and private.
2899/// Care is taken to maintain the invariant that public channels only descend from public channels,
2900/// (though members-only channels can appear at any point in the hierarchy).
2901async fn set_channel_visibility(
2902    request: proto::SetChannelVisibility,
2903    response: Response<proto::SetChannelVisibility>,
2904    session: Session,
2905) -> Result<()> {
2906    let db = session.db().await;
2907    let channel_id = ChannelId::from_proto(request.channel_id);
2908    let visibility = request.visibility().into();
2909
2910    let channel_model = db
2911        .set_channel_visibility(channel_id, visibility, session.user_id())
2912        .await?;
2913    let root_id = channel_model.root_id();
2914    let channel = Channel::from_model(channel_model);
2915
2916    let mut connection_pool = session.connection_pool().await;
2917    for (user_id, role) in connection_pool
2918        .channel_user_ids(root_id)
2919        .collect::<Vec<_>>()
2920        .into_iter()
2921    {
2922        let update = if role.can_see_channel(channel.visibility) {
2923            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2924            proto::UpdateChannels {
2925                channels: vec![channel.to_proto()],
2926                ..Default::default()
2927            }
2928        } else {
2929            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2930            proto::UpdateChannels {
2931                delete_channels: vec![channel.id.to_proto()],
2932                ..Default::default()
2933            }
2934        };
2935
2936        for connection_id in connection_pool.user_connection_ids(user_id) {
2937            session.peer.send(connection_id, update.clone())?;
2938        }
2939    }
2940
2941    response.send(proto::Ack {})?;
2942    Ok(())
2943}
2944
2945/// Alter the role for a user in the channel.
2946async fn set_channel_member_role(
2947    request: proto::SetChannelMemberRole,
2948    response: Response<proto::SetChannelMemberRole>,
2949    session: Session,
2950) -> Result<()> {
2951    let db = session.db().await;
2952    let channel_id = ChannelId::from_proto(request.channel_id);
2953    let member_id = UserId::from_proto(request.user_id);
2954    let result = db
2955        .set_channel_member_role(
2956            channel_id,
2957            session.user_id(),
2958            member_id,
2959            request.role().into(),
2960        )
2961        .await?;
2962
2963    match result {
2964        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2965            let mut connection_pool = session.connection_pool().await;
2966            notify_membership_updated(
2967                &mut connection_pool,
2968                membership_update,
2969                member_id,
2970                &session.peer,
2971            )
2972        }
2973        db::SetMemberRoleResult::InviteUpdated(channel) => {
2974            let update = proto::UpdateChannels {
2975                channel_invitations: vec![channel.to_proto()],
2976                ..Default::default()
2977            };
2978
2979            for connection_id in session
2980                .connection_pool()
2981                .await
2982                .user_connection_ids(member_id)
2983            {
2984                session.peer.send(connection_id, update.clone())?;
2985            }
2986        }
2987    }
2988
2989    response.send(proto::Ack {})?;
2990    Ok(())
2991}
2992
2993/// Change the name of a channel
2994async fn rename_channel(
2995    request: proto::RenameChannel,
2996    response: Response<proto::RenameChannel>,
2997    session: Session,
2998) -> Result<()> {
2999    let db = session.db().await;
3000    let channel_id = ChannelId::from_proto(request.channel_id);
3001    let channel_model = db
3002        .rename_channel(channel_id, session.user_id(), &request.name)
3003        .await?;
3004    let root_id = channel_model.root_id();
3005    let channel = Channel::from_model(channel_model);
3006
3007    response.send(proto::RenameChannelResponse {
3008        channel: Some(channel.to_proto()),
3009    })?;
3010
3011    let connection_pool = session.connection_pool().await;
3012    let update = proto::UpdateChannels {
3013        channels: vec![channel.to_proto()],
3014        ..Default::default()
3015    };
3016    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3017        if role.can_see_channel(channel.visibility) {
3018            session.peer.send(connection_id, update.clone())?;
3019        }
3020    }
3021
3022    Ok(())
3023}
3024
3025/// Move a channel to a new parent.
3026async fn move_channel(
3027    request: proto::MoveChannel,
3028    response: Response<proto::MoveChannel>,
3029    session: Session,
3030) -> Result<()> {
3031    let channel_id = ChannelId::from_proto(request.channel_id);
3032    let to = ChannelId::from_proto(request.to);
3033
3034    let (root_id, channels) = session
3035        .db()
3036        .await
3037        .move_channel(channel_id, to, session.user_id())
3038        .await?;
3039
3040    let connection_pool = session.connection_pool().await;
3041    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3042        let channels = channels
3043            .iter()
3044            .filter_map(|channel| {
3045                if role.can_see_channel(channel.visibility) {
3046                    Some(channel.to_proto())
3047                } else {
3048                    None
3049                }
3050            })
3051            .collect::<Vec<_>>();
3052        if channels.is_empty() {
3053            continue;
3054        }
3055
3056        let update = proto::UpdateChannels {
3057            channels,
3058            ..Default::default()
3059        };
3060
3061        session.peer.send(connection_id, update.clone())?;
3062    }
3063
3064    response.send(Ack {})?;
3065    Ok(())
3066}
3067
3068/// Get the list of channel members
3069async fn get_channel_members(
3070    request: proto::GetChannelMembers,
3071    response: Response<proto::GetChannelMembers>,
3072    session: Session,
3073) -> Result<()> {
3074    let db = session.db().await;
3075    let channel_id = ChannelId::from_proto(request.channel_id);
3076    let limit = if request.limit == 0 {
3077        u16::MAX as u64
3078    } else {
3079        request.limit
3080    };
3081    let (members, users) = db
3082        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3083        .await?;
3084    response.send(proto::GetChannelMembersResponse { members, users })?;
3085    Ok(())
3086}
3087
3088/// Accept or decline a channel invitation.
3089async fn respond_to_channel_invite(
3090    request: proto::RespondToChannelInvite,
3091    response: Response<proto::RespondToChannelInvite>,
3092    session: Session,
3093) -> Result<()> {
3094    let db = session.db().await;
3095    let channel_id = ChannelId::from_proto(request.channel_id);
3096    let RespondToChannelInvite {
3097        membership_update,
3098        notifications,
3099    } = db
3100        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3101        .await?;
3102
3103    let mut connection_pool = session.connection_pool().await;
3104    if let Some(membership_update) = membership_update {
3105        notify_membership_updated(
3106            &mut connection_pool,
3107            membership_update,
3108            session.user_id(),
3109            &session.peer,
3110        );
3111    } else {
3112        let update = proto::UpdateChannels {
3113            remove_channel_invitations: vec![channel_id.to_proto()],
3114            ..Default::default()
3115        };
3116
3117        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3118            session.peer.send(connection_id, update.clone())?;
3119        }
3120    };
3121
3122    send_notifications(&connection_pool, &session.peer, notifications);
3123
3124    response.send(proto::Ack {})?;
3125
3126    Ok(())
3127}
3128
3129/// Join the channels' room
3130async fn join_channel(
3131    request: proto::JoinChannel,
3132    response: Response<proto::JoinChannel>,
3133    session: Session,
3134) -> Result<()> {
3135    let channel_id = ChannelId::from_proto(request.channel_id);
3136    join_channel_internal(channel_id, Box::new(response), session).await
3137}
3138
3139trait JoinChannelInternalResponse {
3140    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3141}
3142impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3143    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3144        Response::<proto::JoinChannel>::send(self, result)
3145    }
3146}
3147impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3148    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3149        Response::<proto::JoinRoom>::send(self, result)
3150    }
3151}
3152
3153async fn join_channel_internal(
3154    channel_id: ChannelId,
3155    response: Box<impl JoinChannelInternalResponse>,
3156    session: Session,
3157) -> Result<()> {
3158    let joined_room = {
3159        let mut db = session.db().await;
3160        // If zed quits without leaving the room, and the user re-opens zed before the
3161        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3162        // room they were in.
3163        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3164            tracing::info!(
3165                stale_connection_id = %connection,
3166                "cleaning up stale connection",
3167            );
3168            drop(db);
3169            leave_room_for_session(&session, connection).await?;
3170            db = session.db().await;
3171        }
3172
3173        let (joined_room, membership_updated, role) = db
3174            .join_channel(channel_id, session.user_id(), session.connection_id)
3175            .await?;
3176
3177        let live_kit_connection_info =
3178            session
3179                .app_state
3180                .livekit_client
3181                .as_ref()
3182                .and_then(|live_kit| {
3183                    let (can_publish, token) = if role == ChannelRole::Guest {
3184                        (
3185                            false,
3186                            live_kit
3187                                .guest_token(
3188                                    &joined_room.room.livekit_room,
3189                                    &session.user_id().to_string(),
3190                                )
3191                                .trace_err()?,
3192                        )
3193                    } else {
3194                        (
3195                            true,
3196                            live_kit
3197                                .room_token(
3198                                    &joined_room.room.livekit_room,
3199                                    &session.user_id().to_string(),
3200                                )
3201                                .trace_err()?,
3202                        )
3203                    };
3204
3205                    Some(LiveKitConnectionInfo {
3206                        server_url: live_kit.url().into(),
3207                        token,
3208                        can_publish,
3209                    })
3210                });
3211
3212        response.send(proto::JoinRoomResponse {
3213            room: Some(joined_room.room.clone()),
3214            channel_id: joined_room
3215                .channel
3216                .as_ref()
3217                .map(|channel| channel.id.to_proto()),
3218            live_kit_connection_info,
3219        })?;
3220
3221        let mut connection_pool = session.connection_pool().await;
3222        if let Some(membership_updated) = membership_updated {
3223            notify_membership_updated(
3224                &mut connection_pool,
3225                membership_updated,
3226                session.user_id(),
3227                &session.peer,
3228            );
3229        }
3230
3231        room_updated(&joined_room.room, &session.peer);
3232
3233        joined_room
3234    };
3235
3236    channel_updated(
3237        &joined_room
3238            .channel
3239            .ok_or_else(|| anyhow!("channel not returned"))?,
3240        &joined_room.room,
3241        &session.peer,
3242        &*session.connection_pool().await,
3243    );
3244
3245    update_user_contacts(session.user_id(), &session).await?;
3246    Ok(())
3247}
3248
3249/// Start editing the channel notes
3250async fn join_channel_buffer(
3251    request: proto::JoinChannelBuffer,
3252    response: Response<proto::JoinChannelBuffer>,
3253    session: Session,
3254) -> Result<()> {
3255    let db = session.db().await;
3256    let channel_id = ChannelId::from_proto(request.channel_id);
3257
3258    let open_response = db
3259        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3260        .await?;
3261
3262    let collaborators = open_response.collaborators.clone();
3263    response.send(open_response)?;
3264
3265    let update = UpdateChannelBufferCollaborators {
3266        channel_id: channel_id.to_proto(),
3267        collaborators: collaborators.clone(),
3268    };
3269    channel_buffer_updated(
3270        session.connection_id,
3271        collaborators
3272            .iter()
3273            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3274        &update,
3275        &session.peer,
3276    );
3277
3278    Ok(())
3279}
3280
3281/// Edit the channel notes
3282async fn update_channel_buffer(
3283    request: proto::UpdateChannelBuffer,
3284    session: Session,
3285) -> Result<()> {
3286    let db = session.db().await;
3287    let channel_id = ChannelId::from_proto(request.channel_id);
3288
3289    let (collaborators, epoch, version) = db
3290        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3291        .await?;
3292
3293    channel_buffer_updated(
3294        session.connection_id,
3295        collaborators.clone(),
3296        &proto::UpdateChannelBuffer {
3297            channel_id: channel_id.to_proto(),
3298            operations: request.operations,
3299        },
3300        &session.peer,
3301    );
3302
3303    let pool = &*session.connection_pool().await;
3304
3305    let non_collaborators =
3306        pool.channel_connection_ids(channel_id)
3307            .filter_map(|(connection_id, _)| {
3308                if collaborators.contains(&connection_id) {
3309                    None
3310                } else {
3311                    Some(connection_id)
3312                }
3313            });
3314
3315    broadcast(None, non_collaborators, |peer_id| {
3316        session.peer.send(
3317            peer_id,
3318            proto::UpdateChannels {
3319                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3320                    channel_id: channel_id.to_proto(),
3321                    epoch: epoch as u64,
3322                    version: version.clone(),
3323                }],
3324                ..Default::default()
3325            },
3326        )
3327    });
3328
3329    Ok(())
3330}
3331
3332/// Rejoin the channel notes after a connection blip
3333async fn rejoin_channel_buffers(
3334    request: proto::RejoinChannelBuffers,
3335    response: Response<proto::RejoinChannelBuffers>,
3336    session: Session,
3337) -> Result<()> {
3338    let db = session.db().await;
3339    let buffers = db
3340        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3341        .await?;
3342
3343    for rejoined_buffer in &buffers {
3344        let collaborators_to_notify = rejoined_buffer
3345            .buffer
3346            .collaborators
3347            .iter()
3348            .filter_map(|c| Some(c.peer_id?.into()));
3349        channel_buffer_updated(
3350            session.connection_id,
3351            collaborators_to_notify,
3352            &proto::UpdateChannelBufferCollaborators {
3353                channel_id: rejoined_buffer.buffer.channel_id,
3354                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3355            },
3356            &session.peer,
3357        );
3358    }
3359
3360    response.send(proto::RejoinChannelBuffersResponse {
3361        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3362    })?;
3363
3364    Ok(())
3365}
3366
3367/// Stop editing the channel notes
3368async fn leave_channel_buffer(
3369    request: proto::LeaveChannelBuffer,
3370    response: Response<proto::LeaveChannelBuffer>,
3371    session: Session,
3372) -> Result<()> {
3373    let db = session.db().await;
3374    let channel_id = ChannelId::from_proto(request.channel_id);
3375
3376    let left_buffer = db
3377        .leave_channel_buffer(channel_id, session.connection_id)
3378        .await?;
3379
3380    response.send(Ack {})?;
3381
3382    channel_buffer_updated(
3383        session.connection_id,
3384        left_buffer.connections,
3385        &proto::UpdateChannelBufferCollaborators {
3386            channel_id: channel_id.to_proto(),
3387            collaborators: left_buffer.collaborators,
3388        },
3389        &session.peer,
3390    );
3391
3392    Ok(())
3393}
3394
3395fn channel_buffer_updated<T: EnvelopedMessage>(
3396    sender_id: ConnectionId,
3397    collaborators: impl IntoIterator<Item = ConnectionId>,
3398    message: &T,
3399    peer: &Peer,
3400) {
3401    broadcast(Some(sender_id), collaborators, |peer_id| {
3402        peer.send(peer_id, message.clone())
3403    });
3404}
3405
3406fn send_notifications(
3407    connection_pool: &ConnectionPool,
3408    peer: &Peer,
3409    notifications: db::NotificationBatch,
3410) {
3411    for (user_id, notification) in notifications {
3412        for connection_id in connection_pool.user_connection_ids(user_id) {
3413            if let Err(error) = peer.send(
3414                connection_id,
3415                proto::AddNotification {
3416                    notification: Some(notification.clone()),
3417                },
3418            ) {
3419                tracing::error!(
3420                    "failed to send notification to {:?} {}",
3421                    connection_id,
3422                    error
3423                );
3424            }
3425        }
3426    }
3427}
3428
3429/// Send a message to the channel
3430async fn send_channel_message(
3431    request: proto::SendChannelMessage,
3432    response: Response<proto::SendChannelMessage>,
3433    session: Session,
3434) -> Result<()> {
3435    // Validate the message body.
3436    let body = request.body.trim().to_string();
3437    if body.len() > MAX_MESSAGE_LEN {
3438        return Err(anyhow!("message is too long"))?;
3439    }
3440    if body.is_empty() {
3441        return Err(anyhow!("message can't be blank"))?;
3442    }
3443
3444    // TODO: adjust mentions if body is trimmed
3445
3446    let timestamp = OffsetDateTime::now_utc();
3447    let nonce = request
3448        .nonce
3449        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3450
3451    let channel_id = ChannelId::from_proto(request.channel_id);
3452    let CreatedChannelMessage {
3453        message_id,
3454        participant_connection_ids,
3455        notifications,
3456    } = session
3457        .db()
3458        .await
3459        .create_channel_message(
3460            channel_id,
3461            session.user_id(),
3462            &body,
3463            &request.mentions,
3464            timestamp,
3465            nonce.clone().into(),
3466            request.reply_to_message_id.map(MessageId::from_proto),
3467        )
3468        .await?;
3469
3470    let message = proto::ChannelMessage {
3471        sender_id: session.user_id().to_proto(),
3472        id: message_id.to_proto(),
3473        body,
3474        mentions: request.mentions,
3475        timestamp: timestamp.unix_timestamp() as u64,
3476        nonce: Some(nonce),
3477        reply_to_message_id: request.reply_to_message_id,
3478        edited_at: None,
3479    };
3480    broadcast(
3481        Some(session.connection_id),
3482        participant_connection_ids.clone(),
3483        |connection| {
3484            session.peer.send(
3485                connection,
3486                proto::ChannelMessageSent {
3487                    channel_id: channel_id.to_proto(),
3488                    message: Some(message.clone()),
3489                },
3490            )
3491        },
3492    );
3493    response.send(proto::SendChannelMessageResponse {
3494        message: Some(message),
3495    })?;
3496
3497    let pool = &*session.connection_pool().await;
3498    let non_participants =
3499        pool.channel_connection_ids(channel_id)
3500            .filter_map(|(connection_id, _)| {
3501                if participant_connection_ids.contains(&connection_id) {
3502                    None
3503                } else {
3504                    Some(connection_id)
3505                }
3506            });
3507    broadcast(None, non_participants, |peer_id| {
3508        session.peer.send(
3509            peer_id,
3510            proto::UpdateChannels {
3511                latest_channel_message_ids: vec![proto::ChannelMessageId {
3512                    channel_id: channel_id.to_proto(),
3513                    message_id: message_id.to_proto(),
3514                }],
3515                ..Default::default()
3516            },
3517        )
3518    });
3519    send_notifications(pool, &session.peer, notifications);
3520
3521    Ok(())
3522}
3523
3524/// Delete a channel message
3525async fn remove_channel_message(
3526    request: proto::RemoveChannelMessage,
3527    response: Response<proto::RemoveChannelMessage>,
3528    session: Session,
3529) -> Result<()> {
3530    let channel_id = ChannelId::from_proto(request.channel_id);
3531    let message_id = MessageId::from_proto(request.message_id);
3532    let (connection_ids, existing_notification_ids) = session
3533        .db()
3534        .await
3535        .remove_channel_message(channel_id, message_id, session.user_id())
3536        .await?;
3537
3538    broadcast(
3539        Some(session.connection_id),
3540        connection_ids,
3541        move |connection| {
3542            session.peer.send(connection, request.clone())?;
3543
3544            for notification_id in &existing_notification_ids {
3545                session.peer.send(
3546                    connection,
3547                    proto::DeleteNotification {
3548                        notification_id: (*notification_id).to_proto(),
3549                    },
3550                )?;
3551            }
3552
3553            Ok(())
3554        },
3555    );
3556    response.send(proto::Ack {})?;
3557    Ok(())
3558}
3559
3560async fn update_channel_message(
3561    request: proto::UpdateChannelMessage,
3562    response: Response<proto::UpdateChannelMessage>,
3563    session: Session,
3564) -> Result<()> {
3565    let channel_id = ChannelId::from_proto(request.channel_id);
3566    let message_id = MessageId::from_proto(request.message_id);
3567    let updated_at = OffsetDateTime::now_utc();
3568    let UpdatedChannelMessage {
3569        message_id,
3570        participant_connection_ids,
3571        notifications,
3572        reply_to_message_id,
3573        timestamp,
3574        deleted_mention_notification_ids,
3575        updated_mention_notifications,
3576    } = session
3577        .db()
3578        .await
3579        .update_channel_message(
3580            channel_id,
3581            message_id,
3582            session.user_id(),
3583            request.body.as_str(),
3584            &request.mentions,
3585            updated_at,
3586        )
3587        .await?;
3588
3589    let nonce = request
3590        .nonce
3591        .clone()
3592        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3593
3594    let message = proto::ChannelMessage {
3595        sender_id: session.user_id().to_proto(),
3596        id: message_id.to_proto(),
3597        body: request.body.clone(),
3598        mentions: request.mentions.clone(),
3599        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3600        nonce: Some(nonce),
3601        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3602        edited_at: Some(updated_at.unix_timestamp() as u64),
3603    };
3604
3605    response.send(proto::Ack {})?;
3606
3607    let pool = &*session.connection_pool().await;
3608    broadcast(
3609        Some(session.connection_id),
3610        participant_connection_ids,
3611        |connection| {
3612            session.peer.send(
3613                connection,
3614                proto::ChannelMessageUpdate {
3615                    channel_id: channel_id.to_proto(),
3616                    message: Some(message.clone()),
3617                },
3618            )?;
3619
3620            for notification_id in &deleted_mention_notification_ids {
3621                session.peer.send(
3622                    connection,
3623                    proto::DeleteNotification {
3624                        notification_id: (*notification_id).to_proto(),
3625                    },
3626                )?;
3627            }
3628
3629            for notification in &updated_mention_notifications {
3630                session.peer.send(
3631                    connection,
3632                    proto::UpdateNotification {
3633                        notification: Some(notification.clone()),
3634                    },
3635                )?;
3636            }
3637
3638            Ok(())
3639        },
3640    );
3641
3642    send_notifications(pool, &session.peer, notifications);
3643
3644    Ok(())
3645}
3646
3647/// Mark a channel message as read
3648async fn acknowledge_channel_message(
3649    request: proto::AckChannelMessage,
3650    session: Session,
3651) -> Result<()> {
3652    let channel_id = ChannelId::from_proto(request.channel_id);
3653    let message_id = MessageId::from_proto(request.message_id);
3654    let notifications = session
3655        .db()
3656        .await
3657        .observe_channel_message(channel_id, session.user_id(), message_id)
3658        .await?;
3659    send_notifications(
3660        &*session.connection_pool().await,
3661        &session.peer,
3662        notifications,
3663    );
3664    Ok(())
3665}
3666
3667/// Mark a buffer version as synced
3668async fn acknowledge_buffer_version(
3669    request: proto::AckBufferOperation,
3670    session: Session,
3671) -> Result<()> {
3672    let buffer_id = BufferId::from_proto(request.buffer_id);
3673    session
3674        .db()
3675        .await
3676        .observe_buffer_version(
3677            buffer_id,
3678            session.user_id(),
3679            request.epoch as i32,
3680            &request.version,
3681        )
3682        .await?;
3683    Ok(())
3684}
3685
3686async fn count_language_model_tokens(
3687    request: proto::CountLanguageModelTokens,
3688    response: Response<proto::CountLanguageModelTokens>,
3689    session: Session,
3690    config: &Config,
3691) -> Result<()> {
3692    authorize_access_to_legacy_llm_endpoints(&session).await?;
3693
3694    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3695        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3696        proto::Plan::Free | proto::Plan::ZedProTrial => {
3697            Box::new(FreeCountLanguageModelTokensRateLimit)
3698        }
3699    };
3700
3701    session
3702        .app_state
3703        .rate_limiter
3704        .check(&*rate_limit, session.user_id())
3705        .await?;
3706
3707    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3708        Some(proto::LanguageModelProvider::Google) => {
3709            let api_key = config
3710                .google_ai_api_key
3711                .as_ref()
3712                .context("no Google AI API key configured on the server")?;
3713            google_ai::count_tokens(
3714                session.http_client.as_ref(),
3715                google_ai::API_URL,
3716                api_key,
3717                serde_json::from_str(&request.request)?,
3718            )
3719            .await?
3720        }
3721        _ => return Err(anyhow!("unsupported provider"))?,
3722    };
3723
3724    response.send(proto::CountLanguageModelTokensResponse {
3725        token_count: result.total_tokens as u32,
3726    })?;
3727
3728    Ok(())
3729}
3730
3731struct ZedProCountLanguageModelTokensRateLimit;
3732
3733impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3734    fn capacity(&self) -> usize {
3735        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3736            .ok()
3737            .and_then(|v| v.parse().ok())
3738            .unwrap_or(600) // Picked arbitrarily
3739    }
3740
3741    fn refill_duration(&self) -> chrono::Duration {
3742        chrono::Duration::hours(1)
3743    }
3744
3745    fn db_name(&self) -> &'static str {
3746        "zed-pro:count-language-model-tokens"
3747    }
3748}
3749
3750struct FreeCountLanguageModelTokensRateLimit;
3751
3752impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3753    fn capacity(&self) -> usize {
3754        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3755            .ok()
3756            .and_then(|v| v.parse().ok())
3757            .unwrap_or(600 / 10) // Picked arbitrarily
3758    }
3759
3760    fn refill_duration(&self) -> chrono::Duration {
3761        chrono::Duration::hours(1)
3762    }
3763
3764    fn db_name(&self) -> &'static str {
3765        "free:count-language-model-tokens"
3766    }
3767}
3768
3769/// This is leftover from before the LLM service.
3770///
3771/// The endpoints protected by this check will be moved there eventually.
3772async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3773    if session.is_staff() {
3774        Ok(())
3775    } else {
3776        Err(anyhow!("permission denied"))?
3777    }
3778}
3779
3780/// Get a Supermaven API key for the user
3781async fn get_supermaven_api_key(
3782    _request: proto::GetSupermavenApiKey,
3783    response: Response<proto::GetSupermavenApiKey>,
3784    session: Session,
3785) -> Result<()> {
3786    let user_id: String = session.user_id().to_string();
3787    if !session.is_staff() {
3788        return Err(anyhow!("supermaven not enabled for this account"))?;
3789    }
3790
3791    let email = session
3792        .email()
3793        .ok_or_else(|| anyhow!("user must have an email"))?;
3794
3795    let supermaven_admin_api = session
3796        .supermaven_client
3797        .as_ref()
3798        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3799
3800    let result = supermaven_admin_api
3801        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3802        .await?;
3803
3804    response.send(proto::GetSupermavenApiKeyResponse {
3805        api_key: result.api_key,
3806    })?;
3807
3808    Ok(())
3809}
3810
3811/// Start receiving chat updates for a channel
3812async fn join_channel_chat(
3813    request: proto::JoinChannelChat,
3814    response: Response<proto::JoinChannelChat>,
3815    session: Session,
3816) -> Result<()> {
3817    let channel_id = ChannelId::from_proto(request.channel_id);
3818
3819    let db = session.db().await;
3820    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3821        .await?;
3822    let messages = db
3823        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3824        .await?;
3825    response.send(proto::JoinChannelChatResponse {
3826        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3827        messages,
3828    })?;
3829    Ok(())
3830}
3831
3832/// Stop receiving chat updates for a channel
3833async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3834    let channel_id = ChannelId::from_proto(request.channel_id);
3835    session
3836        .db()
3837        .await
3838        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3839        .await?;
3840    Ok(())
3841}
3842
3843/// Retrieve the chat history for a channel
3844async fn get_channel_messages(
3845    request: proto::GetChannelMessages,
3846    response: Response<proto::GetChannelMessages>,
3847    session: Session,
3848) -> Result<()> {
3849    let channel_id = ChannelId::from_proto(request.channel_id);
3850    let messages = session
3851        .db()
3852        .await
3853        .get_channel_messages(
3854            channel_id,
3855            session.user_id(),
3856            MESSAGE_COUNT_PER_PAGE,
3857            Some(MessageId::from_proto(request.before_message_id)),
3858        )
3859        .await?;
3860    response.send(proto::GetChannelMessagesResponse {
3861        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3862        messages,
3863    })?;
3864    Ok(())
3865}
3866
3867/// Retrieve specific chat messages
3868async fn get_channel_messages_by_id(
3869    request: proto::GetChannelMessagesById,
3870    response: Response<proto::GetChannelMessagesById>,
3871    session: Session,
3872) -> Result<()> {
3873    let message_ids = request
3874        .message_ids
3875        .iter()
3876        .map(|id| MessageId::from_proto(*id))
3877        .collect::<Vec<_>>();
3878    let messages = session
3879        .db()
3880        .await
3881        .get_channel_messages_by_id(session.user_id(), &message_ids)
3882        .await?;
3883    response.send(proto::GetChannelMessagesResponse {
3884        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3885        messages,
3886    })?;
3887    Ok(())
3888}
3889
3890/// Retrieve the current users notifications
3891async fn get_notifications(
3892    request: proto::GetNotifications,
3893    response: Response<proto::GetNotifications>,
3894    session: Session,
3895) -> Result<()> {
3896    let notifications = session
3897        .db()
3898        .await
3899        .get_notifications(
3900            session.user_id(),
3901            NOTIFICATION_COUNT_PER_PAGE,
3902            request.before_id.map(db::NotificationId::from_proto),
3903        )
3904        .await?;
3905    response.send(proto::GetNotificationsResponse {
3906        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3907        notifications,
3908    })?;
3909    Ok(())
3910}
3911
3912/// Mark notifications as read
3913async fn mark_notification_as_read(
3914    request: proto::MarkNotificationRead,
3915    response: Response<proto::MarkNotificationRead>,
3916    session: Session,
3917) -> Result<()> {
3918    let database = &session.db().await;
3919    let notifications = database
3920        .mark_notification_as_read_by_id(
3921            session.user_id(),
3922            NotificationId::from_proto(request.notification_id),
3923        )
3924        .await?;
3925    send_notifications(
3926        &*session.connection_pool().await,
3927        &session.peer,
3928        notifications,
3929    );
3930    response.send(proto::Ack {})?;
3931    Ok(())
3932}
3933
3934/// Get the current users information
3935async fn get_private_user_info(
3936    _request: proto::GetPrivateUserInfo,
3937    response: Response<proto::GetPrivateUserInfo>,
3938    session: Session,
3939) -> Result<()> {
3940    let db = session.db().await;
3941
3942    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3943    let user = db
3944        .get_user_by_id(session.user_id())
3945        .await?
3946        .ok_or_else(|| anyhow!("user not found"))?;
3947    let flags = db.get_user_flags(session.user_id()).await?;
3948
3949    response.send(proto::GetPrivateUserInfoResponse {
3950        metrics_id,
3951        staff: user.admin,
3952        flags,
3953        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3954    })?;
3955    Ok(())
3956}
3957
3958/// Accept the terms of service (tos) on behalf of the current user
3959async fn accept_terms_of_service(
3960    _request: proto::AcceptTermsOfService,
3961    response: Response<proto::AcceptTermsOfService>,
3962    session: Session,
3963) -> Result<()> {
3964    let db = session.db().await;
3965
3966    let accepted_tos_at = Utc::now();
3967    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3968        .await?;
3969
3970    response.send(proto::AcceptTermsOfServiceResponse {
3971        accepted_tos_at: accepted_tos_at.timestamp() as u64,
3972    })?;
3973    Ok(())
3974}
3975
3976/// The minimum account age an account must have in order to use the LLM service.
3977pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3978
3979async fn get_llm_api_token(
3980    _request: proto::GetLlmToken,
3981    response: Response<proto::GetLlmToken>,
3982    session: Session,
3983) -> Result<()> {
3984    let db = session.db().await;
3985
3986    let flags = db.get_user_flags(session.user_id()).await?;
3987    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3988
3989    if !session.is_staff() && !has_language_models_feature_flag {
3990        Err(anyhow!("permission denied"))?
3991    }
3992
3993    let user_id = session.user_id();
3994    let user = db
3995        .get_user_by_id(user_id)
3996        .await?
3997        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3998
3999    if user.accepted_tos_at.is_none() {
4000        Err(anyhow!("terms of service not accepted"))?
4001    }
4002
4003    let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4004    let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4005    let billing_preferences = db.get_billing_preferences(user.id).await?;
4006
4007    let token = LlmTokenClaims::create(
4008        &user,
4009        session.is_staff(),
4010        billing_preferences,
4011        &flags,
4012        has_legacy_llm_subscription,
4013        billing_subscription,
4014        session.system_id.clone(),
4015        &session.app_state.config,
4016    )?;
4017    response.send(proto::GetLlmTokenResponse { token })?;
4018    Ok(())
4019}
4020
4021fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4022    let message = match message {
4023        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4024        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4025        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4026        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4027        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4028            code: frame.code.into(),
4029            reason: frame.reason.as_str().to_owned().into(),
4030        })),
4031        // We should never receive a frame while reading the message, according
4032        // to the `tungstenite` maintainers:
4033        //
4034        // > It cannot occur when you read messages from the WebSocket, but it
4035        // > can be used when you want to send the raw frames (e.g. you want to
4036        // > send the frames to the WebSocket without composing the full message first).
4037        // >
4038        // > — https://github.com/snapview/tungstenite-rs/issues/268
4039        TungsteniteMessage::Frame(_) => {
4040            bail!("received an unexpected frame while reading the message")
4041        }
4042    };
4043
4044    Ok(message)
4045}
4046
4047fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4048    match message {
4049        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4050        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4051        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4052        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4053        AxumMessage::Close(frame) => {
4054            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4055                code: frame.code.into(),
4056                reason: frame.reason.as_ref().into(),
4057            }))
4058        }
4059    }
4060}
4061
4062fn notify_membership_updated(
4063    connection_pool: &mut ConnectionPool,
4064    result: MembershipUpdated,
4065    user_id: UserId,
4066    peer: &Peer,
4067) {
4068    for membership in &result.new_channels.channel_memberships {
4069        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4070    }
4071    for channel_id in &result.removed_channels {
4072        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4073    }
4074
4075    let user_channels_update = proto::UpdateUserChannels {
4076        channel_memberships: result
4077            .new_channels
4078            .channel_memberships
4079            .iter()
4080            .map(|cm| proto::ChannelMembership {
4081                channel_id: cm.channel_id.to_proto(),
4082                role: cm.role.into(),
4083            })
4084            .collect(),
4085        ..Default::default()
4086    };
4087
4088    let mut update = build_channels_update(result.new_channels);
4089    update.delete_channels = result
4090        .removed_channels
4091        .into_iter()
4092        .map(|id| id.to_proto())
4093        .collect();
4094    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4095
4096    for connection_id in connection_pool.user_connection_ids(user_id) {
4097        peer.send(connection_id, user_channels_update.clone())
4098            .trace_err();
4099        peer.send(connection_id, update.clone()).trace_err();
4100    }
4101}
4102
4103fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4104    proto::UpdateUserChannels {
4105        channel_memberships: channels
4106            .channel_memberships
4107            .iter()
4108            .map(|m| proto::ChannelMembership {
4109                channel_id: m.channel_id.to_proto(),
4110                role: m.role.into(),
4111            })
4112            .collect(),
4113        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4114        observed_channel_message_id: channels.observed_channel_messages.clone(),
4115    }
4116}
4117
4118fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4119    let mut update = proto::UpdateChannels::default();
4120
4121    for channel in channels.channels {
4122        update.channels.push(channel.to_proto());
4123    }
4124
4125    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4126    update.latest_channel_message_ids = channels.latest_channel_messages;
4127
4128    for (channel_id, participants) in channels.channel_participants {
4129        update
4130            .channel_participants
4131            .push(proto::ChannelParticipants {
4132                channel_id: channel_id.to_proto(),
4133                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4134            });
4135    }
4136
4137    for channel in channels.invited_channels {
4138        update.channel_invitations.push(channel.to_proto());
4139    }
4140
4141    update
4142}
4143
4144fn build_initial_contacts_update(
4145    contacts: Vec<db::Contact>,
4146    pool: &ConnectionPool,
4147) -> proto::UpdateContacts {
4148    let mut update = proto::UpdateContacts::default();
4149
4150    for contact in contacts {
4151        match contact {
4152            db::Contact::Accepted { user_id, busy } => {
4153                update.contacts.push(contact_for_user(user_id, busy, pool));
4154            }
4155            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4156            db::Contact::Incoming { user_id } => {
4157                update
4158                    .incoming_requests
4159                    .push(proto::IncomingContactRequest {
4160                        requester_id: user_id.to_proto(),
4161                    })
4162            }
4163        }
4164    }
4165
4166    update
4167}
4168
4169fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4170    proto::Contact {
4171        user_id: user_id.to_proto(),
4172        online: pool.is_user_online(user_id),
4173        busy,
4174    }
4175}
4176
4177fn room_updated(room: &proto::Room, peer: &Peer) {
4178    broadcast(
4179        None,
4180        room.participants
4181            .iter()
4182            .filter_map(|participant| Some(participant.peer_id?.into())),
4183        |peer_id| {
4184            peer.send(
4185                peer_id,
4186                proto::RoomUpdated {
4187                    room: Some(room.clone()),
4188                },
4189            )
4190        },
4191    );
4192}
4193
4194fn channel_updated(
4195    channel: &db::channel::Model,
4196    room: &proto::Room,
4197    peer: &Peer,
4198    pool: &ConnectionPool,
4199) {
4200    let participants = room
4201        .participants
4202        .iter()
4203        .map(|p| p.user_id)
4204        .collect::<Vec<_>>();
4205
4206    broadcast(
4207        None,
4208        pool.channel_connection_ids(channel.root_id())
4209            .filter_map(|(channel_id, role)| {
4210                role.can_see_channel(channel.visibility)
4211                    .then_some(channel_id)
4212            }),
4213        |peer_id| {
4214            peer.send(
4215                peer_id,
4216                proto::UpdateChannels {
4217                    channel_participants: vec![proto::ChannelParticipants {
4218                        channel_id: channel.id.to_proto(),
4219                        participant_user_ids: participants.clone(),
4220                    }],
4221                    ..Default::default()
4222                },
4223            )
4224        },
4225    );
4226}
4227
4228async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4229    let db = session.db().await;
4230
4231    let contacts = db.get_contacts(user_id).await?;
4232    let busy = db.is_user_busy(user_id).await?;
4233
4234    let pool = session.connection_pool().await;
4235    let updated_contact = contact_for_user(user_id, busy, &pool);
4236    for contact in contacts {
4237        if let db::Contact::Accepted {
4238            user_id: contact_user_id,
4239            ..
4240        } = contact
4241        {
4242            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4243                session
4244                    .peer
4245                    .send(
4246                        contact_conn_id,
4247                        proto::UpdateContacts {
4248                            contacts: vec![updated_contact.clone()],
4249                            remove_contacts: Default::default(),
4250                            incoming_requests: Default::default(),
4251                            remove_incoming_requests: Default::default(),
4252                            outgoing_requests: Default::default(),
4253                            remove_outgoing_requests: Default::default(),
4254                        },
4255                    )
4256                    .trace_err();
4257            }
4258        }
4259    }
4260    Ok(())
4261}
4262
4263async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4264    let mut contacts_to_update = HashSet::default();
4265
4266    let room_id;
4267    let canceled_calls_to_user_ids;
4268    let livekit_room;
4269    let delete_livekit_room;
4270    let room;
4271    let channel;
4272
4273    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4274        contacts_to_update.insert(session.user_id());
4275
4276        for project in left_room.left_projects.values() {
4277            project_left(project, session);
4278        }
4279
4280        room_id = RoomId::from_proto(left_room.room.id);
4281        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4282        livekit_room = mem::take(&mut left_room.room.livekit_room);
4283        delete_livekit_room = left_room.deleted;
4284        room = mem::take(&mut left_room.room);
4285        channel = mem::take(&mut left_room.channel);
4286
4287        room_updated(&room, &session.peer);
4288    } else {
4289        return Ok(());
4290    }
4291
4292    if let Some(channel) = channel {
4293        channel_updated(
4294            &channel,
4295            &room,
4296            &session.peer,
4297            &*session.connection_pool().await,
4298        );
4299    }
4300
4301    {
4302        let pool = session.connection_pool().await;
4303        for canceled_user_id in canceled_calls_to_user_ids {
4304            for connection_id in pool.user_connection_ids(canceled_user_id) {
4305                session
4306                    .peer
4307                    .send(
4308                        connection_id,
4309                        proto::CallCanceled {
4310                            room_id: room_id.to_proto(),
4311                        },
4312                    )
4313                    .trace_err();
4314            }
4315            contacts_to_update.insert(canceled_user_id);
4316        }
4317    }
4318
4319    for contact_user_id in contacts_to_update {
4320        update_user_contacts(contact_user_id, session).await?;
4321    }
4322
4323    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4324        live_kit
4325            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4326            .await
4327            .trace_err();
4328
4329        if delete_livekit_room {
4330            live_kit.delete_room(livekit_room).await.trace_err();
4331        }
4332    }
4333
4334    Ok(())
4335}
4336
4337async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4338    let left_channel_buffers = session
4339        .db()
4340        .await
4341        .leave_channel_buffers(session.connection_id)
4342        .await?;
4343
4344    for left_buffer in left_channel_buffers {
4345        channel_buffer_updated(
4346            session.connection_id,
4347            left_buffer.connections,
4348            &proto::UpdateChannelBufferCollaborators {
4349                channel_id: left_buffer.channel_id.to_proto(),
4350                collaborators: left_buffer.collaborators,
4351            },
4352            &session.peer,
4353        );
4354    }
4355
4356    Ok(())
4357}
4358
4359fn project_left(project: &db::LeftProject, session: &Session) {
4360    for connection_id in &project.connection_ids {
4361        if project.should_unshare {
4362            session
4363                .peer
4364                .send(
4365                    *connection_id,
4366                    proto::UnshareProject {
4367                        project_id: project.id.to_proto(),
4368                    },
4369                )
4370                .trace_err();
4371        } else {
4372            session
4373                .peer
4374                .send(
4375                    *connection_id,
4376                    proto::RemoveProjectCollaborator {
4377                        project_id: project.id.to_proto(),
4378                        peer_id: Some(session.connection_id.into()),
4379                    },
4380                )
4381                .trace_err();
4382        }
4383    }
4384}
4385
4386pub trait ResultExt {
4387    type Ok;
4388
4389    fn trace_err(self) -> Option<Self::Ok>;
4390}
4391
4392impl<T, E> ResultExt for Result<T, E>
4393where
4394    E: std::fmt::Debug,
4395{
4396    type Ok = T;
4397
4398    #[track_caller]
4399    fn trace_err(self) -> Option<T> {
4400        match self {
4401            Ok(value) => Some(value),
4402            Err(error) => {
4403                tracing::error!("{:?}", error);
4404                None
4405            }
4406        }
4407    }
4408}