Apply rate limits in LLM service (#15997)

Max Brunsfeld , Marshall , and Marshall Bowers created

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>

Change summary

crates/collab/k8s/collab.template.yml                                       |   5 
crates/collab/k8s/postgrest.template.yml                                    | 126 
crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql          |  32 
crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql |   5 
crates/collab/migrations_llm/20240806213401_create_usages.sql               |  18 
crates/collab/src/llm.rs                                                    | 178 
crates/collab/src/llm/db.rs                                                 |  26 
crates/collab/src/llm/db/ids.rs                                             |   1 
crates/collab/src/llm/db/queries/providers.rs                               | 141 
crates/collab/src/llm/db/queries/usages.rs                                  | 325 
crates/collab/src/llm/db/seed.rs                                            |  45 
crates/collab/src/llm/db/tables.rs                                          |   1 
crates/collab/src/llm/db/tables/model.rs                                    |   3 
crates/collab/src/llm/db/tables/provider.rs                                 |   3 
crates/collab/src/llm/db/tables/usage.rs                                    |  26 
crates/collab/src/llm/db/tables/usage_measure.rs                            |  35 
crates/collab/src/llm/db/tests.rs                                           |  60 
crates/collab/src/llm/db/tests/provider_tests.rs                            |  26 
crates/collab/src/llm/db/tests/usage_tests.rs                               | 124 
crates/collab/src/main.rs                                                   |  10 
crates/rpc/src/llm.rs                                                       |   6 
21 files changed, 976 insertions(+), 220 deletions(-)

Detailed changes

crates/collab/k8s/collab.template.yml 🔗

@@ -85,6 +85,11 @@ spec:
                 secretKeyRef:
                   name: database
                   key: url
+            - name: LLM_DATABASE_URL
+              valueFrom:
+                secretKeyRef:
+                  name: llm-database
+                  key: url
             - name: DATABASE_MAX_CONNECTIONS
               value: "${DATABASE_MAX_CONNECTIONS}"
             - name: API_TOKEN

crates/collab/k8s/postgrest.template.yml 🔗

@@ -12,7 +12,7 @@ metadata:
 spec:
   type: LoadBalancer
   selector:
-    app: postgrest
+    app: nginx
   ports:
     - name: web
       protocol: TCP
@@ -24,17 +24,99 @@ apiVersion: apps/v1
 kind: Deployment
 metadata:
   namespace: ${ZED_KUBE_NAMESPACE}
-  name: postgrest
+  name: nginx
+spec:
+  replicas: 1
+  selector:
+    matchLabels:
+      app: nginx
+  template:
+    metadata:
+      labels:
+        app: nginx
+    spec:
+      containers:
+        - name: nginx
+          image: nginx:latest
+          ports:
+            - containerPort: 8080
+              protocol: TCP
+          volumeMounts:
+            - name: nginx-config
+              mountPath: /etc/nginx/nginx.conf
+              subPath: nginx.conf
+      volumes:
+        - name: nginx-config
+          configMap:
+            name: nginx-config
+
+---
+apiVersion: v1
+kind: ConfigMap
+metadata:
+  namespace: ${ZED_KUBE_NAMESPACE}
+  name: nginx-config
+data:
+  nginx.conf: |
+    events {}
+
+    http {
+      server {
+        listen 8080;
+
+        location /app/ {
+          proxy_pass http://postgrest-app:8080/;
+        }
+
+        location /llm/ {
+          proxy_pass http://postgrest-llm:8080/;
+        }
+      }
+    }
+
+---
+apiVersion: v1
+kind: Service
+metadata:
+  namespace: ${ZED_KUBE_NAMESPACE}
+  name: postgrest-app
+spec:
+  selector:
+    app: postgrest-app
+  ports:
+    - protocol: TCP
+      port: 8080
+      targetPort: 8080
+
+---
+apiVersion: v1
+kind: Service
+metadata:
+  namespace: ${ZED_KUBE_NAMESPACE}
+  name: postgrest-llm
+spec:
+  selector:
+    app: postgrest-llm
+  ports:
+    - protocol: TCP
+      port: 8080
+      targetPort: 8080
 
+---
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+  namespace: ${ZED_KUBE_NAMESPACE}
+  name: postgrest-app
 spec:
   replicas: 1
   selector:
     matchLabels:
-      app: postgrest
+      app: postgrest-app
   template:
     metadata:
       labels:
-        app: postgrest
+        app: postgrest-app
     spec:
       containers:
         - name: postgrest
@@ -55,3 +137,39 @@ spec:
                 secretKeyRef:
                   name: postgrest
                   key: jwt_secret
+
+---
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+  namespace: ${ZED_KUBE_NAMESPACE}
+  name: postgrest-llm
+spec:
+  replicas: 1
+  selector:
+    matchLabels:
+      app: postgrest-llm
+  template:
+    metadata:
+      labels:
+        app: postgrest-llm
+    spec:
+      containers:
+        - name: postgrest
+          image: "postgrest/postgrest"
+          ports:
+            - containerPort: 8080
+              protocol: TCP
+          env:
+            - name: PGRST_SERVER_PORT
+              value: "8080"
+            - name: PGRST_DB_URI
+              valueFrom:
+                secretKeyRef:
+                  name: llm-database
+                  key: url
+            - name: PGRST_JWT_SECRET
+              valueFrom:
+                secretKeyRef:
+                  name: postgrest
+                  key: jwt_secret

crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql 🔗

@@ -1,32 +0,0 @@
-create table providers (
-    id integer primary key autoincrement,
-    name text not null
-);
-
-create unique index uix_providers_on_name on providers (name);
-
-create table models (
-    id integer primary key autoincrement,
-    provider_id integer not null references providers (id) on delete cascade,
-    name text not null
-);
-
-create unique index uix_models_on_provider_id_name on models (provider_id, name);
-create index ix_models_on_provider_id on models (provider_id);
-create index ix_models_on_name on models (name);
-
-create table if not exists usages (
-    id integer primary key autoincrement,
-    user_id integer not null,
-    model_id integer not null references models (id) on delete cascade,
-    requests_this_minute integer not null default 0,
-    tokens_this_minute integer not null default 0,
-    requests_this_day integer not null default 0,
-    tokens_this_day integer not null default 0,
-    requests_this_month integer not null default 0,
-    tokens_this_month integer not null default 0
-);
-
-create index ix_usages_on_user_id on usages (user_id);
-create index ix_usages_on_model_id on usages (model_id);
-create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);

crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql 🔗

@@ -8,7 +8,10 @@ create unique index uix_providers_on_name on providers (name);
 create table if not exists models (
     id serial primary key,
     provider_id integer not null references providers (id) on delete cascade,
-    name text not null
+    name text not null,
+    max_requests_per_minute integer not null,
+    max_tokens_per_minute integer not null,
+    max_tokens_per_day integer not null
 );
 
 create unique index uix_models_on_provider_id_name on models (provider_id, name);

crates/collab/migrations_llm/20240806213401_create_usages.sql 🔗

@@ -1,15 +1,19 @@
+create table usage_measures (
+    id serial primary key,
+    name text not null
+);
+
+create unique index uix_usage_measures_on_name on usage_measures (name);
+
 create table if not exists usages (
     id serial primary key,
     user_id integer not null,
     model_id integer not null references models (id) on delete cascade,
-    requests_this_minute integer not null default 0,
-    tokens_this_minute bigint not null default 0,
-    requests_this_day integer not null default 0,
-    tokens_this_day bigint not null default 0,
-    requests_this_month integer not null default 0,
-    tokens_this_month bigint not null default 0
+    measure_id integer not null references usage_measures (id) on delete cascade,
+    timestamp timestamp without time zone not null,
+    buckets bigint[] not null
 );
 
 create index ix_usages_on_user_id on usages (user_id);
 create index ix_usages_on_model_id on usages (model_id);
-create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);
+create unique index uix_usages_on_user_id_model_id_measure_id on usages (user_id, model_id, measure_id);

crates/collab/src/llm.rs 🔗

@@ -2,24 +2,25 @@ mod authorization;
 pub mod db;
 mod token;
 
-use crate::api::CloudflareIpCountryHeader;
-use crate::llm::authorization::authorize_access_to_language_model;
-use crate::llm::db::LlmDatabase;
-use crate::{executor::Executor, Config, Error, Result};
+use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
 use anyhow::{anyhow, Context as _};
-use axum::TypedHeader;
+use authorization::authorize_access_to_language_model;
 use axum::{
     body::Body,
     http::{self, HeaderName, HeaderValue, Request, StatusCode},
     middleware::{self, Next},
     response::{IntoResponse, Response},
     routing::post,
-    Extension, Json, Router,
+    Extension, Json, Router, TypedHeader,
 };
+use chrono::{DateTime, Duration, Utc};
+use db::{ActiveUserCount, LlmDatabase};
 use futures::StreamExt as _;
 use http_client::IsahcHttpClient;
 use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use std::sync::Arc;
+use tokio::sync::RwLock;
+use util::ResultExt;
 
 pub use token::*;
 
@@ -28,8 +29,11 @@ pub struct LlmState {
     pub executor: Executor,
     pub db: Option<Arc<LlmDatabase>>,
     pub http_client: IsahcHttpClient,
+    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 }
 
