token.rs

 1use crate::{db::UserId, Config};
 2use anyhow::{anyhow, Result};
 3use chrono::Utc;
 4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
 5use serde::{Deserialize, Serialize};
 6use std::time::Duration;
 7use thiserror::Error;
 8
 9#[derive(Clone, Debug, Default, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct LlmTokenClaims {
12    pub iat: u64,
13    pub exp: u64,
14    pub jti: String,
15    pub user_id: u64,
16    // This field is temporarily optional so it can be added
17    // in a backwards-compatible way. We can make it required
18    // once all of the LLM tokens have cycled (~1 hour after
19    // this change has been deployed).
20    #[serde(default)]
21    pub github_user_login: Option<String>,
22    pub is_staff: bool,
23    pub plan: rpc::proto::Plan,
24}
25
26const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
27
28impl LlmTokenClaims {
29    pub fn create(
30        user_id: UserId,
31        github_user_login: String,
32        is_staff: bool,
33        plan: rpc::proto::Plan,
34        config: &Config,
35    ) -> Result<String> {
36        let secret = config
37            .llm_api_secret
38            .as_ref()
39            .ok_or_else(|| anyhow!("no LLM API secret"))?;
40
41        let now = Utc::now();
42        let claims = Self {
43            iat: now.timestamp() as u64,
44            exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
45            jti: uuid::Uuid::new_v4().to_string(),
46            user_id: user_id.to_proto(),
47            github_user_login: Some(github_user_login),
48            is_staff,
49            plan,
50        };
51
52        Ok(jsonwebtoken::encode(
53            &Header::default(),
54            &claims,
55            &EncodingKey::from_secret(secret.as_ref()),
56        )?)
57    }
58
59    pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
60        let secret = config
61            .llm_api_secret
62            .as_ref()
63            .ok_or_else(|| anyhow!("no LLM API secret"))?;
64
65        match jsonwebtoken::decode::<Self>(
66            token,
67            &DecodingKey::from_secret(secret.as_ref()),
68            &Validation::default(),
69        ) {
70            Ok(token) => Ok(token.claims),
71            Err(e) => {
72                if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
73                    Err(ValidateLlmTokenError::Expired)
74                } else {
75                    Err(ValidateLlmTokenError::JwtError(e))
76                }
77            }
78        }
79    }
80}
81
82#[derive(Error, Debug)]
83pub enum ValidateLlmTokenError {
84    #[error("access token is expired")]
85    Expired,
86    #[error("access token validation error: {0}")]
87    JwtError(#[from] jsonwebtoken::errors::Error),
88    #[error("{0}")]
89    Other(#[from] anyhow::Error),
90}