1use crate::{db::UserId, Config};
2use anyhow::{anyhow, Result};
3use chrono::Utc;
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use thiserror::Error;
8
9#[derive(Clone, Debug, Default, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct LlmTokenClaims {
12 pub iat: u64,
13 pub exp: u64,
14 pub jti: String,
15 pub user_id: u64,
16 pub github_user_login: String,
17 pub is_staff: bool,
18 pub has_llm_closed_beta_feature_flag: bool,
19 // This field is temporarily optional so it can be added
20 // in a backwards-compatible way. We can make it required
21 // once all of the LLM tokens have cycled (~1 hour after
22 // this change has been deployed).
23 #[serde(default)]
24 pub has_llm_subscription: Option<bool>,
25 pub plan: rpc::proto::Plan,
26}
27
28const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
29
30impl LlmTokenClaims {
31 pub fn create(
32 user_id: UserId,
33 github_user_login: String,
34 is_staff: bool,
35 has_llm_closed_beta_feature_flag: bool,
36 has_llm_subscription: bool,
37 plan: rpc::proto::Plan,
38 config: &Config,
39 ) -> Result<String> {
40 let secret = config
41 .llm_api_secret
42 .as_ref()
43 .ok_or_else(|| anyhow!("no LLM API secret"))?;
44
45 let now = Utc::now();
46 let claims = Self {
47 iat: now.timestamp() as u64,
48 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
49 jti: uuid::Uuid::new_v4().to_string(),
50 user_id: user_id.to_proto(),
51 github_user_login,
52 is_staff,
53 has_llm_closed_beta_feature_flag,
54 has_llm_subscription: Some(has_llm_subscription),
55 plan,
56 };
57
58 Ok(jsonwebtoken::encode(
59 &Header::default(),
60 &claims,
61 &EncodingKey::from_secret(secret.as_ref()),
62 )?)
63 }
64
65 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
66 let secret = config
67 .llm_api_secret
68 .as_ref()
69 .ok_or_else(|| anyhow!("no LLM API secret"))?;
70
71 match jsonwebtoken::decode::<Self>(
72 token,
73 &DecodingKey::from_secret(secret.as_ref()),
74 &Validation::default(),
75 ) {
76 Ok(token) => Ok(token.claims),
77 Err(e) => {
78 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
79 Err(ValidateLlmTokenError::Expired)
80 } else {
81 Err(ValidateLlmTokenError::JwtError(e))
82 }
83 }
84 }
85 }
86}
87
88#[derive(Error, Debug)]
89pub enum ValidateLlmTokenError {
90 #[error("access token is expired")]
91 Expired,
92 #[error("access token validation error: {0}")]
93 JwtError(#[from] jsonwebtoken::errors::Error),
94 #[error("{0}")]
95 Other(#[from] anyhow::Error),
96}