+const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
+
 impl LlmState {
     pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
         // TODO: This is temporary until we have the LLM database stood up.
@@ -44,7 +48,8 @@ impl LlmState {
 
             let mut db_options = db::ConnectOptions::new(database_url);
             db_options.max_connections(max_connections);
-            let db = LlmDatabase::new(db_options, executor.clone()).await?;
+            let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
+            db.initialize().await?;
 
             Some(Arc::new(db))
         } else {
@@ -57,15 +62,41 @@ impl LlmState {
             .build()
             .context("failed to construct http client")?;
 
+        let initial_active_user_count = if let Some(db) = &db {
+            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
+        } else {
+            None
+        };
+
         let this = Self {
             config,
             executor,
             db,
             http_client,
+            active_user_count: RwLock::new(initial_active_user_count),
         };
 
         Ok(Arc::new(this))
     }
+
+    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
+        let now = Utc::now();
+
+        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
+            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
+                return Ok(*count);
+            }
+        }
+
+        if let Some(db) = &self.db {
+            let mut cache = self.active_user_count.write().await;
+            let new_count = db.get_active_user_count(now).await?;
+            *cache = Some((now, new_count));
+            Ok(new_count)
+        } else {
+            Ok(ActiveUserCount::default())
+        }
+    }
 }
 
 pub fn routes() -> Router<(), Body> {
@@ -122,14 +153,22 @@ async fn perform_completion(
     country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
     Json(params): Json<PerformCompletionParams>,
 ) -> Result<impl IntoResponse> {
+    let model = normalize_model_name(params.provider, params.model);
+
     authorize_access_to_language_model(
         &state.config,
         &claims,
         country_code_header.map(|header| header.to_string()),
         params.provider,
-        &params.model,
+        &model,
     )?;
 
+    let user_id = claims.user_id as i32;
+
+    if state.db.is_some() {
+        check_usage_limit(&state, params.provider, &model, &claims).await?;
+    }
+
     match params.provider {
         LanguageModelProvider::Anthropic => {
             let api_key = state
@@ -160,9 +199,31 @@ async fn perform_completion(
             )
             .await?;
 
-            let stream = chunks.map(|event| {
+            let mut recorder = state.db.clone().map(|db| UsageRecorder {
+                db,
+                executor: state.executor.clone(),
+                user_id,
+                provider: params.provider,
+                model,
+                token_count: 0,
+            });
+
+            let stream = chunks.map(move |event| {
                 let mut buffer = Vec::new();
                 event.map(|chunk| {
+                    match &chunk {
+                        anthropic::Event::MessageStart {
+                            message: anthropic::Response { usage, .. },
+                        }
+                        | anthropic::Event::MessageDelta { usage, .. } => {
+                            if let Some(recorder) = &mut recorder {
+                                recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
+                                recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
+                            }
+                        }
+                        _ => {}
+                    }
+
                     buffer.clear();
                     serde_json::to_writer(&mut buffer, &chunk).unwrap();
                     buffer.push(b'\n');
@@ -259,3 +320,102 @@ async fn perform_completion(
         }
     }
 }
+
+fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
+    match provider {
+        LanguageModelProvider::Anthropic => {
+            for prefix in &[
+                "claude-3-5-sonnet",
+                "claude-3-haiku",
+                "claude-3-opus",
+                "claude-3-sonnet",
+            ] {
+                if name.starts_with(prefix) {
+                    return prefix.to_string();
+                }
+            }
+        }
+        LanguageModelProvider::OpenAi => {}
+        LanguageModelProvider::Google => {}
+        LanguageModelProvider::Zed => {}
+    }
+
+    name
+}
+
+async fn check_usage_limit(
+    state: &Arc<LlmState>,
+    provider: LanguageModelProvider,
+    model_name: &str,
+    claims: &LlmTokenClaims,
+) -> Result<()> {
+    let db = state
+        .db
+        .as_ref()
+        .ok_or_else(|| anyhow!("LLM database not configured"))?;
+    let model = db.model(provider, model_name)?;
+    let usage = db
+        .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
+        .await?;
+
+    let active_users = state.get_active_user_count().await?;
+
+    let per_user_max_requests_per_minute =
+        model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
+    let per_user_max_tokens_per_minute =
+        model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
+    let per_user_max_tokens_per_day =
+        model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
+
+    let checks = [
+        (
+            usage.requests_this_minute,
+            per_user_max_requests_per_minute,
+            "requests per minute",
+        ),
+        (
+            usage.tokens_this_minute,
+            per_user_max_tokens_per_minute,
+            "tokens per minute",
+        ),
+        (
+            usage.tokens_this_day,
+            per_user_max_tokens_per_day,
+            "tokens per day",
+        ),
+    ];
+
+    for (usage, limit, resource) in checks {
+        if usage > limit {
+            return Err(Error::http(
+                StatusCode::TOO_MANY_REQUESTS,
+                format!("Rate limit exceeded. Maximum {} reached.", resource),
+            ));
+        }
+    }
+
+    Ok(())
+}
+struct UsageRecorder {
+    db: Arc<LlmDatabase>,
+    executor: Executor,
+    user_id: i32,
+    provider: LanguageModelProvider,
+    model: String,
+    token_count: usize,
+}
+
+impl Drop for UsageRecorder {
+    fn drop(&mut self) {
+        let db = self.db.clone();
+        let user_id = self.user_id;
+        let provider = self.provider;
+        let model = std::mem::take(&mut self.model);
+        let token_count = self.token_count;
+        self.executor.spawn_detached(async move {
+            db.record_usage(user_id, provider, &model, token_count, Utc::now())
+                .await
+                .log_err();
+        })
+    }
+}

crates/collab/src/llm/db.rs 🔗

@@ -1,20 +1,26 @@
 mod ids;
 mod queries;
+mod seed;
 mod tables;
 
 #[cfg(test)]
 mod tests;
 
+use collections::HashMap;
 pub use ids::*;
+use rpc::LanguageModelProvider;
+pub use seed::*;
 pub use tables::*;
 
 #[cfg(test)]
 pub use tests::TestLlmDb;
+use usage_measure::UsageMeasure;
 
 use std::future::Future;
 use std::sync::Arc;
 
 use anyhow::anyhow;
+pub use queries::usages::ActiveUserCount;
 use sea_orm::prelude::*;
 pub use sea_orm::ConnectOptions;
 use sea_orm::{
@@ -31,6 +37,9 @@ pub struct LlmDatabase {
     pool: DatabaseConnection,
     #[allow(unused)]
     executor: Executor,
+    provider_ids: HashMap<LanguageModelProvider, ProviderId>,
+    models: HashMap<(LanguageModelProvider, String), model::Model>,
+    usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
     #[cfg(test)]
     runtime: Option<tokio::runtime::Runtime>,
 }
@@ -43,11 +52,28 @@ impl LlmDatabase {
             options: options.clone(),
             pool: sea_orm::Database::connect(options).await?,
             executor,
+            provider_ids: HashMap::default(),
+            models: HashMap::default(),
+            usage_measure_ids: HashMap::default(),
             #[cfg(test)]
             runtime: None,
         })
     }
 
+    pub async fn initialize(&mut self) -> Result<()> {
+        self.initialize_providers().await?;
+        self.initialize_models().await?;
+        self.initialize_usage_measures().await?;
+        Ok(())
+    }
+
+    pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
+        Ok(self
+            .models
+            .get(&(provider, name.to_string()))
+            .ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
+    }
+
     pub fn options(&self) -> &ConnectOptions {
         &self.options
     }

crates/collab/src/llm/db/queries/providers.rs 🔗

@@ -1,66 +1,115 @@
-use sea_orm::sea_query::OnConflict;
+use super::*;
 use sea_orm::QueryOrder;
+use std::str::FromStr;
+use strum::IntoEnumIterator as _;
 
-use super::*;
+pub struct ModelRateLimits {
+    pub max_requests_per_minute: i32,
+    pub max_tokens_per_minute: i32,
+    pub max_tokens_per_day: i32,
+}
 
 impl LlmDatabase {
-    pub async fn initialize_providers(&self) -> Result<()> {
-        self.transaction(|tx| async move {
-            let providers_and_models = vec![
-                ("anthropic", "claude-3-5-sonnet"),
-                ("anthropic", "claude-3-opus"),
-                ("anthropic", "claude-3-sonnet"),
-                ("anthropic", "claude-3-haiku"),
-            ];
+    pub async fn initialize_providers(&mut self) -> Result<()> {
+        self.provider_ids = self
+            .transaction(|tx| async move {
+                let existing_providers = provider::Entity::find().all(&*tx).await?;
 
-            for (provider_name, model_name) in providers_and_models {
-                let insert_provider = provider::Entity::insert(provider::ActiveModel {
-                    name: ActiveValue::set(provider_name.to_owned()),
-                    ..Default::default()
-                })
-                .on_conflict(
-                    OnConflict::columns([provider::Column::Name])
-                        .update_column(provider::Column::Name)
-                        .to_owned(),
-                );
+                let mut new_providers = LanguageModelProvider::iter()
+                    .filter(|provider| {
+                        !existing_providers
+                            .iter()
+                            .any(|p| p.name == provider.to_string())
+                    })
+                    .map(|provider| provider::ActiveModel {
+                        name: ActiveValue::set(provider.to_string()),
+                        ..Default::default()
+                    })
+                    .peekable();
 
-                let provider = if tx.support_returning() {
-                    insert_provider.exec_with_returning(&*tx).await?
-                } else {
-                    insert_provider.exec_without_returning(&*tx).await?;
-                    provider::Entity::find()
-                        .filter(provider::Column::Name.eq(provider_name))
-                        .one(&*tx)
-                        .await?
-                        .ok_or_else(|| anyhow!("failed to insert provider"))?
-                };
+                if new_providers.peek().is_some() {
+                    provider::Entity::insert_many(new_providers)
+                        .exec(&*tx)
+                        .await?;
+                }
 
-                model::Entity::insert(model::ActiveModel {
-                    provider_id: ActiveValue::set(provider.id),
-                    name: ActiveValue::set(model_name.to_owned()),
-                    ..Default::default()
-                })
-                .on_conflict(
-                    OnConflict::columns([model::Column::ProviderId, model::Column::Name])
-                        .update_column(model::Column::Name)
-                        .to_owned(),
-                )
-                .exec_without_returning(&*tx)
-                .await?;
-            }
+                let all_providers: HashMap<_, _> = provider::Entity::find()
+                    .all(&*tx)
+                    .await?
+                    .iter()
+                    .filter_map(|provider| {
+                        LanguageModelProvider::from_str(&provider.name)
+                            .ok()
+                            .map(|p| (p, provider.id))
+                    })
+                    .collect();
 
+                Ok(all_providers)
+            })
+            .await?;
+        Ok(())
+    }
+
+    pub async fn initialize_models(&mut self) -> Result<()> {
+        let all_provider_ids = &self.provider_ids;
+        self.models = self
+            .transaction(|tx| async move {
+                let all_models: HashMap<_, _> = model::Entity::find()
+                    .all(&*tx)
+                    .await?
+                    .into_iter()
+                    .filter_map(|model| {
+                        let provider = all_provider_ids.iter().find_map(|(provider, id)| {
+                            if *id == model.provider_id {
+                                Some(provider)
+                            } else {
+                                None
+                            }
+                        })?;
+                        Some(((*provider, model.name.clone()), model))
+                    })
+                    .collect();
+                Ok(all_models)
+            })
+            .await?;
+        Ok(())
+    }
+
+    pub async fn insert_models(
+        &mut self,
+        models: &[(LanguageModelProvider, String, ModelRateLimits)],
+    ) -> Result<()> {
+        let all_provider_ids = &self.provider_ids;
+        self.transaction(|tx| async move {
+            model::Entity::insert_many(models.into_iter().map(|(provider, name, rate_limits)| {
+                let provider_id = all_provider_ids[&provider];
+                model::ActiveModel {
+                    provider_id: ActiveValue::set(provider_id),
+                    name: ActiveValue::set(name.clone()),
+                    max_requests_per_minute: ActiveValue::set(rate_limits.max_requests_per_minute),
+                    max_tokens_per_minute: ActiveValue::set(rate_limits.max_tokens_per_minute),
+                    max_tokens_per_day: ActiveValue::set(rate_limits.max_tokens_per_day),
+                    ..Default::default()
+                }
+            }))
+            .exec_without_returning(&*tx)
+            .await?;
             Ok(())
         })
-        .await
+        .await?;
+        self.initialize_models().await
     }
 
     /// Returns the list of LLM providers.
-    pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
+    pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
         self.transaction(|tx| async move {
             Ok(provider::Entity::find()
                 .order_by_asc(provider::Column::Name)
                 .all(&*tx)
-                .await?)
+                .await?
+                .into_iter()
+                .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
+                .collect())
         })
         .await
     }

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -1,57 +1,318 @@
+use chrono::Duration;
 use rpc::LanguageModelProvider;
+use sea_orm::QuerySelect;
+use std::{iter, str::FromStr};
+use strum::IntoEnumIterator as _;
 
 use super::*;
 
+#[derive(Debug, PartialEq, Clone, Copy)]
+pub struct Usage {
+    pub requests_this_minute: usize,
+    pub tokens_this_minute: usize,
+    pub tokens_this_day: usize,
+    pub tokens_this_month: usize,
+}
+
+#[derive(Clone, Copy, Debug, Default)]
+pub struct ActiveUserCount {
+    pub users_in_recent_minutes: usize,
+    pub users_in_recent_days: usize,
+}
+
 impl LlmDatabase {
-    pub async fn find_or_create_usage(
+    pub async fn initialize_usage_measures(&mut self) -> Result<()> {
+        let all_measures = self
+            .transaction(|tx| async move {
+                let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
+
+                let new_measures = UsageMeasure::iter()
+                    .filter(|measure| {
+                        !existing_measures
+                            .iter()
+                            .any(|m| m.name == measure.to_string())
+                    })
+                    .map(|measure| usage_measure::ActiveModel {
+                        name: ActiveValue::set(measure.to_string()),
+                        ..Default::default()
+                    })
+                    .collect::<Vec<_>>();
+
+                if !new_measures.is_empty() {
+                    usage_measure::Entity::insert_many(new_measures)
+                        .exec(&*tx)
+                        .await?;
+                }
+
+                Ok(usage_measure::Entity::find().all(&*tx).await?)
+            })
+            .await?;
+
+        self.usage_measure_ids = all_measures
+            .into_iter()
+            .filter_map(|measure| {
+                UsageMeasure::from_str(&measure.name)
+                    .ok()
+                    .map(|um| (um, measure.id))
+            })
+            .collect();
+        Ok(())
+    }
+
+    pub async fn get_usage(
         &self,
         user_id: i32,
         provider: LanguageModelProvider,
         model_name: &str,
-    ) -> Result<usage::Model> {
+        now: DateTimeUtc,
+    ) -> Result<Usage> {
         self.transaction(|tx| async move {
-            let provider_name = match provider {
-                LanguageModelProvider::Anthropic => "anthropic",
-                LanguageModelProvider::OpenAi => "open_ai",
-                LanguageModelProvider::Google => "google",
-                LanguageModelProvider::Zed => "zed",
-            };
-
-            let model = model::Entity::find()
-                .inner_join(provider::Entity)
+            let model = self
+                .models
+                .get(&(provider, model_name.to_string()))
+                .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
+
+            let usages = usage::Entity::find()
                 .filter(
-                    provider::Column::Name
-                        .eq(provider_name)
-                        .and(model::Column::Name.eq(model_name)),
+                    usage::Column::UserId
+                        .eq(user_id)
+                        .and(usage::Column::ModelId.eq(model.id)),
                 )
-                .one(&*tx)
-                .await?
-                // TODO: Create the model, if one doesn't exist.
-                .ok_or_else(|| anyhow!("no model found for {provider_name}:{model_name}"))?;
-            let model_id = model.id;
+                .all(&*tx)
+                .await?;
+
+            let requests_this_minute =
+                self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
+            let tokens_this_minute =
+                self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
+            let tokens_this_day =
+                self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
+            let tokens_this_month =
+                self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMonth)?;
 
-            let existing_usage = usage::Entity::find()
+            Ok(Usage {
+                requests_this_minute,
+                tokens_this_minute,
+                tokens_this_day,
+                tokens_this_month,
+            })
+        })
+        .await
+    }
+
+    pub async fn record_usage(
+        &self,
+        user_id: i32,
+        provider: LanguageModelProvider,
+        model_name: &str,
+        token_count: usize,
+        now: DateTimeUtc,
+    ) -> Result<()> {
+        self.transaction(|tx| async move {
+            let model = self.model(provider, model_name)?;
+
+            let usages = usage::Entity::find()
                 .filter(
                     usage::Column::UserId
                         .eq(user_id)
-                        .and(usage::Column::ModelId.eq(model_id)),
+                        .and(usage::Column::ModelId.eq(model.id)),
                 )
-                .one(&*tx)
+                .all(&*tx)
                 .await?;
-            if let Some(usage) = existing_usage {
-                return Ok(usage);
-            }
 
-            let usage = usage::Entity::insert(usage::ActiveModel {
-                user_id: ActiveValue::set(user_id),
-                model_id: ActiveValue::set(model_id),
-                ..Default::default()
-            })
-            .exec_with_returning(&*tx)
+            self.update_usage_for_measure(
+                user_id,
+                model.id,
+                &usages,
+                UsageMeasure::RequestsPerMinute,
+                now,
+                1,
+                &tx,
+            )
             .await?;
+            self.update_usage_for_measure(
+                user_id,
+                model.id,
+                &usages,
+                UsageMeasure::TokensPerMinute,
+                now,
+                token_count,
+                &tx,
+            )
+            .await?;
+            self.update_usage_for_measure(
+                user_id,
+                model.id,
+                &usages,
+                UsageMeasure::TokensPerDay,
+                now,
+                token_count,
+                &tx,
+            )
+            .await?;
+            self.update_usage_for_measure(
+                user_id,
+                model.id,
+                &usages,
+                UsageMeasure::TokensPerMonth,
+                now,
+                token_count,
+                &tx,
+            )
+            .await?;
+
+            Ok(())
+        })
+        .await
+    }
+
+    pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result<ActiveUserCount> {
+        self.transaction(|tx| async move {
+            let minute_since = now - Duration::minutes(5);
+            let day_since = now - Duration::days(5);
 
-            Ok(usage)
+            let users_in_recent_minutes = usage::Entity::find()
+                .filter(usage::Column::Timestamp.gte(minute_since.naive_utc()))
+                .group_by(usage::Column::UserId)
+                .count(&*tx)
+                .await? as usize;
+
+            let users_in_recent_days = usage::Entity::find()
+                .filter(usage::Column::Timestamp.gte(day_since.naive_utc()))
+                .group_by(usage::Column::UserId)
+                .count(&*tx)
+                .await? as usize;
+
+            Ok(ActiveUserCount {
+                users_in_recent_minutes,
+                users_in_recent_days,
+            })
         })
         .await
     }
+
+    #[allow(clippy::too_many_arguments)]
+    async fn update_usage_for_measure(
+        &self,
+        user_id: i32,
+        model_id: ModelId,
+        usages: &[usage::Model],
+        usage_measure: UsageMeasure,
+        now: DateTimeUtc,
+        usage_to_add: usize,
+        tx: &DatabaseTransaction,
+    ) -> Result<()> {
+        let now = now.naive_utc();
+        let measure_id = *self
+            .usage_measure_ids
+            .get(&usage_measure)
+            .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
+
+        let mut id = None;
+        let mut timestamp = now;
+        let mut buckets = vec![0_i64];
+
+        if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
+            id = Some(old_usage.id);
+            let (live_buckets, buckets_since) =
+                Self::get_live_buckets(old_usage, now, usage_measure);
+            if !live_buckets.is_empty() {
+                buckets.clear();
+                buckets.extend_from_slice(live_buckets);
+                buckets.extend(iter::repeat(0).take(buckets_since));
+                timestamp =
+                    old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
+            }
+        }
+
+        *buckets.last_mut().unwrap() += usage_to_add as i64;
+
+        let mut model = usage::ActiveModel {
+            user_id: ActiveValue::set(user_id),
+            model_id: ActiveValue::set(model_id),
+            measure_id: ActiveValue::set(measure_id),
+            timestamp: ActiveValue::set(timestamp),
+            buckets: ActiveValue::set(buckets),
+            ..Default::default()
+        };
+
+        if let Some(id) = id {
+            model.id = ActiveValue::unchanged(id);
+            model.update(tx).await?;
+        } else {
+            usage::Entity::insert(model)
+                .exec_without_returning(tx)
+                .await?;
+        }
+
+        Ok(())
+    }
+
+    fn get_usage_for_measure(
+        &self,
+        usages: &[usage::Model],
+        now: DateTimeUtc,
+        usage_measure: UsageMeasure,
+    ) -> Result<usize> {
+        let now = now.naive_utc();
+        let measure_id = *self
+            .usage_measure_ids
+            .get(&usage_measure)
+            .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
+        let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
+            return Ok(0);
+        };
+
+        let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
+        Ok(live_buckets.iter().sum::<i64>() as _)
+    }
+
+    fn get_live_buckets(
+        usage: &usage::Model,
+        now: chrono::NaiveDateTime,
+        measure: UsageMeasure,
+    ) -> (&[i64], usize) {
+        let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
+        let buckets_since_usage =
+            seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
+        let buckets_since_usage = buckets_since_usage.ceil() as usize;
+        let mut live_buckets = &[] as &[i64];
+        if buckets_since_usage < measure.bucket_count() {
+            let expired_bucket_count =
+                (usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
+            live_buckets = &usage.buckets[expired_bucket_count..];
+            while live_buckets.first() == Some(&0) {
+                live_buckets = &live_buckets[1..];
+            }
+        }
+        (live_buckets, buckets_since_usage)
+    }
+}
+
+const MINUTE_BUCKET_COUNT: usize = 12;
+const DAY_BUCKET_COUNT: usize = 48;
+const MONTH_BUCKET_COUNT: usize = 30;
+
+impl UsageMeasure {
+    fn bucket_count(&self) -> usize {
+        match self {
+            UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
+            UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
+            UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
+            UsageMeasure::TokensPerMonth => MONTH_BUCKET_COUNT,
+        }
+    }
+
+    fn total_duration(&self) -> Duration {
+        match self {
+            UsageMeasure::RequestsPerMinute => Duration::minutes(1),
+            UsageMeasure::TokensPerMinute => Duration::minutes(1),
+            UsageMeasure::TokensPerDay => Duration::hours(24),
+            UsageMeasure::TokensPerMonth => Duration::days(30),
+        }
+    }
+
+    fn bucket_duration(&self) -> Duration {
+        self.total_duration() / self.bucket_count() as i32
+    }
 }

crates/collab/src/llm/db/seed.rs 🔗

@@ -0,0 +1,45 @@
+use super::*;
+use crate::{Config, Result};
+use queries::providers::ModelRateLimits;
+
+pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
+    db.insert_models(&[
+        (
+            LanguageModelProvider::Anthropic,
+            "claude-3-5-sonnet".into(),
+            ModelRateLimits {
+                max_requests_per_minute: 5,
+                max_tokens_per_minute: 20_000,
+                max_tokens_per_day: 300_000,
+            },
+        ),
+        (
+            LanguageModelProvider::Anthropic,
+            "claude-3-opus".into(),
+            ModelRateLimits {
+                max_requests_per_minute: 5,
+                max_tokens_per_minute: 10_000,
+                max_tokens_per_day: 300_000,
+            },
+        ),
+        (
+            LanguageModelProvider::Anthropic,
+            "claude-3-sonnet".into(),
+            ModelRateLimits {
+                max_requests_per_minute: 5,
+                max_tokens_per_minute: 20_000,
+                max_tokens_per_day: 300_000,
+            },
+        ),
+        (
+            LanguageModelProvider::Anthropic,
+            "claude-3-haiku".into(),
+            ModelRateLimits {
+                max_requests_per_minute: 5,
+                max_tokens_per_minute: 25_000,
+                max_tokens_per_day: 300_000,
+            },
+        ),
+    ])
+    .await
+}

crates/collab/src/llm/db/tables/model.rs 🔗

@@ -10,6 +10,9 @@ pub struct Model {
     pub id: ModelId,
     pub provider_id: ProviderId,
     pub name: String,
+    pub max_requests_per_minute: i32,
+    pub max_tokens_per_minute: i32,
+    pub max_tokens_per_day: i32,
 }
 
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

crates/collab/src/llm/db/tables/provider.rs 🔗

@@ -1,6 +1,5 @@
-use sea_orm::entity::prelude::*;
-
 use crate::llm::db::ProviderId;
+use sea_orm::entity::prelude::*;
 
 /// An LLM provider.
 #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]

crates/collab/src/llm/db/tables/usage.rs 🔗

@@ -1,24 +1,20 @@
+use crate::llm::db::{ModelId, UsageId, UsageMeasureId};
 use sea_orm::entity::prelude::*;
 
-use crate::llm::db::ModelId;
-
 /// An LLM usage record.
 #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
 #[sea_orm(table_name = "usages")]
 pub struct Model {
     #[sea_orm(primary_key)]
-    pub id: i32,
+    pub id: UsageId,
     /// The ID of the Zed user.
     ///
     /// Corresponds to the `users` table in the primary collab database.
     pub user_id: i32,
     pub model_id: ModelId,
-    pub requests_this_minute: i32,
-    pub tokens_this_minute: i64,
-    pub requests_this_day: i32,
-    pub tokens_this_day: i64,
-    pub requests_this_month: i32,
-    pub tokens_this_month: i64,
+    pub measure_id: UsageMeasureId,
+    pub timestamp: DateTime,
+    pub buckets: Vec<i64>,
 }
 
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -29,6 +25,12 @@ pub enum Relation {
         to = "super::model::Column::Id"
     )]
     Model,
+    #[sea_orm(
+        belongs_to = "super::usage_measure::Entity",
+        from = "Column::MeasureId",
+        to = "super::usage_measure::Column::Id"
+    )]
+    UsageMeasure,
 }
 
 impl Related<super::model::Entity> for Entity {
@@ -37,4 +39,10 @@ impl Related<super::model::Entity> for Entity {
     }
 }
 
+impl Related<super::usage_measure::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::UsageMeasure.def()
+    }
+}
+
 impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tables/usage_measure.rs 🔗

