1use crate::db::billing_subscription::SubscriptionKind;
2use crate::db::{billing_subscription, user};
3use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
4use crate::{Config, db::billing_preference};
5use anyhow::{Result, anyhow};
6use chrono::{NaiveDateTime, Utc};
7use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
8use serde::{Deserialize, Serialize};
9use std::time::Duration;
10use thiserror::Error;
11use uuid::Uuid;
12use zed_llm_client::Plan;
13
14#[derive(Clone, Debug, Default, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct LlmTokenClaims {
17 pub iat: u64,
18 pub exp: u64,
19 pub jti: String,
20 pub user_id: u64,
21 pub system_id: Option<String>,
22 pub metrics_id: Uuid,
23 pub github_user_login: String,
24 pub account_created_at: NaiveDateTime,
25 pub is_staff: bool,
26 pub has_llm_closed_beta_feature_flag: bool,
27 pub bypass_account_age_check: bool,
28 #[serde(default)]
29 pub use_llm_request_queue: bool,
30 pub plan: Plan,
31 #[serde(default)]
32 pub has_extended_trial: bool,
33 #[serde(default)]
34 pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
35 #[serde(default)]
36 pub enable_model_request_overages: bool,
37 #[serde(default)]
38 pub model_request_overages_spend_limit_in_cents: u32,
39 #[serde(default)]
40 pub can_use_web_search_tool: bool,
41}
42
43const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
44
45impl LlmTokenClaims {
46 pub fn create(
47 user: &user::Model,
48 is_staff: bool,
49 billing_preferences: Option<billing_preference::Model>,
50 feature_flags: &Vec<String>,
51 subscription: Option<billing_subscription::Model>,
52 system_id: Option<String>,
53 config: &Config,
54 ) -> Result<String> {
55 let secret = config
56 .llm_api_secret
57 .as_ref()
58 .ok_or_else(|| anyhow!("no LLM API secret"))?;
59
60 let now = Utc::now();
61 let claims = Self {
62 iat: now.timestamp() as u64,
63 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
64 jti: uuid::Uuid::new_v4().to_string(),
65 user_id: user.id.to_proto(),
66 system_id,
67 metrics_id: user.metrics_id,
68 github_user_login: user.github_login.clone(),
69 account_created_at: user.account_created_at(),
70 is_staff,
71 has_llm_closed_beta_feature_flag: feature_flags
72 .iter()
73 .any(|flag| flag == "llm-closed-beta"),
74 bypass_account_age_check: feature_flags
75 .iter()
76 .any(|flag| flag == "bypass-account-age-check"),
77 can_use_web_search_tool: true,
78 use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
79 plan: if is_staff {
80 Plan::ZedPro
81 } else {
82 subscription
83 .as_ref()
84 .and_then(|subscription| subscription.kind)
85 .map_or(Plan::Free, |kind| match kind {
86 SubscriptionKind::ZedFree => Plan::Free,
87 SubscriptionKind::ZedPro => Plan::ZedPro,
88 SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
89 })
90 },
91 has_extended_trial: feature_flags
92 .iter()
93 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
94 subscription_period: billing_subscription::Model::current_period(
95 subscription,
96 is_staff,
97 )
98 .map(|(start, end)| (start.naive_utc(), end.naive_utc())),
99 enable_model_request_overages: billing_preferences
100 .as_ref()
101 .map_or(false, |preferences| {
102 preferences.model_request_overages_enabled
103 }),
104 model_request_overages_spend_limit_in_cents: billing_preferences
105 .as_ref()
106 .map_or(0, |preferences| {
107 preferences.model_request_overages_spend_limit_in_cents as u32
108 }),
109 };
110
111 Ok(jsonwebtoken::encode(
112 &Header::default(),
113 &claims,
114 &EncodingKey::from_secret(secret.as_ref()),
115 )?)
116 }
117
118 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
119 let secret = config
120 .llm_api_secret
121 .as_ref()
122 .ok_or_else(|| anyhow!("no LLM API secret"))?;
123
124 match jsonwebtoken::decode::<Self>(
125 token,
126 &DecodingKey::from_secret(secret.as_ref()),
127 &Validation::default(),
128 ) {
129 Ok(token) => Ok(token.claims),
130 Err(e) => {
131 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
132 Err(ValidateLlmTokenError::Expired)
133 } else {
134 Err(ValidateLlmTokenError::JwtError(e))
135 }
136 }
137 }
138 }
139}
140
141#[derive(Error, Debug)]
142pub enum ValidateLlmTokenError {
143 #[error("access token is expired")]
144 Expired,
145 #[error("access token validation error: {0}")]
146 JwtError(#[from] jsonwebtoken::errors::Error),
147 #[error("{0}")]
148 Other(#[from] anyhow::Error),
149}