token.rs

  1use crate::db::user;
  2use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
  3use crate::Cents;
  4use crate::{db::billing_preference, Config};
  5use anyhow::{anyhow, Result};
  6use chrono::Utc;
  7use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
  8use serde::{Deserialize, Serialize};
  9use std::time::Duration;
 10use thiserror::Error;
 11use uuid::Uuid;
 12
 13#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 14#[serde(rename_all = "camelCase")]
 15pub struct LlmTokenClaims {
 16    pub iat: u64,
 17    pub exp: u64,
 18    pub jti: String,
 19    pub user_id: u64,
 20    pub system_id: Option<String>,
 21    pub metrics_id: Uuid,
 22    pub github_user_login: String,
 23    pub is_staff: bool,
 24    pub has_llm_closed_beta_feature_flag: bool,
 25    pub has_llm_subscription: bool,
 26    pub max_monthly_spend_in_cents: u32,
 27    pub custom_llm_monthly_allowance_in_cents: Option<u32>,
 28    pub plan: rpc::proto::Plan,
 29}
 30
 31const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
 32
 33impl LlmTokenClaims {
 34    #[allow(clippy::too_many_arguments)]
 35    pub fn create(
 36        user: &user::Model,
 37        is_staff: bool,
 38        billing_preferences: Option<billing_preference::Model>,
 39        has_llm_closed_beta_feature_flag: bool,
 40        has_llm_subscription: bool,
 41        plan: rpc::proto::Plan,
 42        system_id: Option<String>,
 43        config: &Config,
 44    ) -> Result<String> {
 45        let secret = config
 46            .llm_api_secret
 47            .as_ref()
 48            .ok_or_else(|| anyhow!("no LLM API secret"))?;
 49
 50        let now = Utc::now();
 51        let claims = Self {
 52            iat: now.timestamp() as u64,
 53            exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
 54            jti: uuid::Uuid::new_v4().to_string(),
 55            user_id: user.id.to_proto(),
 56            system_id,
 57            metrics_id: user.metrics_id,
 58            github_user_login: user.github_login.clone(),
 59            is_staff,
 60            has_llm_closed_beta_feature_flag,
 61            has_llm_subscription,
 62            max_monthly_spend_in_cents: billing_preferences
 63                .map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
 64                    preferences.max_monthly_llm_usage_spending_in_cents as u32
 65                }),
 66            custom_llm_monthly_allowance_in_cents: user
 67                .custom_llm_monthly_allowance_in_cents
 68                .map(|allowance| allowance as u32),
 69            plan,
 70        };
 71
 72        Ok(jsonwebtoken::encode(
 73            &Header::default(),
 74            &claims,
 75            &EncodingKey::from_secret(secret.as_ref()),
 76        )?)
 77    }
 78
 79    pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
 80        let secret = config
 81            .llm_api_secret
 82            .as_ref()
 83            .ok_or_else(|| anyhow!("no LLM API secret"))?;
 84
 85        match jsonwebtoken::decode::<Self>(
 86            token,
 87            &DecodingKey::from_secret(secret.as_ref()),
 88            &Validation::default(),
 89        ) {
 90            Ok(token) => Ok(token.claims),
 91            Err(e) => {
 92                if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
 93                    Err(ValidateLlmTokenError::Expired)
 94                } else {
 95                    Err(ValidateLlmTokenError::JwtError(e))
 96                }
 97            }
 98        }
 99    }
100
101    pub fn free_tier_monthly_spending_limit(&self) -> Cents {
102        self.custom_llm_monthly_allowance_in_cents
103            .map(Cents)
104            .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT)
105    }
106}
107
108#[derive(Error, Debug)]
109pub enum ValidateLlmTokenError {
110    #[error("access token is expired")]
111    Expired,
112    #[error("access token validation error: {0}")]
113    JwtError(#[from] jsonwebtoken::errors::Error),
114    #[error("{0}")]
115    Other(#[from] anyhow::Error),
116}