token.rs

  1use crate::db::billing_subscription::SubscriptionKind;
  2use crate::db::{billing_subscription, user};
  3use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
  4use crate::{Config, db::billing_preference};
  5use anyhow::{Result, anyhow};
  6use chrono::{NaiveDateTime, Utc};
  7use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
  8use serde::{Deserialize, Serialize};
  9use std::time::Duration;
 10use thiserror::Error;
 11use uuid::Uuid;
 12use zed_llm_client::Plan;
 13
 14#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 15#[serde(rename_all = "camelCase")]
 16pub struct LlmTokenClaims {
 17    pub iat: u64,
 18    pub exp: u64,
 19    pub jti: String,
 20    pub user_id: u64,
 21    pub system_id: Option<String>,
 22    pub metrics_id: Uuid,
 23    pub github_user_login: String,
 24    pub account_created_at: NaiveDateTime,
 25    pub is_staff: bool,
 26    pub has_llm_closed_beta_feature_flag: bool,
 27    pub bypass_account_age_check: bool,
 28    #[serde(default)]
 29    pub use_llm_request_queue: bool,
 30    pub plan: Plan,
 31    #[serde(default)]
 32    pub has_extended_trial: bool,
 33    #[serde(default)]
 34    pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
 35    #[serde(default)]
 36    pub enable_model_request_overages: bool,
 37    #[serde(default)]
 38    pub model_request_overages_spend_limit_in_cents: u32,
 39    #[serde(default)]
 40    pub can_use_web_search_tool: bool,
 41}
 42
 43const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
 44
 45impl LlmTokenClaims {
 46    pub fn create(
 47        user: &user::Model,
 48        is_staff: bool,
 49        billing_preferences: Option<billing_preference::Model>,
 50        feature_flags: &Vec<String>,
 51        subscription: Option<billing_subscription::Model>,
 52        system_id: Option<String>,
 53        config: &Config,
 54    ) -> Result<String> {
 55        let secret = config
 56            .llm_api_secret
 57            .as_ref()
 58            .ok_or_else(|| anyhow!("no LLM API secret"))?;
 59
 60        let now = Utc::now();
 61        let claims = Self {
 62            iat: now.timestamp() as u64,
 63            exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
 64            jti: uuid::Uuid::new_v4().to_string(),
 65            user_id: user.id.to_proto(),
 66            system_id,
 67            metrics_id: user.metrics_id,
 68            github_user_login: user.github_login.clone(),
 69            account_created_at: user.account_created_at(),
 70            is_staff,
 71            has_llm_closed_beta_feature_flag: feature_flags
 72                .iter()
 73                .any(|flag| flag == "llm-closed-beta"),
 74            bypass_account_age_check: feature_flags
 75                .iter()
 76                .any(|flag| flag == "bypass-account-age-check"),
 77            can_use_web_search_tool: true,
 78            use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
 79            plan: if is_staff {
 80                Plan::ZedPro
 81            } else {
 82                subscription
 83                    .as_ref()
 84                    .and_then(|subscription| subscription.kind)
 85                    .map_or(Plan::Free, |kind| match kind {
 86                        SubscriptionKind::ZedFree => Plan::Free,
 87                        SubscriptionKind::ZedPro => Plan::ZedPro,
 88                        SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
 89                    })
 90            },
 91            has_extended_trial: feature_flags
 92                .iter()
 93                .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
 94            subscription_period: billing_subscription::Model::current_period(
 95                subscription,
 96                is_staff,
 97            )
 98            .map(|(start, end)| (start.naive_utc(), end.naive_utc())),
 99            enable_model_request_overages: billing_preferences
100                .as_ref()
101                .map_or(false, |preferences| {
102                    preferences.model_request_overages_enabled
103                }),
104            model_request_overages_spend_limit_in_cents: billing_preferences
105                .as_ref()
106                .map_or(0, |preferences| {
107                    preferences.model_request_overages_spend_limit_in_cents as u32
108                }),
109        };
110
111        Ok(jsonwebtoken::encode(
112            &Header::default(),
113            &claims,
114            &EncodingKey::from_secret(secret.as_ref()),
115        )?)
116    }
117
118    pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
119        let secret = config
120            .llm_api_secret
121            .as_ref()
122            .ok_or_else(|| anyhow!("no LLM API secret"))?;
123
124        match jsonwebtoken::decode::<Self>(
125            token,
126            &DecodingKey::from_secret(secret.as_ref()),
127            &Validation::default(),
128        ) {
129            Ok(token) => Ok(token.claims),
130            Err(e) => {
131                if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
132                    Err(ValidateLlmTokenError::Expired)
133                } else {
134                    Err(ValidateLlmTokenError::JwtError(e))
135                }
136            }
137        }
138    }
139}
140
141#[derive(Error, Debug)]
142pub enum ValidateLlmTokenError {
143    #[error("access token is expired")]
144    Expired,
145    #[error("access token validation error: {0}")]
146    JwtError(#[from] jsonwebtoken::errors::Error),
147    #[error("{0}")]
148    Other(#[from] anyhow::Error),
149}