@@ -18,6 +18,7 @@ use axum::{
Extension, Json, Router, TypedHeader,
};
use chrono::{DateTime, Duration, Utc};
+use collections::HashMap;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use http_client::IsahcHttpClient;
@@ -41,7 +42,8 @@ pub struct LlmState {
pub db: Arc<LlmDatabase>,
pub http_client: IsahcHttpClient,
pub clickhouse_client: Option<clickhouse::Client>,
- active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
+ active_user_count_by_model:
+ RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
}
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
@@ -69,9 +71,6 @@ impl LlmState {
.build()
.context("failed to construct http client")?;
- let initial_active_user_count =
- Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
-
let this = Self {
executor,
db,
@@ -80,25 +79,34 @@ impl LlmState {
.clickhouse_url
.as_ref()
.and_then(|_| build_clickhouse_client(&config).log_err()),
- active_user_count: RwLock::new(initial_active_user_count),
+ active_user_count_by_model: RwLock::new(HashMap::default()),
config,
};
Ok(Arc::new(this))
}
- pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
+ pub async fn get_active_user_count(
+ &self,
+ provider: LanguageModelProvider,
+ model: &str,
+ ) -> Result<ActiveUserCount> {
let now = Utc::now();
- if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
- if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
- return Ok(*count);
+ {
+ let active_user_count_by_model = self.active_user_count_by_model.read().await;
+ if let Some((last_updated, count)) =
+ active_user_count_by_model.get(&(provider, model.to_string()))
+ {
+ if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
+ return Ok(*count);
+ }
}
}
- let mut cache = self.active_user_count.write().await;
- let new_count = self.db.get_active_user_count(now).await?;
- *cache = Some((now, new_count));
+ let mut cache = self.active_user_count_by_model.write().await;
+ let new_count = self.db.get_active_user_count(provider, model, now).await?;
+ cache.insert((provider, model.to_string()), (now, new_count));
Ok(new_count)
}
}
@@ -419,7 +427,7 @@ async fn check_usage_limit(
)
.await?;
- let active_users = state.get_active_user_count().await?;
+ let active_users = state.get_active_user_count(provider, model_name).await?;
let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
let users_in_recent_days = active_users.users_in_recent_days.max(1);
@@ -343,15 +343,27 @@ impl LlmDatabase {
.await
}
- pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result<ActiveUserCount> {
+ /// Returns the active user count for the specified model.
+ pub async fn get_active_user_count(
+ &self,
+ provider: LanguageModelProvider,
+ model_name: &str,
+ now: DateTimeUtc,
+ ) -> Result<ActiveUserCount> {
self.transaction(|tx| async move {
let minute_since = now - Duration::minutes(5);
let day_since = now - Duration::days(5);
+ let model = self
+ .models
+ .get(&(provider, model_name.to_string()))
+ .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
+
let users_in_recent_minutes = usage::Entity::find()
.filter(
- usage::Column::Timestamp
- .gte(minute_since.naive_utc())
+ usage::Column::ModelId
+ .eq(model.id)
+ .and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()
@@ -362,8 +374,9 @@ impl LlmDatabase {
let users_in_recent_days = usage::Entity::find()
.filter(
- usage::Column::Timestamp
- .gte(day_since.naive_utc())
+ usage::Column::ModelId
+ .eq(model.id)
+ .and(usage::Column::Timestamp.gte(day_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()