Detailed changes
@@ -0,0 +1,7 @@
+create table revoked_access_tokens (
+ id serial primary key,
+ jti text not null,
+ revoked_at timestamp without time zone not null default now()
+);
+
+create unique index uix_revoked_access_tokens_on_jti on revoked_access_tokens (jti);
@@ -131,6 +131,15 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
match LlmTokenClaims::validate(&token, &state.config) {
Ok(claims) => {
+ if state.db.is_access_token_revoked(&claims.jti).await? {
+ return Err(Error::http(
+ StatusCode::UNAUTHORIZED,
+ "unauthorized".to_string(),
+ ));
+ }
+
+ tracing::Span::current().record("authn.jti", &claims.jti);
+
req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response())
}
@@ -7,3 +7,4 @@ id_type!(ModelId);
id_type!(ProviderId);
id_type!(UsageId);
id_type!(UsageMeasureId);
+id_type!(RevokedAccessTokenId);
@@ -1,4 +1,5 @@
use super::*;
pub mod providers;
+pub mod revoked_access_tokens;
pub mod usages;
@@ -0,0 +1,15 @@
+use super::*;
+
+impl LlmDatabase {
+ /// Returns whether the access token with the given `jti` has been revoked.
+ pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
+ self.transaction(|tx| async move {
+ Ok(revoked_access_token::Entity::find()
+ .filter(revoked_access_token::Column::Jti.eq(jti))
+ .one(&*tx)
+ .await?
+ .is_some())
+ })
+ .await
+ }
+}
@@ -1,5 +1,6 @@
pub mod lifetime_usage;
pub mod model;
pub mod provider;
+pub mod revoked_access_token;
pub mod usage;
pub mod usage_measure;
@@ -0,0 +1,19 @@
+use chrono::NaiveDateTime;
+use sea_orm::entity::prelude::*;
+
+use crate::llm::db::RevokedAccessTokenId;
+
+/// A revoked access token.
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "revoked_access_tokens")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: RevokedAccessTokenId,
+ pub jti: String,
+ pub revoked_at: NaiveDateTime,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -150,6 +150,7 @@ async fn main() -> Result<()> {
"http_request",
method = ?request.method(),
matched_path,
+ authn.jti = tracing::field::Empty
)
})
.on_response(