@@ -0,0 +1,35 @@
+use crate::llm::db::UsageMeasureId;
+use sea_orm::entity::prelude::*;
+
+#[derive(
+    Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
+)]
+#[strum(serialize_all = "snake_case")]
+pub enum UsageMeasure {
+    RequestsPerMinute,
+    TokensPerMinute,
+    TokensPerDay,
+    TokensPerMonth,
+}
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "usage_measures")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: UsageMeasureId,
+    pub name: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(has_many = "super::usage::Entity")]
+    Usages,
+}
+
+impl Related<super::usage::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Usages.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/llm/db/tests.rs 🔗

@@ -6,7 +6,6 @@ use parking_lot::Mutex;
 use rand::prelude::*;
 use sea_orm::ConnectionTrait;
 use sqlx::migrate::MigrateDatabase;
-use std::sync::Arc;
 use std::time::Duration;
 
 use crate::migrations::run_database_migrations;
@@ -14,47 +13,11 @@ use crate::migrations::run_database_migrations;
 use super::*;
 
 pub struct TestLlmDb {
-    pub db: Option<Arc<LlmDatabase>>,
+    pub db: Option<LlmDatabase>,
     pub connection: Option<sqlx::AnyConnection>,
 }
 
 impl TestLlmDb {
-    pub fn sqlite(background: BackgroundExecutor) -> Self {
-        let url = "sqlite::memory:";
-        let runtime = tokio::runtime::Builder::new_current_thread()
-            .enable_io()
-            .enable_time()
-            .build()
-            .unwrap();
-
-        let mut db = runtime.block_on(async {
-            let mut options = ConnectOptions::new(url);
-            options.max_connections(5);
-            let db = LlmDatabase::new(options, Executor::Deterministic(background))
-                .await
-                .unwrap();
-            let sql = include_str!(concat!(
-                env!("CARGO_MANIFEST_DIR"),
-                "/migrations_llm.sqlite/20240806182921_test_schema.sql"
-            ));
-            db.pool
-                .execute(sea_orm::Statement::from_string(
-                    db.pool.get_database_backend(),
-                    sql,
-                ))
-                .await
-                .unwrap();
-            db
-        });
-
-        db.runtime = Some(runtime);
-
-        Self {
-            db: Some(Arc::new(db)),
-            connection: None,
-        }
-    }
-
     pub fn postgres(background: BackgroundExecutor) -> Self {
         static LOCK: Mutex<()> = Mutex::new(());
 
@@ -91,29 +54,26 @@ impl TestLlmDb {
         db.runtime = Some(runtime);
 
         Self {
-            db: Some(Arc::new(db)),
+            db: Some(db),
             connection: None,
         }
     }
 
-    pub fn db(&self) -> &Arc<LlmDatabase> {
-        self.db.as_ref().unwrap()
+    pub fn db(&mut self) -> &mut LlmDatabase {
+        self.db.as_mut().unwrap()
     }
 }
 
 #[macro_export]
-macro_rules! test_both_llm_dbs {
-    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
-        #[cfg(target_os = "macos")]
+macro_rules! test_llm_db {
+    ($test_name:ident, $postgres_test_name:ident) => {
         #[gpui::test]
         async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
-            let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
-            $test_name(test_db.db()).await;
-        }
+            if !cfg!(target_os = "macos") {
+                return;
+            }
 
-        #[gpui::test]
-        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
-            let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
+            let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
             $test_name(test_db.db()).await;
         }
     };

crates/collab/src/llm/db/tests/provider_tests.rs 🔗

@@ -1,17 +1,15 @@
-use std::sync::Arc;
-
 use pretty_assertions::assert_eq;
+use rpc::LanguageModelProvider;
 
 use crate::llm::db::LlmDatabase;
-use crate::test_both_llm_dbs;
+use crate::test_llm_db;
 
-test_both_llm_dbs!(
+test_llm_db!(
     test_initialize_providers,
-    test_initialize_providers_postgres,
-    test_initialize_providers_sqlite
+    test_initialize_providers_postgres
 );
 
-async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
+async fn test_initialize_providers(db: &mut LlmDatabase) {
     let initial_providers = db.list_providers().await.unwrap();
     assert_eq!(initial_providers, vec![]);
 
@@ -22,9 +20,13 @@ async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
 
     let providers = db.list_providers().await.unwrap();
 
-    let provider_names = providers
-        .into_iter()
-        .map(|provider| provider.name)
-        .collect::<Vec<_>>();
-    assert_eq!(provider_names, vec!["anthropic".to_string()]);
+    assert_eq!(
+        providers,
+        &[
+            LanguageModelProvider::Anthropic,
+            LanguageModelProvider::Google,
+            LanguageModelProvider::OpenAi,
+            LanguageModelProvider::Zed
+        ]
+    )
 }

crates/collab/src/llm/db/tests/usage_tests.rs 🔗

@@ -1,24 +1,120 @@
-use std::sync::Arc;
-
+use crate::{
+    llm::db::{queries::providers::ModelRateLimits, queries::usages::Usage, LlmDatabase},
+    test_llm_db,
+};
+use chrono::{Duration, Utc};
 use pretty_assertions::assert_eq;
 use rpc::LanguageModelProvider;
 
-use crate::llm::db::LlmDatabase;
-use crate::test_both_llm_dbs;
+test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
+
+async fn test_tracking_usage(db: &mut LlmDatabase) {
+    let provider = LanguageModelProvider::Anthropic;
+    let model = "claude-3-5-sonnet";
+
+    db.initialize().await.unwrap();
+    db.insert_models(&[(
+        provider,
+        model.to_string(),
+        ModelRateLimits {
+            max_requests_per_minute: 5,
+            max_tokens_per_minute: 10_000,
+            max_tokens_per_day: 50_000,
+        },
+    )])
+    .await
+    .unwrap();
+
+    let t0 = Utc::now();
+    let user_id = 123;
+
+    let now = t0;
+    db.record_usage(user_id, provider, model, 1000, now)
+        .await
+        .unwrap();
+
+    let now = t0 + Duration::seconds(10);
+    db.record_usage(user_id, provider, model, 2000, now)
+        .await
+        .unwrap();
+
+    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+    assert_eq!(
+        usage,
+        Usage {
+            requests_this_minute: 2,
+            tokens_this_minute: 3000,
+            tokens_this_day: 3000,
+            tokens_this_month: 3000,
+        }
+    );
+
+    let now = t0 + Duration::seconds(60);
+    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+    assert_eq!(
+        usage,
+        Usage {
+            requests_this_minute: 1,
+            tokens_this_minute: 2000,
+            tokens_this_day: 3000,
+            tokens_this_month: 3000,
+        }
+    );
 
-test_both_llm_dbs!(
-    test_find_or_create_usage,
-    test_find_or_create_usage_postgres,
-    test_find_or_create_usage_sqlite
-);
+    let now = t0 + Duration::seconds(60);
+    db.record_usage(user_id, provider, model, 3000, now)
+        .await
+        .unwrap();
 
-async fn test_find_or_create_usage(db: &Arc<LlmDatabase>) {
-    db.initialize_providers().await.unwrap();
+    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+    assert_eq!(
+        usage,
+        Usage {
+            requests_this_minute: 2,
+            tokens_this_minute: 5000,
+            tokens_this_day: 6000,
+            tokens_this_month: 6000,
+        }
+    );
 
-    let usage = db
-        .find_or_create_usage(123, LanguageModelProvider::Anthropic, "claude-3-5-sonnet")
+    let t1 = t0 + Duration::hours(24);
+    let now = t1;
+    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+    assert_eq!(
+        usage,
+        Usage {
+            requests_this_minute: 0,
+            tokens_this_minute: 0,
+            tokens_this_day: 5000,
+            tokens_this_month: 6000,
+        }
+    );
+
+    db.record_usage(user_id, provider, model, 4000, now)
         .await
         .unwrap();
 
-    assert_eq!(usage.user_id, 123);
+    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+    assert_eq!(
+        usage,
+        Usage {
+            requests_this_minute: 1,
+            tokens_this_minute: 4000,
+            tokens_this_day: 9000,
+            tokens_this_month: 10000,
+        }
+    );
+
+    let t2 = t0 + Duration::days(30);
+    let now = t2;
+    let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+    assert_eq!(
+        usage,
+        Usage {
+            requests_this_minute: 0,
+            tokens_this_minute: 0,
+            tokens_this_day: 0,
+            tokens_this_month: 9000,
+        }
+    );
 }

crates/collab/src/main.rs 🔗

@@ -52,10 +52,18 @@ async fn main() -> Result<()> {
         Some("seed") => {
             let config = envy::from_env::<Config>().expect("error loading config");
             let db_options = db::ConnectOptions::new(config.database_url.clone());
+
             let mut db = Database::new(db_options, Executor::Production).await?;
             db.initialize_notification_kinds().await?;
 
-            collab::seed::seed(&config, &db, true).await?;
+            collab::seed::seed(&config, &db, false).await?;
+
+            if let Some(llm_database_url) = config.llm_database_url.clone() {
+                let db_options = db::ConnectOptions::new(llm_database_url);
+                let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
+                db.initialize().await?;
+                collab::llm::db::seed_database(&config, &mut db, true).await?;
+            }
         }
         Some("serve") => {
             let mode = match args.next().as_deref() {

crates/rpc/src/llm.rs 🔗

@@ -1,9 +1,13 @@
 use serde::{Deserialize, Serialize};
+use strum::{Display, EnumIter, EnumString};
 
 pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 
-#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
+#[derive(
+    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
+)]
 #[serde(rename_all = "snake_case")]
+#[strum(serialize_all = "snake_case")]
 pub enum LanguageModelProvider {
     Anthropic,
     OpenAi,