1use crate::db::billing_subscription::SubscriptionKind;
2use crate::db::{billing_subscription, user};
3use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
4use crate::{Config, db::billing_preference};
5use anyhow::{Result, anyhow};
6use chrono::{NaiveDateTime, Utc};
7use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
8use serde::{Deserialize, Serialize};
9use std::time::Duration;
10use thiserror::Error;
11use uuid::Uuid;
12use zed_llm_client::Plan;
13
14#[derive(Clone, Debug, Default, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct LlmTokenClaims {
17 pub iat: u64,
18 pub exp: u64,
19 pub jti: String,
20 pub user_id: u64,
21 pub system_id: Option<String>,
22 pub metrics_id: Uuid,
23 pub github_user_login: String,
24 pub account_created_at: NaiveDateTime,
25 pub is_staff: bool,
26 pub has_llm_closed_beta_feature_flag: bool,
27 pub bypass_account_age_check: bool,
28 pub use_llm_request_queue: bool,
29 pub plan: Plan,
30 pub has_extended_trial: bool,
31 pub subscription_period: (NaiveDateTime, NaiveDateTime),
32 pub enable_model_request_overages: bool,
33 pub model_request_overages_spend_limit_in_cents: u32,
34 pub can_use_web_search_tool: bool,
35}
36
37const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
38
39impl LlmTokenClaims {
40 pub fn create(
41 user: &user::Model,
42 is_staff: bool,
43 billing_preferences: Option<billing_preference::Model>,
44 feature_flags: &Vec<String>,
45 subscription: Option<billing_subscription::Model>,
46 system_id: Option<String>,
47 config: &Config,
48 ) -> Result<String> {
49 let secret = config
50 .llm_api_secret
51 .as_ref()
52 .ok_or_else(|| anyhow!("no LLM API secret"))?;
53
54 let plan = if is_staff {
55 Plan::ZedPro
56 } else {
57 subscription
58 .as_ref()
59 .and_then(|subscription| subscription.kind)
60 .map_or(Plan::ZedFree, |kind| match kind {
61 SubscriptionKind::ZedFree => Plan::ZedFree,
62 SubscriptionKind::ZedPro => Plan::ZedPro,
63 SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
64 })
65 };
66 let subscription_period =
67 billing_subscription::Model::current_period(subscription, is_staff)
68 .map(|(start, end)| (start.naive_utc(), end.naive_utc()))
69 .ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
70
71 let now = Utc::now();
72 let claims = Self {
73 iat: now.timestamp() as u64,
74 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
75 jti: uuid::Uuid::new_v4().to_string(),
76 user_id: user.id.to_proto(),
77 system_id,
78 metrics_id: user.metrics_id,
79 github_user_login: user.github_login.clone(),
80 account_created_at: user.account_created_at(),
81 is_staff,
82 has_llm_closed_beta_feature_flag: feature_flags
83 .iter()
84 .any(|flag| flag == "llm-closed-beta"),
85 bypass_account_age_check: feature_flags
86 .iter()
87 .any(|flag| flag == "bypass-account-age-check"),
88 can_use_web_search_tool: true,
89 use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
90 plan,
91 has_extended_trial: feature_flags
92 .iter()
93 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
94 subscription_period,
95 enable_model_request_overages: billing_preferences
96 .as_ref()
97 .map_or(false, |preferences| {
98 preferences.model_request_overages_enabled
99 }),
100 model_request_overages_spend_limit_in_cents: billing_preferences
101 .as_ref()
102 .map_or(0, |preferences| {
103 preferences.model_request_overages_spend_limit_in_cents as u32
104 }),
105 };
106
107 Ok(jsonwebtoken::encode(
108 &Header::default(),
109 &claims,
110 &EncodingKey::from_secret(secret.as_ref()),
111 )?)
112 }
113
114 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
115 let secret = config
116 .llm_api_secret
117 .as_ref()
118 .ok_or_else(|| anyhow!("no LLM API secret"))?;
119
120 match jsonwebtoken::decode::<Self>(
121 token,
122 &DecodingKey::from_secret(secret.as_ref()),
123 &Validation::default(),
124 ) {
125 Ok(token) => Ok(token.claims),
126 Err(e) => {
127 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
128 Err(ValidateLlmTokenError::Expired)
129 } else {
130 Err(ValidateLlmTokenError::JwtError(e))
131 }
132 }
133 }
134 }
135}
136
137#[derive(Error, Debug)]
138pub enum ValidateLlmTokenError {
139 #[error("access token is expired")]
140 Expired,
141 #[error("access token validation error: {0}")]
142 JwtError(#[from] jsonwebtoken::errors::Error),
143 #[error("{0}")]
144 Other(#[from] anyhow::Error),
145}