Remove RPC messages pertaining to the LLM token (#36252)

Marshall Bowers created

This PR removes the RPC messages pertaining to the LLM token.

We now retrieve the LLM token from Cloud.

Release Notes:

- N/A

Change summary

Cargo.lock                     |   2 
crates/collab/Cargo.toml       |   2 
crates/collab/src/llm.rs       |   3 
crates/collab/src/llm/token.rs | 146 ------------------------------------
crates/collab/src/rpc.rs       |  96 -----------------------
crates/proto/proto/ai.proto    |   8 -
crates/proto/proto/zed.proto   |   6 
crates/proto/src/proto.rs      |   4 
8 files changed, 4 insertions(+), 263 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3324,7 +3324,6 @@ dependencies = [
  "http_client",
  "hyper 0.14.32",
  "indoc",
- "jsonwebtoken",
  "language",
  "language_model",
  "livekit_api",
@@ -3370,7 +3369,6 @@ dependencies = [
  "telemetry_events",
  "text",
  "theme",
- "thiserror 2.0.12",
  "time",
  "tokio",
  "toml 0.8.20",

crates/collab/Cargo.toml 🔗

@@ -39,7 +39,6 @@ futures.workspace = true
 gpui.workspace = true
 hex.workspace = true
 http_client.workspace = true
-jsonwebtoken.workspace = true
 livekit_api.workspace = true
 log.workspace = true
 nanoid.workspace = true
@@ -65,7 +64,6 @@ subtle.workspace = true
 supermaven_api.workspace = true
 telemetry_events.workspace = true
 text.workspace = true
-thiserror.workspace = true
 time.workspace = true
 tokio = { workspace = true, features = ["full"] }
 toml.workspace = true

crates/collab/src/llm.rs 🔗

@@ -1,7 +1,4 @@
 pub mod db;
-mod token;
-
-pub use token::*;
 
 pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
 

crates/collab/src/llm/token.rs 🔗

@@ -1,146 +0,0 @@
-use crate::db::billing_subscription::SubscriptionKind;
-use crate::db::{billing_customer, billing_subscription, user};
-use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG};
-use crate::{Config, db::billing_preference};
-use anyhow::{Context as _, Result};
-use chrono::{NaiveDateTime, Utc};
-use cloud_llm_client::Plan;
-use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
-use serde::{Deserialize, Serialize};
-use std::time::Duration;
-use thiserror::Error;
-use uuid::Uuid;
-
-#[derive(Clone, Debug, Default, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub struct LlmTokenClaims {
-    pub iat: u64,
-    pub exp: u64,
-    pub jti: String,
-    pub user_id: u64,
-    pub system_id: Option<String>,
-    pub metrics_id: Uuid,
-    pub github_user_login: String,
-    pub account_created_at: NaiveDateTime,
-    pub is_staff: bool,
-    pub has_llm_closed_beta_feature_flag: bool,
-    pub bypass_account_age_check: bool,
-    pub use_llm_request_queue: bool,
-    pub plan: Plan,
-    pub has_extended_trial: bool,
-    pub subscription_period: (NaiveDateTime, NaiveDateTime),
-    pub enable_model_request_overages: bool,
-    pub model_request_overages_spend_limit_in_cents: u32,
-    pub can_use_web_search_tool: bool,
-    #[serde(default)]
-    pub has_overdue_invoices: bool,
-}
-
-const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
-
-impl LlmTokenClaims {
-    pub fn create(
-        user: &user::Model,
-        is_staff: bool,
-        billing_customer: billing_customer::Model,
-        billing_preferences: Option<billing_preference::Model>,
-        feature_flags: &Vec<String>,
-        subscription: billing_subscription::Model,
-        system_id: Option<String>,
-        config: &Config,
-    ) -> Result<String> {
-        let secret = config
-            .llm_api_secret
-            .as_ref()
-            .context("no LLM API secret")?;
-
-        let plan = if is_staff {
-            Plan::ZedPro
-        } else {
-            subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
-                SubscriptionKind::ZedFree => Plan::ZedFree,
-                SubscriptionKind::ZedPro => Plan::ZedPro,
-                SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
-            })
-        };
-        let subscription_period =
-            billing_subscription::Model::current_period(Some(subscription), is_staff)
-                .map(|(start, end)| (start.naive_utc(), end.naive_utc()))
-                .context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?;
-
-        let now = Utc::now();
-        let claims = Self {
-            iat: now.timestamp() as u64,
-            exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
-            jti: uuid::Uuid::new_v4().to_string(),
-            user_id: user.id.to_proto(),
-            system_id,
-            metrics_id: user.metrics_id,
-            github_user_login: user.github_login.clone(),
-            account_created_at: user.account_created_at(),
-            is_staff,
-            has_llm_closed_beta_feature_flag: feature_flags
-                .iter()
-                .any(|flag| flag == "llm-closed-beta"),
-            bypass_account_age_check: feature_flags
-                .iter()
-                .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG),
-            can_use_web_search_tool: true,
-            use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
-            plan,
-            has_extended_trial: feature_flags
-                .iter()
-                .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
-            subscription_period,
-            enable_model_request_overages: billing_preferences
-                .as_ref()
-                .map_or(false, |preferences| {
-                    preferences.model_request_overages_enabled
-                }),
-            model_request_overages_spend_limit_in_cents: billing_preferences
-                .as_ref()
-                .map_or(0, |preferences| {
-                    preferences.model_request_overages_spend_limit_in_cents as u32
-                }),
-            has_overdue_invoices: billing_customer.has_overdue_invoices,
-        };
-
-        Ok(jsonwebtoken::encode(
-            &Header::default(),
-            &claims,
-            &EncodingKey::from_secret(secret.as_ref()),
-        )?)
-    }
-
-    pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
-        let secret = config
-            .llm_api_secret
-            .as_ref()
-            .context("no LLM API secret")?;
-
-        match jsonwebtoken::decode::<Self>(
-            token,
-            &DecodingKey::from_secret(secret.as_ref()),
-            &Validation::default(),
-        ) {
-            Ok(token) => Ok(token.claims),
-            Err(e) => {
-                if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
-                    Err(ValidateLlmTokenError::Expired)
-                } else {
-                    Err(ValidateLlmTokenError::JwtError(e))
-                }
-            }
-        }
-    }
-}
-
-#[derive(Error, Debug)]
-pub enum ValidateLlmTokenError {
-    #[error("access token is expired")]
-    Expired,
-    #[error("access token validation error: {0}")]
-    JwtError(#[from] jsonwebtoken::errors::Error),
-    #[error("{0}")]
-    Other(#[from] anyhow::Error),
-}

crates/collab/src/rpc.rs 🔗

@@ -1,14 +1,12 @@
 mod connection_pool;
 
-use crate::api::billing::find_or_create_billing_customer;
 use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
 use crate::db::billing_subscription::SubscriptionKind;
 use crate::llm::db::LlmDatabase;
 use crate::llm::{
-    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
+    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG,
     MIN_ACCOUNT_AGE_FOR_LLM_USE,
 };
-use crate::stripe_client::StripeCustomerId;
 use crate::{
     AppState, Error, Result, auth,
     db::{
@@ -218,6 +216,7 @@ struct Session {
     /// The GeoIP country code for the user.
     #[allow(unused)]
     geoip_country_code: Option<String>,
+    #[allow(unused)]
     system_id: Option<String>,
     _executor: Executor,
 }
@@ -464,7 +463,6 @@ impl Server {
             .add_message_handler(unfollow)
             .add_message_handler(update_followers)
             .add_request_handler(get_private_user_info)
-            .add_request_handler(get_llm_api_token)
             .add_request_handler(accept_terms_of_service)
             .add_message_handler(acknowledge_channel_message)
             .add_message_handler(acknowledge_buffer_version)
@@ -4251,96 +4249,6 @@ async fn accept_terms_of_service(
         accepted_tos_at: accepted_tos_at.timestamp() as u64,
     })?;
 
-    // When the user accepts the terms of service, we want to refresh their LLM
-    // token to grant access.
-    session
-        .peer
-        .send(session.connection_id, proto::RefreshLlmToken {})?;
-
-    Ok(())
-}
-
-async fn get_llm_api_token(
-    _request: proto::GetLlmToken,
-    response: Response<proto::GetLlmToken>,
-    session: MessageContext,
-) -> Result<()> {
-    let db = session.db().await;
-
-    let flags = db.get_user_flags(session.user_id()).await?;
-
-    let user_id = session.user_id();
-    let user = db
-        .get_user_by_id(user_id)
-        .await?
-        .with_context(|| format!("user {user_id} not found"))?;
-
-    if user.accepted_tos_at.is_none() {
-        Err(anyhow!("terms of service not accepted"))?
-    }
-
-    let stripe_client = session
-        .app_state
-        .stripe_client
-        .as_ref()
-        .context("failed to retrieve Stripe client")?;
-
-    let stripe_billing = session
-        .app_state
-        .stripe_billing
-        .as_ref()
-        .context("failed to retrieve Stripe billing object")?;
-
-    let billing_customer = if let Some(billing_customer) =
-        db.get_billing_customer_by_user_id(user.id).await?
-    {
-        billing_customer
-    } else {
-        let customer_id = stripe_billing
-            .find_or_create_customer_by_email(user.email_address.as_deref())
-            .await?;
-
-        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
-            .await?
-            .context("billing customer not found")?
-    };
-
-    let billing_subscription =
-        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
-            billing_subscription
-        } else {
-            let stripe_customer_id =
-                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
-
-            let stripe_subscription = stripe_billing
-                .subscribe_to_zed_free(stripe_customer_id)
-                .await?;
-
-            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
-                billing_customer_id: billing_customer.id,
-                kind: Some(SubscriptionKind::ZedFree),
-                stripe_subscription_id: stripe_subscription.id.to_string(),
-                stripe_subscription_status: stripe_subscription.status.into(),
-                stripe_cancellation_reason: None,
-                stripe_current_period_start: Some(stripe_subscription.current_period_start),
-                stripe_current_period_end: Some(stripe_subscription.current_period_end),
-            })
-            .await?
-        };
-
-    let billing_preferences = db.get_billing_preferences(user.id).await?;
-
-    let token = LlmTokenClaims::create(
-        &user,
-        session.is_staff(),
-        billing_customer,
-        billing_preferences,
-        &flags,
-        billing_subscription,
-        session.system_id.clone(),
-        &session.app_state.config,
-    )?;
-    response.send(proto::GetLlmTokenResponse { token })?;
     Ok(())
 }
 

