1use crate::{db::UserId, Config};
2use anyhow::{anyhow, Result};
3use chrono::Utc;
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use thiserror::Error;
8
9#[derive(Clone, Debug, Default, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct LlmTokenClaims {
12 pub iat: u64,
13 pub exp: u64,
14 pub jti: String,
15 pub user_id: u64,
16 // This field is temporarily optional so it can be added
17 // in a backwards-compatible way. We can make it required
18 // once all of the LLM tokens have cycled (~1 hour after
19 // this change has been deployed).
20 #[serde(default)]
21 pub github_user_login: Option<String>,
22 pub is_staff: bool,
23 pub plan: rpc::proto::Plan,
24}
25
26const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
27
28impl LlmTokenClaims {
29 pub fn create(
30 user_id: UserId,
31 github_user_login: String,
32 is_staff: bool,
33 plan: rpc::proto::Plan,
34 config: &Config,
35 ) -> Result<String> {
36 let secret = config
37 .llm_api_secret
38 .as_ref()
39 .ok_or_else(|| anyhow!("no LLM API secret"))?;
40
41 let now = Utc::now();
42 let claims = Self {
43 iat: now.timestamp() as u64,
44 exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
45 jti: uuid::Uuid::new_v4().to_string(),
46 user_id: user_id.to_proto(),
47 github_user_login: Some(github_user_login),
48 is_staff,
49 plan,
50 };
51
52 Ok(jsonwebtoken::encode(
53 &Header::default(),
54 &claims,
55 &EncodingKey::from_secret(secret.as_ref()),
56 )?)
57 }
58
59 pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
60 let secret = config
61 .llm_api_secret
62 .as_ref()
63 .ok_or_else(|| anyhow!("no LLM API secret"))?;
64
65 match jsonwebtoken::decode::<Self>(
66 token,
67 &DecodingKey::from_secret(secret.as_ref()),
68 &Validation::default(),
69 ) {
70 Ok(token) => Ok(token.claims),
71 Err(e) => {
72 if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
73 Err(ValidateLlmTokenError::Expired)
74 } else {
75 Err(ValidateLlmTokenError::JwtError(e))
76 }
77 }
78 }
79 }
80}
81
82#[derive(Error, Debug)]
83pub enum ValidateLlmTokenError {
84 #[error("access token is expired")]
85 Expired,
86 #[error("access token validation error: {0}")]
87 JwtError(#[from] jsonwebtoken::errors::Error),
88 #[error("{0}")]
89 Other(#[from] anyhow::Error),
90}