1use crate::{db::UserId, Config};
2use anyhow::{anyhow, Result};
3use chrono::Utc;
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use thiserror::Error;
8
9#[derive(Clone, Debug, Default, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct LlmTokenClaims {
12 pub iat: u64,
13 pub exp: u64,
14 pub jti: String,
15 pub user_id: u64,
16 pub plan: rpc::proto::Plan,
17}
18
19const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
20
21impl LlmTokenClaims {
22 pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
23 let secret = config
24 .llm_api_secret
25 .as_ref()
26 .ok_or_else(|| anyhow!("no LLM API secret"))?;
27
28 let now = Utc::now();
29 let claims = Self {
30 iat: now.timestamp() as u64,
31 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
32 jti: uuid::Uuid::new_v4().to_string(),
33 user_id: user_id.to_proto(),
34 plan,
35 };
36
37 Ok(jsonwebtoken::encode(
38 &Header::default(),
39 &claims,
40 &EncodingKey::from_secret(secret.as_ref()),
41 )?)
42 }
43
44 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
45 let secret = config
46 .llm_api_secret
47 .as_ref()
48 .ok_or_else(|| anyhow!("no LLM API secret"))?;
49
50 match jsonwebtoken::decode::<Self>(
51 token,
52 &DecodingKey::from_secret(secret.as_ref()),
53 &Validation::default(),
54 ) {
55 Ok(token) => Ok(token.claims),
56 Err(e) => {
57 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
58 Err(ValidateLlmTokenError::Expired)
59 } else {
60 Err(ValidateLlmTokenError::JwtError(e))
61 }
62 }
63 }
64 }
65}
66
67#[derive(Error, Debug)]
68pub enum ValidateLlmTokenError {
69 #[error("access token is expired")]
70 Expired,
71 #[error("access token validation error: {0}")]
72 JwtError(#[from] jsonwebtoken::errors::Error),
73 #[error("{0}")]
74 Other(#[from] anyhow::Error),
75}