crates/proto/proto/ai.proto 🔗

@@ -158,14 +158,6 @@ message SynchronizeContextsResponse {
     repeated ContextVersion contexts = 1;
 }
 
-message GetLlmToken {}
-
-message GetLlmTokenResponse {
-    string token = 1;
-}
-
-message RefreshLlmToken {}
-
 enum LanguageModelRole {
     LanguageModelUser = 0;
     LanguageModelAssistant = 1;

crates/proto/proto/zed.proto 🔗

@@ -250,10 +250,6 @@ message Envelope {
         AddWorktree add_worktree = 222;
         AddWorktreeResponse add_worktree_response = 223;
 
-        GetLlmToken get_llm_token = 235;
-        GetLlmTokenResponse get_llm_token_response = 236;
-        RefreshLlmToken refresh_llm_token = 259;
-
         LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
         LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
 
@@ -419,7 +415,9 @@ message Envelope {
     reserved 221;
     reserved 224 to 229;
     reserved 230 to 231;
+    reserved 235 to 236;
     reserved 246;
+    reserved 259;
     reserved 270;
     reserved 247 to 254;
     reserved 255 to 256;

crates/proto/src/proto.rs 🔗

@@ -119,8 +119,6 @@ messages!(
     (GetTypeDefinitionResponse, Background),
     (GetImplementation, Background),
     (GetImplementationResponse, Background),
-    (GetLlmToken, Background),
-    (GetLlmTokenResponse, Background),
     (OpenUnstagedDiff, Foreground),
     (OpenUnstagedDiffResponse, Foreground),
     (OpenUncommittedDiff, Foreground),
@@ -196,7 +194,6 @@ messages!(
     (PrepareRenameResponse, Background),
     (ProjectEntryResponse, Foreground),
     (RefreshInlayHints, Foreground),
-    (RefreshLlmToken, Background),
     (RegisterBufferWithLanguageServers, Background),
     (RejoinChannelBuffers, Foreground),
     (RejoinChannelBuffersResponse, Foreground),
@@ -354,7 +351,6 @@ request_messages!(
     (GetDocumentHighlights, GetDocumentHighlightsResponse),
     (GetDocumentSymbols, GetDocumentSymbolsResponse),
     (GetHover, GetHoverResponse),
-    (GetLlmToken, GetLlmTokenResponse),
     (GetNotifications, GetNotificationsResponse),
     (GetPrivateUserInfo, GetPrivateUserInfoResponse),
     (GetProjectSymbols, GetProjectSymbolsResponse),