token.rs

  1use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
  2use crate::{
  3    db::{billing_preference, UserId},
  4    Config,
  5};
  6use anyhow::{anyhow, Result};
  7use chrono::Utc;
  8use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
  9use serde::{Deserialize, Serialize};
 10use std::time::Duration;
 11use thiserror::Error;
 12
 13#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 14#[serde(rename_all = "camelCase")]
 15pub struct LlmTokenClaims {
 16    pub iat: u64,
 17    pub exp: u64,
 18    pub jti: String,
 19    pub user_id: u64,
 20    pub github_user_login: String,
 21    pub is_staff: bool,
 22    pub has_llm_closed_beta_feature_flag: bool,
 23    pub has_llm_subscription: bool,
 24    pub max_monthly_spend_in_cents: u32,
 25    pub plan: rpc::proto::Plan,
 26}
 27
 28const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
 29
 30impl LlmTokenClaims {
 31    #[allow(clippy::too_many_arguments)]
 32    pub fn create(
 33        user_id: UserId,
 34        github_user_login: String,
 35        is_staff: bool,
 36        billing_preferences: Option<billing_preference::Model>,
 37        has_llm_closed_beta_feature_flag: bool,
 38        has_llm_subscription: bool,
 39        plan: rpc::proto::Plan,
 40        config: &Config,
 41    ) -> Result<String> {
 42        let secret = config
 43            .llm_api_secret
 44            .as_ref()
 45            .ok_or_else(|| anyhow!("no LLM API secret"))?;
 46
 47        let now = Utc::now();
 48        let claims = Self {
 49            iat: now.timestamp() as u64,
 50            exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
 51            jti: uuid::Uuid::new_v4().to_string(),
 52            user_id: user_id.to_proto(),
 53            github_user_login,
 54            is_staff,
 55            has_llm_closed_beta_feature_flag,
 56            has_llm_subscription,
 57            max_monthly_spend_in_cents: billing_preferences
 58                .map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
 59                    preferences.max_monthly_llm_usage_spending_in_cents as u32
 60                }),
 61            plan,
 62        };
 63
 64        Ok(jsonwebtoken::encode(
 65            &Header::default(),
 66            &claims,
 67            &EncodingKey::from_secret(secret.as_ref()),
 68        )?)
 69    }
 70
 71    pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
 72        let secret = config
 73            .llm_api_secret
 74            .as_ref()
 75            .ok_or_else(|| anyhow!("no LLM API secret"))?;
 76
 77        match jsonwebtoken::decode::<Self>(
 78            token,
 79            &DecodingKey::from_secret(secret.as_ref()),
 80            &Validation::default(),
 81        ) {
 82            Ok(token) => Ok(token.claims),
 83            Err(e) => {
 84                if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
 85                    Err(ValidateLlmTokenError::Expired)
 86                } else {
 87                    Err(ValidateLlmTokenError::JwtError(e))
 88                }
 89            }
 90        }
 91    }
 92}
 93
 94#[derive(Error, Debug)]
 95pub enum ValidateLlmTokenError {
 96    #[error("access token is expired")]
 97    Expired,
 98    #[error("access token validation error: {0}")]
 99    JwtError(#[from] jsonwebtoken::errors::Error),
100    #[error("{0}")]
101    Other(#[from] anyhow::Error),
102}