token.rs

 1use crate::{db::UserId, Config};
 2use anyhow::{anyhow, Result};
 3use chrono::Utc;
 4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
 5use serde::{Deserialize, Serialize};
 6use std::time::Duration;
 7use thiserror::Error;
 8
 9#[derive(Clone, Debug, Default, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct LlmTokenClaims {
12    pub iat: u64,
13    pub exp: u64,
14    pub jti: String,
15    pub user_id: u64,
16    pub github_user_login: String,
17    pub is_staff: bool,
18    pub has_llm_closed_beta_feature_flag: bool,
19    // This field is temporarily optional so it can be added
20    // in a backwards-compatible way. We can make it required
21    // once all of the LLM tokens have cycled (~1 hour after
22    // this change has been deployed).
23    #[serde(default)]
24    pub has_llm_subscription: Option<bool>,
25    pub plan: rpc::proto::Plan,
26}
27
28const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
29
30impl LlmTokenClaims {
31    pub fn create(
32        user_id: UserId,
33        github_user_login: String,
34        is_staff: bool,
35        has_llm_closed_beta_feature_flag: bool,
36        has_llm_subscription: bool,
37        plan: rpc::proto::Plan,
38        config: &Config,
39    ) -> Result<String> {
40        let secret = config
41            .llm_api_secret
42            .as_ref()
43            .ok_or_else(|| anyhow!("no LLM API secret"))?;
44
45        let now = Utc::now();
46        let claims = Self {
47            iat: now.timestamp() as u64,
48            exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
49            jti: uuid::Uuid::new_v4().to_string(),
50            user_id: user_id.to_proto(),
51            github_user_login,
52            is_staff,
53            has_llm_closed_beta_feature_flag,
54            has_llm_subscription: Some(has_llm_subscription),
55            plan,
56        };
57
58        Ok(jsonwebtoken::encode(
59            &Header::default(),
60            &claims,
61            &EncodingKey::from_secret(secret.as_ref()),
62        )?)
63    }
64
65    pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
66        let secret = config
67            .llm_api_secret
68            .as_ref()
69            .ok_or_else(|| anyhow!("no LLM API secret"))?;
70
71        match jsonwebtoken::decode::<Self>(
72            token,
73            &DecodingKey::from_secret(secret.as_ref()),
74            &Validation::default(),
75        ) {
76            Ok(token) => Ok(token.claims),
77            Err(e) => {
78                if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
79                    Err(ValidateLlmTokenError::Expired)
80                } else {
81                    Err(ValidateLlmTokenError::JwtError(e))
82                }
83            }
84        }
85    }
86}
87
88#[derive(Error, Debug)]
89pub enum ValidateLlmTokenError {
90    #[error("access token is expired")]
91    Expired,
92    #[error("access token validation error: {0}")]
93    JwtError(#[from] jsonwebtoken::errors::Error),
94    #[error("{0}")]
95    Other(#[from] anyhow::Error),
96}