1use crate::{db::UserId, Config};
2use anyhow::{anyhow, Result};
3use chrono::Utc;
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use thiserror::Error;
8
9#[derive(Clone, Debug, Default, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct LlmTokenClaims {
12 pub iat: u64,
13 pub exp: u64,
14 pub jti: String,
15 pub user_id: u64,
16 pub is_staff: bool,
17 pub plan: rpc::proto::Plan,
18}
19
20const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
21
22impl LlmTokenClaims {
23 pub fn create(
24 user_id: UserId,
25 is_staff: bool,
26 plan: rpc::proto::Plan,
27 config: &Config,
28 ) -> Result<String> {
29 let secret = config
30 .llm_api_secret
31 .as_ref()
32 .ok_or_else(|| anyhow!("no LLM API secret"))?;
33
34 let now = Utc::now();
35 let claims = Self {
36 iat: now.timestamp() as u64,
37 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
38 jti: uuid::Uuid::new_v4().to_string(),
39 user_id: user_id.to_proto(),
40 is_staff,
41 plan,
42 };
43
44 Ok(jsonwebtoken::encode(
45 &Header::default(),
46 &claims,
47 &EncodingKey::from_secret(secret.as_ref()),
48 )?)
49 }
50
51 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
52 let secret = config
53 .llm_api_secret
54 .as_ref()
55 .ok_or_else(|| anyhow!("no LLM API secret"))?;
56
57 match jsonwebtoken::decode::<Self>(
58 token,
59 &DecodingKey::from_secret(secret.as_ref()),
60 &Validation::default(),
61 ) {
62 Ok(token) => Ok(token.claims),
63 Err(e) => {
64 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
65 Err(ValidateLlmTokenError::Expired)
66 } else {
67 Err(ValidateLlmTokenError::JwtError(e))
68 }
69 }
70 }
71 }
72}
73
74#[derive(Error, Debug)]
75pub enum ValidateLlmTokenError {
76 #[error("access token is expired")]
77 Expired,
78 #[error("access token validation error: {0}")]
79 JwtError(#[from] jsonwebtoken::errors::Error),
80 #[error("{0}")]
81 Other(#[from] anyhow::Error),
82}