1use crate::db::billing_subscription::SubscriptionKind;
2use crate::db::{billing_customer, billing_subscription, user};
3use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG};
4use crate::{Config, db::billing_preference};
5use anyhow::{Context as _, Result};
6use chrono::{NaiveDateTime, Utc};
7use cloud_llm_client::Plan;
8use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
9use serde::{Deserialize, Serialize};
10use std::time::Duration;
11use thiserror::Error;
12use uuid::Uuid;
13
14#[derive(Clone, Debug, Default, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct LlmTokenClaims {
17 pub iat: u64,
18 pub exp: u64,
19 pub jti: String,
20 pub user_id: u64,
21 pub system_id: Option<String>,
22 pub metrics_id: Uuid,
23 pub github_user_login: String,
24 pub account_created_at: NaiveDateTime,
25 pub is_staff: bool,
26 pub has_llm_closed_beta_feature_flag: bool,
27 pub bypass_account_age_check: bool,
28 pub use_llm_request_queue: bool,
29 pub plan: Plan,
30 pub has_extended_trial: bool,
31 pub subscription_period: (NaiveDateTime, NaiveDateTime),
32 pub enable_model_request_overages: bool,
33 pub model_request_overages_spend_limit_in_cents: u32,
34 pub can_use_web_search_tool: bool,
35 #[serde(default)]
36 pub has_overdue_invoices: bool,
37}
38
39const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
40
41impl LlmTokenClaims {
42 pub fn create(
43 user: &user::Model,
44 is_staff: bool,
45 billing_customer: billing_customer::Model,
46 billing_preferences: Option<billing_preference::Model>,
47 feature_flags: &Vec<String>,
48 subscription: billing_subscription::Model,
49 system_id: Option<String>,
50 config: &Config,
51 ) -> Result<String> {
52 let secret = config
53 .llm_api_secret
54 .as_ref()
55 .context("no LLM API secret")?;
56
57 let plan = if is_staff {
58 Plan::ZedPro
59 } else {
60 subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
61 SubscriptionKind::ZedFree => Plan::ZedFree,
62 SubscriptionKind::ZedPro => Plan::ZedPro,
63 SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
64 })
65 };
66 let subscription_period =
67 billing_subscription::Model::current_period(Some(subscription), is_staff)
68 .map(|(start, end)| (start.naive_utc(), end.naive_utc()))
69 .context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?;
70
71 let now = Utc::now();
72 let claims = Self {
73 iat: now.timestamp() as u64,
74 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
75 jti: uuid::Uuid::new_v4().to_string(),
76 user_id: user.id.to_proto(),
77 system_id,
78 metrics_id: user.metrics_id,
79 github_user_login: user.github_login.clone(),
80 account_created_at: user.account_created_at(),
81 is_staff,
82 has_llm_closed_beta_feature_flag: feature_flags
83 .iter()
84 .any(|flag| flag == "llm-closed-beta"),
85 bypass_account_age_check: feature_flags
86 .iter()
87 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG),
88 can_use_web_search_tool: true,
89 use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
90 plan,
91 has_extended_trial: feature_flags
92 .iter()
93 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
94 subscription_period,
95 enable_model_request_overages: billing_preferences
96 .as_ref()
97 .map_or(false, |preferences| {
98 preferences.model_request_overages_enabled
99 }),
100 model_request_overages_spend_limit_in_cents: billing_preferences
101 .as_ref()
102 .map_or(0, |preferences| {
103 preferences.model_request_overages_spend_limit_in_cents as u32
104 }),
105 has_overdue_invoices: billing_customer.has_overdue_invoices,
106 };
107
108 Ok(jsonwebtoken::encode(
109 &Header::default(),
110 &claims,
111 &EncodingKey::from_secret(secret.as_ref()),
112 )?)
113 }
114
115 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
116 let secret = config
117 .llm_api_secret
118 .as_ref()
119 .context("no LLM API secret")?;
120
121 match jsonwebtoken::decode::<Self>(
122 token,
123 &DecodingKey::from_secret(secret.as_ref()),
124 &Validation::default(),
125 ) {
126 Ok(token) => Ok(token.claims),
127 Err(e) => {
128 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
129 Err(ValidateLlmTokenError::Expired)
130 } else {
131 Err(ValidateLlmTokenError::JwtError(e))
132 }
133 }
134 }
135 }
136}
137
138#[derive(Error, Debug)]
139pub enum ValidateLlmTokenError {
140 #[error("access token is expired")]
141 Expired,
142 #[error("access token validation error: {0}")]
143 JwtError(#[from] jsonwebtoken::errors::Error),
144 #[error("{0}")]
145 Other(#[from] anyhow::Error),
146}