1use crate::{db::UserId, Config};
2use anyhow::{anyhow, Result};
3use chrono::Utc;
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use thiserror::Error;
8
9#[derive(Clone, Debug, Default, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct LlmTokenClaims {
12 pub iat: u64,
13 pub exp: u64,
14 pub jti: String,
15 pub user_id: u64,
16 // This field is temporarily optional so it can be added
17 // in a backwards-compatible way. We can make it required
18 // once all of the LLM tokens have cycled (~1 hour after
19 // this change has been deployed).
20 #[serde(default)]
21 pub github_user_login: Option<String>,
22 pub is_staff: bool,
23 #[serde(default)]
24 pub has_llm_closed_beta_feature_flag: bool,
25 pub plan: rpc::proto::Plan,
26}
27
28const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
29
30impl LlmTokenClaims {
31 pub fn create(
32 user_id: UserId,
33 github_user_login: String,
34 is_staff: bool,
35 has_llm_closed_beta_feature_flag: bool,
36 plan: rpc::proto::Plan,
37 config: &Config,
38 ) -> Result<String> {
39 let secret = config
40 .llm_api_secret
41 .as_ref()
42 .ok_or_else(|| anyhow!("no LLM API secret"))?;
43
44 let now = Utc::now();
45 let claims = Self {
46 iat: now.timestamp() as u64,
47 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
48 jti: uuid::Uuid::new_v4().to_string(),
49 user_id: user_id.to_proto(),
50 github_user_login: Some(github_user_login),
51 is_staff,
52 has_llm_closed_beta_feature_flag,
53 plan,
54 };
55
56 Ok(jsonwebtoken::encode(
57 &Header::default(),
58 &claims,
59 &EncodingKey::from_secret(secret.as_ref()),
60 )?)
61 }
62
63 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
64 let secret = config
65 .llm_api_secret
66 .as_ref()
67 .ok_or_else(|| anyhow!("no LLM API secret"))?;
68
69 match jsonwebtoken::decode::<Self>(
70 token,
71 &DecodingKey::from_secret(secret.as_ref()),
72 &Validation::default(),
73 ) {
74 Ok(token) => Ok(token.claims),
75 Err(e) => {
76 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
77 Err(ValidateLlmTokenError::Expired)
78 } else {
79 Err(ValidateLlmTokenError::JwtError(e))
80 }
81 }
82 }
83 }
84}
85
86#[derive(Error, Debug)]
87pub enum ValidateLlmTokenError {
88 #[error("access token is expired")]
89 Expired,
90 #[error("access token validation error: {0}")]
91 JwtError(#[from] jsonwebtoken::errors::Error),
92 #[error("{0}")]
93 Other(#[from] anyhow::Error),
94}