1use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
2use crate::{
3 db::{billing_preference, UserId},
4 Config,
5};
6use anyhow::{anyhow, Result};
7use chrono::Utc;
8use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
9use serde::{Deserialize, Serialize};
10use std::time::Duration;
11use thiserror::Error;
12
13#[derive(Clone, Debug, Default, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct LlmTokenClaims {
16 pub iat: u64,
17 pub exp: u64,
18 pub jti: String,
19 pub user_id: u64,
20 pub github_user_login: String,
21 pub is_staff: bool,
22 pub has_llm_closed_beta_feature_flag: bool,
23 pub has_llm_subscription: bool,
24 pub max_monthly_spend_in_cents: u32,
25 pub plan: rpc::proto::Plan,
26}
27
28const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
29
30impl LlmTokenClaims {
31 #[allow(clippy::too_many_arguments)]
32 pub fn create(
33 user_id: UserId,
34 github_user_login: String,
35 is_staff: bool,
36 billing_preferences: Option<billing_preference::Model>,
37 has_llm_closed_beta_feature_flag: bool,
38 has_llm_subscription: bool,
39 plan: rpc::proto::Plan,
40 config: &Config,
41 ) -> Result<String> {
42 let secret = config
43 .llm_api_secret
44 .as_ref()
45 .ok_or_else(|| anyhow!("no LLM API secret"))?;
46
47 let now = Utc::now();
48 let claims = Self {
49 iat: now.timestamp() as u64,
50 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
51 jti: uuid::Uuid::new_v4().to_string(),
52 user_id: user_id.to_proto(),
53 github_user_login,
54 is_staff,
55 has_llm_closed_beta_feature_flag,
56 has_llm_subscription,
57 max_monthly_spend_in_cents: billing_preferences
58 .map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
59 preferences.max_monthly_llm_usage_spending_in_cents as u32
60 }),
61 plan,
62 };
63
64 Ok(jsonwebtoken::encode(
65 &Header::default(),
66 &claims,
67 &EncodingKey::from_secret(secret.as_ref()),
68 )?)
69 }
70
71 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
72 let secret = config
73 .llm_api_secret
74 .as_ref()
75 .ok_or_else(|| anyhow!("no LLM API secret"))?;
76
77 match jsonwebtoken::decode::<Self>(
78 token,
79 &DecodingKey::from_secret(secret.as_ref()),
80 &Validation::default(),
81 ) {
82 Ok(token) => Ok(token.claims),
83 Err(e) => {
84 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
85 Err(ValidateLlmTokenError::Expired)
86 } else {
87 Err(ValidateLlmTokenError::JwtError(e))
88 }
89 }
90 }
91 }
92}
93
94#[derive(Error, Debug)]
95pub enum ValidateLlmTokenError {
96 #[error("access token is expired")]
97 Expired,
98 #[error("access token validation error: {0}")]
99 JwtError(#[from] jsonwebtoken::errors::Error),
100 #[error("{0}")]
101 Other(#[from] anyhow::Error),
102}