token.rs

 1use crate::{db::UserId, Config};
 2use anyhow::{anyhow, Result};
 3use chrono::Utc;
 4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
 5use serde::{Deserialize, Serialize};
 6use std::time::Duration;
 7use thiserror::Error;
 8
 9#[derive(Clone, Debug, Default, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct LlmTokenClaims {
12    pub iat: u64,
13    pub exp: u64,
14    pub jti: String,
15    pub user_id: u64,
16    pub is_staff: bool,
17    pub plan: rpc::proto::Plan,
18}
19
20const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
21
22impl LlmTokenClaims {
23    pub fn create(
24        user_id: UserId,
25        is_staff: bool,
26        plan: rpc::proto::Plan,
27        config: &Config,
28    ) -> Result<String> {
29        let secret = config
30            .llm_api_secret
31            .as_ref()
32            .ok_or_else(|| anyhow!("no LLM API secret"))?;
33
34        let now = Utc::now();
35        let claims = Self {
36            iat: now.timestamp() as u64,
37            exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
38            jti: uuid::Uuid::new_v4().to_string(),
39            user_id: user_id.to_proto(),
40            is_staff,
41            plan,
42        };
43
44        Ok(jsonwebtoken::encode(
45            &Header::default(),
46            &claims,
47            &EncodingKey::from_secret(secret.as_ref()),
48        )?)
49    }
50
51    pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
52        let secret = config
53            .llm_api_secret
54            .as_ref()
55            .ok_or_else(|| anyhow!("no LLM API secret"))?;
56
57        match jsonwebtoken::decode::<Self>(
58            token,
59            &DecodingKey::from_secret(secret.as_ref()),
60            &Validation::default(),
61        ) {
62            Ok(token) => Ok(token.claims),
63            Err(e) => {
64                if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
65                    Err(ValidateLlmTokenError::Expired)
66                } else {
67                    Err(ValidateLlmTokenError::JwtError(e))
68                }
69            }
70        }
71    }
72}
73
74#[derive(Error, Debug)]
75pub enum ValidateLlmTokenError {
76    #[error("access token is expired")]
77    Expired,
78    #[error("access token validation error: {0}")]
79    JwtError(#[from] jsonwebtoken::errors::Error),
80    #[error("{0}")]
81    Other(#[from] anyhow::Error),
82}