collab: Track active user counts independently for each model (#16624)

Marshall Bowers created

This PR fixes an issue where the active user count spanned individual
models.

We now track the active user counts on a per-model basis.

Release Notes:

- N/A

Change summary

crates/collab/src/llm.rs                   | 34 ++++++++++++++---------
crates/collab/src/llm/db/queries/usages.rs | 23 ++++++++++++---
crates/collab/src/main.rs                  |  5 --
3 files changed, 40 insertions(+), 22 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -18,6 +18,7 @@ use axum::{
     Extension, Json, Router, TypedHeader,
 };
 use chrono::{DateTime, Duration, Utc};
+use collections::HashMap;
 use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 use futures::{Stream, StreamExt as _};
 use http_client::IsahcHttpClient;
@@ -41,7 +42,8 @@ pub struct LlmState {
     pub db: Arc<LlmDatabase>,
     pub http_client: IsahcHttpClient,
     pub clickhouse_client: Option<clickhouse::Client>,
-    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
+    active_user_count_by_model:
+        RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
 }
 
 const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
@@ -69,9 +71,6 @@ impl LlmState {
             .build()
             .context("failed to construct http client")?;
 
-        let initial_active_user_count =
-            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
-
         let this = Self {
             executor,
             db,
@@ -80,25 +79,34 @@ impl LlmState {
                 .clickhouse_url
                 .as_ref()
                 .and_then(|_| build_clickhouse_client(&config).log_err()),
-            active_user_count: RwLock::new(initial_active_user_count),
+            active_user_count_by_model: RwLock::new(HashMap::default()),
             config,
         };
 
         Ok(Arc::new(this))
     }
 
-    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
+    pub async fn get_active_user_count(
+        &self,
+        provider: LanguageModelProvider,
+        model: &str,
+    ) -> Result<ActiveUserCount> {
         let now = Utc::now();
 
-        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
-            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
-                return Ok(*count);
+        {
+            let active_user_count_by_model = self.active_user_count_by_model.read().await;
+            if let Some((last_updated, count)) =
+                active_user_count_by_model.get(&(provider, model.to_string()))
+            {
+                if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
+                    return Ok(*count);
+                }
             }
         }
 
-        let mut cache = self.active_user_count.write().await;
-        let new_count = self.db.get_active_user_count(now).await?;
-        *cache = Some((now, new_count));
+        let mut cache = self.active_user_count_by_model.write().await;
+        let new_count = self.db.get_active_user_count(provider, model, now).await?;
+        cache.insert((provider, model.to_string()), (now, new_count));
         Ok(new_count)
     }
 }
@@ -419,7 +427,7 @@ async fn check_usage_limit(
         )
         .await?;
 
-    let active_users = state.get_active_user_count().await?;
+    let active_users = state.get_active_user_count(provider, model_name).await?;
 
     let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
     let users_in_recent_days = active_users.users_in_recent_days.max(1);

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -343,15 +343,27 @@ impl LlmDatabase {
         .await
     }
 
-    pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result<ActiveUserCount> {
+    /// Returns the active user count for the specified model.
+    pub async fn get_active_user_count(
+        &self,
+        provider: LanguageModelProvider,
+        model_name: &str,
+        now: DateTimeUtc,
+    ) -> Result<ActiveUserCount> {
         self.transaction(|tx| async move {
             let minute_since = now - Duration::minutes(5);
             let day_since = now - Duration::days(5);
 
+            let model = self
+                .models
+                .get(&(provider, model_name.to_string()))
+                .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
+
             let users_in_recent_minutes = usage::Entity::find()
                 .filter(
-                    usage::Column::Timestamp
-                        .gte(minute_since.naive_utc())
+                    usage::Column::ModelId
+                        .eq(model.id)
+                        .and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
                         .and(usage::Column::IsStaff.eq(false)),
                 )
                 .select_only()
@@ -362,8 +374,9 @@ impl LlmDatabase {
 
             let users_in_recent_days = usage::Entity::find()
                 .filter(
-                    usage::Column::Timestamp
-                        .gte(day_since.naive_utc())
+                    usage::Column::ModelId
+                        .eq(model.id)
+                        .and(usage::Column::Timestamp.gte(day_since.naive_utc()))
                         .and(usage::Column::IsStaff.eq(false)),
                 )
                 .select_only()

crates/collab/src/main.rs 🔗

@@ -302,10 +302,7 @@ async fn handle_liveness_probe(
     }
 
     if let Some(llm_state) = llm_state {
-        llm_state
-            .db
-            .get_active_user_count(chrono::Utc::now())
-            .await?;
+        llm_state.db.list_providers().await?;
     }
 
     Ok("ok".to_string())