1use crate::Cents;
2use crate::db::billing_subscription::SubscriptionKind;
3use crate::db::{billing_subscription, user};
4use crate::llm::{
5 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
6};
7use crate::{Config, db::billing_preference};
8use anyhow::{Result, anyhow};
9use chrono::{NaiveDateTime, Utc};
10use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13use thiserror::Error;
14use util::maybe;
15use uuid::Uuid;
16use zed_llm_client::Plan;
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19#[serde(rename_all = "camelCase")]
20pub struct LlmTokenClaims {
21 pub iat: u64,
22 pub exp: u64,
23 pub jti: String,
24 pub user_id: u64,
25 pub system_id: Option<String>,
26 pub metrics_id: Uuid,
27 pub github_user_login: String,
28 pub account_created_at: NaiveDateTime,
29 pub is_staff: bool,
30 pub has_llm_closed_beta_feature_flag: bool,
31 pub bypass_account_age_check: bool,
32 pub has_llm_subscription: bool,
33 pub max_monthly_spend_in_cents: u32,
34 pub custom_llm_monthly_allowance_in_cents: Option<u32>,
35 pub plan: Plan,
36 #[serde(default)]
37 pub has_extended_trial: bool,
38 #[serde(default)]
39 pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
40 #[serde(default)]
41 pub enable_model_request_overages: bool,
42 #[serde(default)]
43 pub model_request_overages_spend_limit_in_cents: u32,
44 #[serde(default)]
45 pub can_use_web_search_tool: bool,
46}
47
48const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
49
50impl LlmTokenClaims {
51 pub fn create(
52 user: &user::Model,
53 is_staff: bool,
54 billing_preferences: Option<billing_preference::Model>,
55 feature_flags: &Vec<String>,
56 has_legacy_llm_subscription: bool,
57 subscription: Option<billing_subscription::Model>,
58 system_id: Option<String>,
59 config: &Config,
60 ) -> Result<String> {
61 let secret = config
62 .llm_api_secret
63 .as_ref()
64 .ok_or_else(|| anyhow!("no LLM API secret"))?;
65
66 let now = Utc::now();
67 let claims = Self {
68 iat: now.timestamp() as u64,
69 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
70 jti: uuid::Uuid::new_v4().to_string(),
71 user_id: user.id.to_proto(),
72 system_id,
73 metrics_id: user.metrics_id,
74 github_user_login: user.github_login.clone(),
75 account_created_at: user.account_created_at(),
76 is_staff,
77 has_llm_closed_beta_feature_flag: feature_flags
78 .iter()
79 .any(|flag| flag == "llm-closed-beta"),
80 bypass_account_age_check: feature_flags
81 .iter()
82 .any(|flag| flag == "bypass-account-age-check"),
83 can_use_web_search_tool: feature_flags.iter().any(|flag| flag == "assistant2"),
84 has_llm_subscription: has_legacy_llm_subscription,
85 max_monthly_spend_in_cents: billing_preferences
86 .as_ref()
87 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
88 preferences.max_monthly_llm_usage_spending_in_cents as u32
89 }),
90 custom_llm_monthly_allowance_in_cents: user
91 .custom_llm_monthly_allowance_in_cents
92 .map(|allowance| allowance as u32),
93 plan: subscription
94 .as_ref()
95 .and_then(|subscription| subscription.kind)
96 .map_or(Plan::Free, |kind| match kind {
97 SubscriptionKind::ZedFree => Plan::Free,
98 SubscriptionKind::ZedPro => Plan::ZedPro,
99 SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
100 }),
101 has_extended_trial: feature_flags
102 .iter()
103 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
104 subscription_period: maybe!({
105 let subscription = subscription?;
106 let period_start_at = subscription.current_period_start_at()?;
107 let period_end_at = subscription.current_period_end_at()?;
108
109 Some((period_start_at.naive_utc(), period_end_at.naive_utc()))
110 }),
111 enable_model_request_overages: billing_preferences
112 .as_ref()
113 .map_or(false, |preferences| {
114 preferences.model_request_overages_enabled
115 }),
116 model_request_overages_spend_limit_in_cents: billing_preferences
117 .as_ref()
118 .map_or(0, |preferences| {
119 preferences.model_request_overages_spend_limit_in_cents as u32
120 }),
121 };
122
123 Ok(jsonwebtoken::encode(
124 &Header::default(),
125 &claims,
126 &EncodingKey::from_secret(secret.as_ref()),
127 )?)
128 }
129
130 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
131 let secret = config
132 .llm_api_secret
133 .as_ref()
134 .ok_or_else(|| anyhow!("no LLM API secret"))?;
135
136 match jsonwebtoken::decode::<Self>(
137 token,
138 &DecodingKey::from_secret(secret.as_ref()),
139 &Validation::default(),
140 ) {
141 Ok(token) => Ok(token.claims),
142 Err(e) => {
143 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
144 Err(ValidateLlmTokenError::Expired)
145 } else {
146 Err(ValidateLlmTokenError::JwtError(e))
147 }
148 }
149 }
150 }
151
152 pub fn free_tier_monthly_spending_limit(&self) -> Cents {
153 self.custom_llm_monthly_allowance_in_cents
154 .map(Cents)
155 .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT)
156 }
157}
158
159#[derive(Error, Debug)]
160pub enum ValidateLlmTokenError {
161 #[error("access token is expired")]
162 Expired,
163 #[error("access token validation error: {0}")]
164 JwtError(#[from] jsonwebtoken::errors::Error),
165 #[error("{0}")]
166 Other(#[from] anyhow::Error),
167}