Detailed changes
@@ -85,6 +85,11 @@ spec:
secretKeyRef:
name: database
key: url
+ - name: LLM_DATABASE_URL
+ valueFrom:
+ secretKeyRef:
+ name: llm-database
+ key: url
- name: DATABASE_MAX_CONNECTIONS
value: "${DATABASE_MAX_CONNECTIONS}"
- name: API_TOKEN
@@ -12,7 +12,7 @@ metadata:
spec:
type: LoadBalancer
selector:
- app: postgrest
+ app: nginx
ports:
- name: web
protocol: TCP
@@ -24,17 +24,99 @@ apiVersion: apps/v1
kind: Deployment
metadata:
namespace: ${ZED_KUBE_NAMESPACE}
- name: postgrest
+ name: nginx
+spec:
+ replicas: 1
+ selector:
+ matchLabels:
+ app: nginx
+ template:
+ metadata:
+ labels:
+ app: nginx
+ spec:
+ containers:
+ - name: nginx
+ image: nginx:latest
+ ports:
+ - containerPort: 8080
+ protocol: TCP
+ volumeMounts:
+ - name: nginx-config
+ mountPath: /etc/nginx/nginx.conf
+ subPath: nginx.conf
+ volumes:
+ - name: nginx-config
+ configMap:
+ name: nginx-config
+
+---
+apiVersion: v1
+kind: ConfigMap
+metadata:
+ namespace: ${ZED_KUBE_NAMESPACE}
+ name: nginx-config
+data:
+ nginx.conf: |
+ events {}
+
+ http {
+ server {
+ listen 8080;
+
+ location /app/ {
+ proxy_pass http://postgrest-app:8080/;
+ }
+
+ location /llm/ {
+ proxy_pass http://postgrest-llm:8080/;
+ }
+ }
+ }
+
+---
+apiVersion: v1
+kind: Service
+metadata:
+ namespace: ${ZED_KUBE_NAMESPACE}
+ name: postgrest-app
+spec:
+ selector:
+ app: postgrest-app
+ ports:
+ - protocol: TCP
+ port: 8080
+ targetPort: 8080
+
+---
+apiVersion: v1
+kind: Service
+metadata:
+ namespace: ${ZED_KUBE_NAMESPACE}
+ name: postgrest-llm
+spec:
+ selector:
+ app: postgrest-llm
+ ports:
+ - protocol: TCP
+ port: 8080
+ targetPort: 8080
+---
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ namespace: ${ZED_KUBE_NAMESPACE}
+ name: postgrest-app
spec:
replicas: 1
selector:
matchLabels:
- app: postgrest
+ app: postgrest-app
template:
metadata:
labels:
- app: postgrest
+ app: postgrest-app
spec:
containers:
- name: postgrest
@@ -55,3 +137,39 @@ spec:
secretKeyRef:
name: postgrest
key: jwt_secret
+
+---
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ namespace: ${ZED_KUBE_NAMESPACE}
+ name: postgrest-llm
+spec:
+ replicas: 1
+ selector:
+ matchLabels:
+ app: postgrest-llm
+ template:
+ metadata:
+ labels:
+ app: postgrest-llm
+ spec:
+ containers:
+ - name: postgrest
+ image: "postgrest/postgrest"
+ ports:
+ - containerPort: 8080
+ protocol: TCP
+ env:
+ - name: PGRST_SERVER_PORT
+ value: "8080"
+ - name: PGRST_DB_URI
+ valueFrom:
+ secretKeyRef:
+ name: llm-database
+ key: url
+ - name: PGRST_JWT_SECRET
+ valueFrom:
+ secretKeyRef:
+ name: postgrest
+ key: jwt_secret
@@ -1,32 +0,0 @@
-create table providers (
- id integer primary key autoincrement,
- name text not null
-);
-
-create unique index uix_providers_on_name on providers (name);
-
-create table models (
- id integer primary key autoincrement,
- provider_id integer not null references providers (id) on delete cascade,
- name text not null
-);
-
-create unique index uix_models_on_provider_id_name on models (provider_id, name);
-create index ix_models_on_provider_id on models (provider_id);
-create index ix_models_on_name on models (name);
-
-create table if not exists usages (
- id integer primary key autoincrement,
- user_id integer not null,
- model_id integer not null references models (id) on delete cascade,
- requests_this_minute integer not null default 0,
- tokens_this_minute integer not null default 0,
- requests_this_day integer not null default 0,
- tokens_this_day integer not null default 0,
- requests_this_month integer not null default 0,
- tokens_this_month integer not null default 0
-);
-
-create index ix_usages_on_user_id on usages (user_id);
-create index ix_usages_on_model_id on usages (model_id);
-create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);
@@ -8,7 +8,10 @@ create unique index uix_providers_on_name on providers (name);
create table if not exists models (
id serial primary key,
provider_id integer not null references providers (id) on delete cascade,
- name text not null
+ name text not null,
+ max_requests_per_minute integer not null,
+ max_tokens_per_minute integer not null,
+ max_tokens_per_day integer not null
);
create unique index uix_models_on_provider_id_name on models (provider_id, name);
@@ -1,15 +1,19 @@
+create table usage_measures (
+ id serial primary key,
+ name text not null
+);
+
+create unique index uix_usage_measures_on_name on usage_measures (name);
+
create table if not exists usages (
id serial primary key,
user_id integer not null,
model_id integer not null references models (id) on delete cascade,
- requests_this_minute integer not null default 0,
- tokens_this_minute bigint not null default 0,
- requests_this_day integer not null default 0,
- tokens_this_day bigint not null default 0,
- requests_this_month integer not null default 0,
- tokens_this_month bigint not null default 0
+ measure_id integer not null references usage_measures (id) on delete cascade,
+ timestamp timestamp without time zone not null,
+ buckets bigint[] not null
);
create index ix_usages_on_user_id on usages (user_id);
create index ix_usages_on_model_id on usages (model_id);
-create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);
+create unique index uix_usages_on_user_id_model_id_measure_id on usages (user_id, model_id, measure_id);
@@ -2,24 +2,25 @@ mod authorization;
pub mod db;
mod token;
-use crate::api::CloudflareIpCountryHeader;
-use crate::llm::authorization::authorize_access_to_language_model;
-use crate::llm::db::LlmDatabase;
-use crate::{executor::Executor, Config, Error, Result};
+use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
use anyhow::{anyhow, Context as _};
-use axum::TypedHeader;
+use authorization::authorize_access_to_language_model;
use axum::{
body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::post,
- Extension, Json, Router,
+ Extension, Json, Router, TypedHeader,
};
+use chrono::{DateTime, Duration, Utc};
+use db::{ActiveUserCount, LlmDatabase};
use futures::StreamExt as _;
use http_client::IsahcHttpClient;
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc;
+use tokio::sync::RwLock;
+use util::ResultExt;
pub use token::*;
@@ -28,8 +29,11 @@ pub struct LlmState {
pub executor: Executor,
pub db: Option<Arc<LlmDatabase>>,
pub http_client: IsahcHttpClient,
+ active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
}
+const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
+
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
// TODO: This is temporary until we have the LLM database stood up.
@@ -44,7 +48,8 @@ impl LlmState {
let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
- let db = LlmDatabase::new(db_options, executor.clone()).await?;
+ let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
+ db.initialize().await?;
Some(Arc::new(db))
} else {
@@ -57,15 +62,41 @@ impl LlmState {
.build()
.context("failed to construct http client")?;
+ let initial_active_user_count = if let Some(db) = &db {
+ Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
+ } else {
+ None
+ };
+
let this = Self {
config,
executor,
db,
http_client,
+ active_user_count: RwLock::new(initial_active_user_count),
};
Ok(Arc::new(this))
}
+
+ pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
+ let now = Utc::now();
+
+ if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
+ if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
+ return Ok(*count);
+ }
+ }
+
+ if let Some(db) = &self.db {
+ let mut cache = self.active_user_count.write().await;
+ let new_count = db.get_active_user_count(now).await?;
+ *cache = Some((now, new_count));
+ Ok(new_count)
+ } else {
+ Ok(ActiveUserCount::default())
+ }
+ }
}
pub fn routes() -> Router<(), Body> {
@@ -122,14 +153,22 @@ async fn perform_completion(
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
+ let model = normalize_model_name(params.provider, params.model);
+
authorize_access_to_language_model(
&state.config,
&claims,
country_code_header.map(|header| header.to_string()),
params.provider,
- ¶ms.model,
+ &model,
)?;
+ let user_id = claims.user_id as i32;
+
+ if state.db.is_some() {
+ check_usage_limit(&state, params.provider, &model, &claims).await?;
+ }
+
match params.provider {
LanguageModelProvider::Anthropic => {
let api_key = state
@@ -160,9 +199,31 @@ async fn perform_completion(
)
.await?;
- let stream = chunks.map(|event| {
+ let mut recorder = state.db.clone().map(|db| UsageRecorder {
+ db,
+ executor: state.executor.clone(),
+ user_id,
+ provider: params.provider,
+ model,
+ token_count: 0,
+ });
+
+ let stream = chunks.map(move |event| {
let mut buffer = Vec::new();
event.map(|chunk| {
+ match &chunk {
+ anthropic::Event::MessageStart {
+ message: anthropic::Response { usage, .. },
+ }
+ | anthropic::Event::MessageDelta { usage, .. } => {
+ if let Some(recorder) = &mut recorder {
+ recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
+ recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
+ }
+ }
+ _ => {}
+ }
+
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
@@ -259,3 +320,102 @@ async fn perform_completion(
}
}
}
+
+fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
+ match provider {
+ LanguageModelProvider::Anthropic => {
+ for prefix in &[
+ "claude-3-5-sonnet",
+ "claude-3-haiku",
+ "claude-3-opus",
+ "claude-3-sonnet",
+ ] {
+ if name.starts_with(prefix) {
+ return prefix.to_string();
+ }
+ }
+ }
+ LanguageModelProvider::OpenAi => {}
+ LanguageModelProvider::Google => {}
+ LanguageModelProvider::Zed => {}
+ }
+
+ name
+}
+
+async fn check_usage_limit(
+ state: &Arc<LlmState>,
+ provider: LanguageModelProvider,
+ model_name: &str,
+ claims: &LlmTokenClaims,
+) -> Result<()> {
+ let db = state
+ .db
+ .as_ref()
+ .ok_or_else(|| anyhow!("LLM database not configured"))?;
+ let model = db.model(provider, model_name)?;
+ let usage = db
+ .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
+ .await?;
+
+ let active_users = state.get_active_user_count().await?;
+
+ let per_user_max_requests_per_minute =
+ model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
+ let per_user_max_tokens_per_minute =
+ model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
+ let per_user_max_tokens_per_day =
+ model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
+
+ let checks = [
+ (
+ usage.requests_this_minute,
+ per_user_max_requests_per_minute,
+ "requests per minute",
+ ),
+ (
+ usage.tokens_this_minute,
+ per_user_max_tokens_per_minute,
+ "tokens per minute",
+ ),
+ (
+ usage.tokens_this_day,
+ per_user_max_tokens_per_day,
+ "tokens per day",
+ ),
+ ];
+
+ for (usage, limit, resource) in checks {
+ if usage > limit {
+ return Err(Error::http(
+ StatusCode::TOO_MANY_REQUESTS,
+ format!("Rate limit exceeded. Maximum {} reached.", resource),
+ ));
+ }
+ }
+
+ Ok(())
+}
+struct UsageRecorder {
+ db: Arc<LlmDatabase>,
+ executor: Executor,
+ user_id: i32,
+ provider: LanguageModelProvider,
+ model: String,
+ token_count: usize,
+}
+
+impl Drop for UsageRecorder {
+ fn drop(&mut self) {
+ let db = self.db.clone();
+ let user_id = self.user_id;
+ let provider = self.provider;
+ let model = std::mem::take(&mut self.model);
+ let token_count = self.token_count;
+ self.executor.spawn_detached(async move {
+ db.record_usage(user_id, provider, &model, token_count, Utc::now())
+ .await
+ .log_err();
+ })
+ }
+}
@@ -1,20 +1,26 @@
mod ids;
mod queries;
+mod seed;
mod tables;
#[cfg(test)]
mod tests;
+use collections::HashMap;
pub use ids::*;
+use rpc::LanguageModelProvider;
+pub use seed::*;
pub use tables::*;
#[cfg(test)]
pub use tests::TestLlmDb;
+use usage_measure::UsageMeasure;
use std::future::Future;
use std::sync::Arc;
use anyhow::anyhow;
+pub use queries::usages::ActiveUserCount;
use sea_orm::prelude::*;
pub use sea_orm::ConnectOptions;
use sea_orm::{
@@ -31,6 +37,9 @@ pub struct LlmDatabase {
pool: DatabaseConnection,
#[allow(unused)]
executor: Executor,
+ provider_ids: HashMap<LanguageModelProvider, ProviderId>,
+ models: HashMap<(LanguageModelProvider, String), model::Model>,
+ usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
#[cfg(test)]
runtime: Option<tokio::runtime::Runtime>,
}
@@ -43,11 +52,28 @@ impl LlmDatabase {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
executor,
+ provider_ids: HashMap::default(),
+ models: HashMap::default(),
+ usage_measure_ids: HashMap::default(),
#[cfg(test)]
runtime: None,
})
}
+ pub async fn initialize(&mut self) -> Result<()> {
+ self.initialize_providers().await?;
+ self.initialize_models().await?;
+ self.initialize_usage_measures().await?;
+ Ok(())
+ }
+
+ pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
+ Ok(self
+ .models
+ .get(&(provider, name.to_string()))
+ .ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
+ }
+
pub fn options(&self) -> &ConnectOptions {
&self.options
}
@@ -6,3 +6,4 @@ use crate::id_type;
id_type!(ModelId);
id_type!(ProviderId);
id_type!(UsageId);
+id_type!(UsageMeasureId);
@@ -1,66 +1,115 @@
-use sea_orm::sea_query::OnConflict;
+use super::*;
use sea_orm::QueryOrder;
+use std::str::FromStr;
+use strum::IntoEnumIterator as _;
-use super::*;
+pub struct ModelRateLimits {
+ pub max_requests_per_minute: i32,
+ pub max_tokens_per_minute: i32,
+ pub max_tokens_per_day: i32,
+}
impl LlmDatabase {
- pub async fn initialize_providers(&self) -> Result<()> {
- self.transaction(|tx| async move {
- let providers_and_models = vec![
- ("anthropic", "claude-3-5-sonnet"),
- ("anthropic", "claude-3-opus"),
- ("anthropic", "claude-3-sonnet"),
- ("anthropic", "claude-3-haiku"),
- ];
+ pub async fn initialize_providers(&mut self) -> Result<()> {
+ self.provider_ids = self
+ .transaction(|tx| async move {
+ let existing_providers = provider::Entity::find().all(&*tx).await?;
- for (provider_name, model_name) in providers_and_models {
- let insert_provider = provider::Entity::insert(provider::ActiveModel {
- name: ActiveValue::set(provider_name.to_owned()),
- ..Default::default()
- })
- .on_conflict(
- OnConflict::columns([provider::Column::Name])
- .update_column(provider::Column::Name)
- .to_owned(),
- );
+ let mut new_providers = LanguageModelProvider::iter()
+ .filter(|provider| {
+ !existing_providers
+ .iter()
+ .any(|p| p.name == provider.to_string())
+ })
+ .map(|provider| provider::ActiveModel {
+ name: ActiveValue::set(provider.to_string()),
+ ..Default::default()
+ })
+ .peekable();
- let provider = if tx.support_returning() {
- insert_provider.exec_with_returning(&*tx).await?
- } else {
- insert_provider.exec_without_returning(&*tx).await?;
- provider::Entity::find()
- .filter(provider::Column::Name.eq(provider_name))
- .one(&*tx)
- .await?
- .ok_or_else(|| anyhow!("failed to insert provider"))?
- };
+ if new_providers.peek().is_some() {
+ provider::Entity::insert_many(new_providers)
+ .exec(&*tx)
+ .await?;
+ }
- model::Entity::insert(model::ActiveModel {
- provider_id: ActiveValue::set(provider.id),
- name: ActiveValue::set(model_name.to_owned()),
- ..Default::default()
- })
- .on_conflict(
- OnConflict::columns([model::Column::ProviderId, model::Column::Name])
- .update_column(model::Column::Name)
- .to_owned(),
- )
- .exec_without_returning(&*tx)
- .await?;
- }
+ let all_providers: HashMap<_, _> = provider::Entity::find()
+ .all(&*tx)
+ .await?
+ .iter()
+ .filter_map(|provider| {
+ LanguageModelProvider::from_str(&provider.name)
+ .ok()
+ .map(|p| (p, provider.id))
+ })
+ .collect();
+ Ok(all_providers)
+ })
+ .await?;
+ Ok(())
+ }
+
+ pub async fn initialize_models(&mut self) -> Result<()> {
+ let all_provider_ids = &self.provider_ids;
+ self.models = self
+ .transaction(|tx| async move {
+ let all_models: HashMap<_, _> = model::Entity::find()
+ .all(&*tx)
+ .await?
+ .into_iter()
+ .filter_map(|model| {
+ let provider = all_provider_ids.iter().find_map(|(provider, id)| {
+ if *id == model.provider_id {
+ Some(provider)
+ } else {
+ None
+ }
+ })?;
+ Some(((*provider, model.name.clone()), model))
+ })
+ .collect();
+ Ok(all_models)
+ })
+ .await?;
+ Ok(())
+ }
+
+ pub async fn insert_models(
+ &mut self,
+ models: &[(LanguageModelProvider, String, ModelRateLimits)],
+ ) -> Result<()> {
+ let all_provider_ids = &self.provider_ids;
+ self.transaction(|tx| async move {
+ model::Entity::insert_many(models.into_iter().map(|(provider, name, rate_limits)| {
+ let provider_id = all_provider_ids[&provider];
+ model::ActiveModel {
+ provider_id: ActiveValue::set(provider_id),
+ name: ActiveValue::set(name.clone()),
+ max_requests_per_minute: ActiveValue::set(rate_limits.max_requests_per_minute),
+ max_tokens_per_minute: ActiveValue::set(rate_limits.max_tokens_per_minute),
+ max_tokens_per_day: ActiveValue::set(rate_limits.max_tokens_per_day),
+ ..Default::default()
+ }
+ }))
+ .exec_without_returning(&*tx)
+ .await?;
Ok(())
})
- .await
+ .await?;
+ self.initialize_models().await
}
/// Returns the list of LLM providers.
- pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
+ pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
self.transaction(|tx| async move {
Ok(provider::Entity::find()
.order_by_asc(provider::Column::Name)
.all(&*tx)
- .await?)
+ .await?
+ .into_iter()
+ .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
+ .collect())
})
.await
}
@@ -1,57 +1,318 @@
+use chrono::Duration;
use rpc::LanguageModelProvider;
+use sea_orm::QuerySelect;
+use std::{iter, str::FromStr};
+use strum::IntoEnumIterator as _;
use super::*;
+#[derive(Debug, PartialEq, Clone, Copy)]
+pub struct Usage {
+ pub requests_this_minute: usize,
+ pub tokens_this_minute: usize,
+ pub tokens_this_day: usize,
+ pub tokens_this_month: usize,
+}
+
+#[derive(Clone, Copy, Debug, Default)]
+pub struct ActiveUserCount {
+ pub users_in_recent_minutes: usize,
+ pub users_in_recent_days: usize,
+}
+
impl LlmDatabase {
- pub async fn find_or_create_usage(
+ pub async fn initialize_usage_measures(&mut self) -> Result<()> {
+ let all_measures = self
+ .transaction(|tx| async move {
+ let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
+
+ let new_measures = UsageMeasure::iter()
+ .filter(|measure| {
+ !existing_measures
+ .iter()
+ .any(|m| m.name == measure.to_string())
+ })
+ .map(|measure| usage_measure::ActiveModel {
+ name: ActiveValue::set(measure.to_string()),
+ ..Default::default()
+ })
+ .collect::<Vec<_>>();
+
+ if !new_measures.is_empty() {
+ usage_measure::Entity::insert_many(new_measures)
+ .exec(&*tx)
+ .await?;
+ }
+
+ Ok(usage_measure::Entity::find().all(&*tx).await?)
+ })
+ .await?;
+
+ self.usage_measure_ids = all_measures
+ .into_iter()
+ .filter_map(|measure| {
+ UsageMeasure::from_str(&measure.name)
+ .ok()
+ .map(|um| (um, measure.id))
+ })
+ .collect();
+ Ok(())
+ }
+
+ pub async fn get_usage(
&self,
user_id: i32,
provider: LanguageModelProvider,
model_name: &str,
- ) -> Result<usage::Model> {
+ now: DateTimeUtc,
+ ) -> Result<Usage> {
self.transaction(|tx| async move {
- let provider_name = match provider {
- LanguageModelProvider::Anthropic => "anthropic",
- LanguageModelProvider::OpenAi => "open_ai",
- LanguageModelProvider::Google => "google",
- LanguageModelProvider::Zed => "zed",
- };
-
- let model = model::Entity::find()
- .inner_join(provider::Entity)
+ let model = self
+ .models
+ .get(&(provider, model_name.to_string()))
+ .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
+
+ let usages = usage::Entity::find()
.filter(
- provider::Column::Name
- .eq(provider_name)
- .and(model::Column::Name.eq(model_name)),
+ usage::Column::UserId
+ .eq(user_id)
+ .and(usage::Column::ModelId.eq(model.id)),
)
- .one(&*tx)
- .await?
- // TODO: Create the model, if one doesn't exist.
- .ok_or_else(|| anyhow!("no model found for {provider_name}:{model_name}"))?;
- let model_id = model.id;
+ .all(&*tx)
+ .await?;
+
+ let requests_this_minute =
+ self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
+ let tokens_this_minute =
+ self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
+ let tokens_this_day =
+ self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
+ let tokens_this_month =
+ self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMonth)?;
- let existing_usage = usage::Entity::find()
+ Ok(Usage {
+ requests_this_minute,
+ tokens_this_minute,
+ tokens_this_day,
+ tokens_this_month,
+ })
+ })
+ .await
+ }
+
+ pub async fn record_usage(
+ &self,
+ user_id: i32,
+ provider: LanguageModelProvider,
+ model_name: &str,
+ token_count: usize,
+ now: DateTimeUtc,
+ ) -> Result<()> {
+ self.transaction(|tx| async move {
+ let model = self.model(provider, model_name)?;
+
+ let usages = usage::Entity::find()
.filter(
usage::Column::UserId
.eq(user_id)
- .and(usage::Column::ModelId.eq(model_id)),
+ .and(usage::Column::ModelId.eq(model.id)),
)
- .one(&*tx)
+ .all(&*tx)
.await?;
- if let Some(usage) = existing_usage {
- return Ok(usage);
- }
- let usage = usage::Entity::insert(usage::ActiveModel {
- user_id: ActiveValue::set(user_id),
- model_id: ActiveValue::set(model_id),
- ..Default::default()
- })
- .exec_with_returning(&*tx)
+ self.update_usage_for_measure(
+ user_id,
+ model.id,
+ &usages,
+ UsageMeasure::RequestsPerMinute,
+ now,
+ 1,
+ &tx,
+ )
.await?;
+ self.update_usage_for_measure(
+ user_id,
+ model.id,
+ &usages,
+ UsageMeasure::TokensPerMinute,
+ now,
+ token_count,
+ &tx,
+ )
+ .await?;
+ self.update_usage_for_measure(
+ user_id,
+ model.id,
+ &usages,
+ UsageMeasure::TokensPerDay,
+ now,
+ token_count,
+ &tx,
+ )
+ .await?;
+ self.update_usage_for_measure(
+ user_id,
+ model.id,
+ &usages,
+ UsageMeasure::TokensPerMonth,
+ now,
+ token_count,
+ &tx,
+ )
+ .await?;
+
+ Ok(())
+ })
+ .await
+ }
+
+ pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result<ActiveUserCount> {
+ self.transaction(|tx| async move {
+ let minute_since = now - Duration::minutes(5);
+ let day_since = now - Duration::days(5);
- Ok(usage)
+ let users_in_recent_minutes = usage::Entity::find()
+ .filter(usage::Column::Timestamp.gte(minute_since.naive_utc()))
+ .group_by(usage::Column::UserId)
+ .count(&*tx)
+ .await? as usize;
+
+ let users_in_recent_days = usage::Entity::find()
+ .filter(usage::Column::Timestamp.gte(day_since.naive_utc()))
+ .group_by(usage::Column::UserId)
+ .count(&*tx)
+ .await? as usize;
+
+ Ok(ActiveUserCount {
+ users_in_recent_minutes,
+ users_in_recent_days,
+ })
})
.await
}
+
+ #[allow(clippy::too_many_arguments)]
+ async fn update_usage_for_measure(
+ &self,
+ user_id: i32,
+ model_id: ModelId,
+ usages: &[usage::Model],
+ usage_measure: UsageMeasure,
+ now: DateTimeUtc,
+ usage_to_add: usize,
+ tx: &DatabaseTransaction,
+ ) -> Result<()> {
+ let now = now.naive_utc();
+ let measure_id = *self
+ .usage_measure_ids
+ .get(&usage_measure)
+ .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
+
+ let mut id = None;
+ let mut timestamp = now;
+ let mut buckets = vec![0_i64];
+
+ if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
+ id = Some(old_usage.id);
+ let (live_buckets, buckets_since) =
+ Self::get_live_buckets(old_usage, now, usage_measure);
+ if !live_buckets.is_empty() {
+ buckets.clear();
+ buckets.extend_from_slice(live_buckets);
+ buckets.extend(iter::repeat(0).take(buckets_since));
+ timestamp =
+ old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
+ }
+ }
+
+ *buckets.last_mut().unwrap() += usage_to_add as i64;
+
+ let mut model = usage::ActiveModel {
+ user_id: ActiveValue::set(user_id),
+ model_id: ActiveValue::set(model_id),
+ measure_id: ActiveValue::set(measure_id),
+ timestamp: ActiveValue::set(timestamp),
+ buckets: ActiveValue::set(buckets),
+ ..Default::default()
+ };
+
+ if let Some(id) = id {
+ model.id = ActiveValue::unchanged(id);
+ model.update(tx).await?;
+ } else {
+ usage::Entity::insert(model)
+ .exec_without_returning(tx)
+ .await?;
+ }
+
+ Ok(())
+ }
+
+ fn get_usage_for_measure(
+ &self,
+ usages: &[usage::Model],
+ now: DateTimeUtc,
+ usage_measure: UsageMeasure,
+ ) -> Result<usize> {
+ let now = now.naive_utc();
+ let measure_id = *self
+ .usage_measure_ids
+ .get(&usage_measure)
+ .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
+ let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
+ return Ok(0);
+ };
+
+ let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
+ Ok(live_buckets.iter().sum::<i64>() as _)
+ }
+
+ fn get_live_buckets(
+ usage: &usage::Model,
+ now: chrono::NaiveDateTime,
+ measure: UsageMeasure,
+ ) -> (&[i64], usize) {
+ let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
+ let buckets_since_usage =
+ seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
+ let buckets_since_usage = buckets_since_usage.ceil() as usize;
+ let mut live_buckets = &[] as &[i64];
+ if buckets_since_usage < measure.bucket_count() {
+ let expired_bucket_count =
+ (usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
+ live_buckets = &usage.buckets[expired_bucket_count..];
+ while live_buckets.first() == Some(&0) {
+ live_buckets = &live_buckets[1..];
+ }
+ }
+ (live_buckets, buckets_since_usage)
+ }
+}
+
+const MINUTE_BUCKET_COUNT: usize = 12;
+const DAY_BUCKET_COUNT: usize = 48;
+const MONTH_BUCKET_COUNT: usize = 30;
+
+impl UsageMeasure {
+ fn bucket_count(&self) -> usize {
+ match self {
+ UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
+ UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
+ UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
+ UsageMeasure::TokensPerMonth => MONTH_BUCKET_COUNT,
+ }
+ }
+
+ fn total_duration(&self) -> Duration {
+ match self {
+ UsageMeasure::RequestsPerMinute => Duration::minutes(1),
+ UsageMeasure::TokensPerMinute => Duration::minutes(1),
+ UsageMeasure::TokensPerDay => Duration::hours(24),
+ UsageMeasure::TokensPerMonth => Duration::days(30),
+ }
+ }
+
+ fn bucket_duration(&self) -> Duration {
+ self.total_duration() / self.bucket_count() as i32
+ }
}
@@ -0,0 +1,45 @@
+use super::*;
+use crate::{Config, Result};
+use queries::providers::ModelRateLimits;
+
+pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
+ db.insert_models(&[
+ (
+ LanguageModelProvider::Anthropic,
+ "claude-3-5-sonnet".into(),
+ ModelRateLimits {
+ max_requests_per_minute: 5,
+ max_tokens_per_minute: 20_000,
+ max_tokens_per_day: 300_000,
+ },
+ ),
+ (
+ LanguageModelProvider::Anthropic,
+ "claude-3-opus".into(),
+ ModelRateLimits {
+ max_requests_per_minute: 5,
+ max_tokens_per_minute: 10_000,
+ max_tokens_per_day: 300_000,
+ },
+ ),
+ (
+ LanguageModelProvider::Anthropic,
+ "claude-3-sonnet".into(),
+ ModelRateLimits {
+ max_requests_per_minute: 5,
+ max_tokens_per_minute: 20_000,
+ max_tokens_per_day: 300_000,
+ },
+ ),
+ (
+ LanguageModelProvider::Anthropic,
+ "claude-3-haiku".into(),
+ ModelRateLimits {
+ max_requests_per_minute: 5,
+ max_tokens_per_minute: 25_000,
+ max_tokens_per_day: 300_000,
+ },
+ ),
+ ])
+ .await
+}
@@ -1,3 +1,4 @@
pub mod model;
pub mod provider;
pub mod usage;
+pub mod usage_measure;
@@ -10,6 +10,9 @@ pub struct Model {
pub id: ModelId,
pub provider_id: ProviderId,
pub name: String,
+ pub max_requests_per_minute: i32,
+ pub max_tokens_per_minute: i32,
+ pub max_tokens_per_day: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -1,6 +1,5 @@
-use sea_orm::entity::prelude::*;
-
use crate::llm::db::ProviderId;
+use sea_orm::entity::prelude::*;
/// An LLM provider.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
@@ -1,24 +1,20 @@
+use crate::llm::db::{ModelId, UsageId, UsageMeasureId};
use sea_orm::entity::prelude::*;
-use crate::llm::db::ModelId;
-
/// An LLM usage record.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "usages")]
pub struct Model {
#[sea_orm(primary_key)]
- pub id: i32,
+ pub id: UsageId,
/// The ID of the Zed user.
///
/// Corresponds to the `users` table in the primary collab database.
pub user_id: i32,
pub model_id: ModelId,
- pub requests_this_minute: i32,
- pub tokens_this_minute: i64,
- pub requests_this_day: i32,
- pub tokens_this_day: i64,
- pub requests_this_month: i32,
- pub tokens_this_month: i64,
+ pub measure_id: UsageMeasureId,
+ pub timestamp: DateTime,
+ pub buckets: Vec<i64>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -29,6 +25,12 @@ pub enum Relation {
to = "super::model::Column::Id"
)]
Model,
+ #[sea_orm(
+ belongs_to = "super::usage_measure::Entity",
+ from = "Column::MeasureId",
+ to = "super::usage_measure::Column::Id"
+ )]
+ UsageMeasure,
}
impl Related<super::model::Entity> for Entity {
@@ -37,4 +39,10 @@ impl Related<super::model::Entity> for Entity {
}
}
+impl Related<super::usage_measure::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::UsageMeasure.def()
+ }
+}
+
impl ActiveModelBehavior for ActiveModel {}
@@ -0,0 +1,35 @@
+use crate::llm::db::UsageMeasureId;
+use sea_orm::entity::prelude::*;
+
+#[derive(
+ Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
+)]
+#[strum(serialize_all = "snake_case")]
+pub enum UsageMeasure {
+ RequestsPerMinute,
+ TokensPerMinute,
+ TokensPerDay,
+ TokensPerMonth,
+}
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "usage_measures")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: UsageMeasureId,
+ pub name: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(has_many = "super::usage::Entity")]
+ Usages,
+}
+
+impl Related<super::usage::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Usages.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -6,7 +6,6 @@ use parking_lot::Mutex;
use rand::prelude::*;
use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase;
-use std::sync::Arc;
use std::time::Duration;
use crate::migrations::run_database_migrations;
@@ -14,47 +13,11 @@ use crate::migrations::run_database_migrations;
use super::*;
pub struct TestLlmDb {
- pub db: Option<Arc<LlmDatabase>>,
+ pub db: Option<LlmDatabase>,
pub connection: Option<sqlx::AnyConnection>,
}
impl TestLlmDb {
- pub fn sqlite(background: BackgroundExecutor) -> Self {
- let url = "sqlite::memory:";
- let runtime = tokio::runtime::Builder::new_current_thread()
- .enable_io()
- .enable_time()
- .build()
- .unwrap();
-
- let mut db = runtime.block_on(async {
- let mut options = ConnectOptions::new(url);
- options.max_connections(5);
- let db = LlmDatabase::new(options, Executor::Deterministic(background))
- .await
- .unwrap();
- let sql = include_str!(concat!(
- env!("CARGO_MANIFEST_DIR"),
- "/migrations_llm.sqlite/20240806182921_test_schema.sql"
- ));
- db.pool
- .execute(sea_orm::Statement::from_string(
- db.pool.get_database_backend(),
- sql,
- ))
- .await
- .unwrap();
- db
- });
-
- db.runtime = Some(runtime);
-
- Self {
- db: Some(Arc::new(db)),
- connection: None,
- }
- }
-
pub fn postgres(background: BackgroundExecutor) -> Self {
static LOCK: Mutex<()> = Mutex::new(());
@@ -91,29 +54,26 @@ impl TestLlmDb {
db.runtime = Some(runtime);
Self {
- db: Some(Arc::new(db)),
+ db: Some(db),
connection: None,
}
}
- pub fn db(&self) -> &Arc<LlmDatabase> {
- self.db.as_ref().unwrap()
+ pub fn db(&mut self) -> &mut LlmDatabase {
+ self.db.as_mut().unwrap()
}
}
#[macro_export]
-macro_rules! test_both_llm_dbs {
- ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
- #[cfg(target_os = "macos")]
+macro_rules! test_llm_db {
+ ($test_name:ident, $postgres_test_name:ident) => {
#[gpui::test]
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
- let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
- $test_name(test_db.db()).await;
- }
+ if !cfg!(target_os = "macos") {
+ return;
+ }
- #[gpui::test]
- async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
- let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
+ let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
$test_name(test_db.db()).await;
}
};
@@ -1,17 +1,15 @@
-use std::sync::Arc;
-
use pretty_assertions::assert_eq;
+use rpc::LanguageModelProvider;
use crate::llm::db::LlmDatabase;
-use crate::test_both_llm_dbs;
+use crate::test_llm_db;
-test_both_llm_dbs!(
+test_llm_db!(
test_initialize_providers,
- test_initialize_providers_postgres,
- test_initialize_providers_sqlite
+ test_initialize_providers_postgres
);
-async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
+async fn test_initialize_providers(db: &mut LlmDatabase) {
let initial_providers = db.list_providers().await.unwrap();
assert_eq!(initial_providers, vec![]);
@@ -22,9 +20,13 @@ async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
let providers = db.list_providers().await.unwrap();
- let provider_names = providers
- .into_iter()
- .map(|provider| provider.name)
- .collect::<Vec<_>>();
- assert_eq!(provider_names, vec!["anthropic".to_string()]);
+ assert_eq!(
+ providers,
+ &[
+ LanguageModelProvider::Anthropic,
+ LanguageModelProvider::Google,
+ LanguageModelProvider::OpenAi,
+ LanguageModelProvider::Zed
+ ]
+ )
}
@@ -1,24 +1,120 @@
-use std::sync::Arc;
-
+use crate::{
+ llm::db::{queries::providers::ModelRateLimits, queries::usages::Usage, LlmDatabase},
+ test_llm_db,
+};
+use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
-use crate::llm::db::LlmDatabase;
-use crate::test_both_llm_dbs;
+test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
+
+async fn test_tracking_usage(db: &mut LlmDatabase) {
+ let provider = LanguageModelProvider::Anthropic;
+ let model = "claude-3-5-sonnet";
+
+ db.initialize().await.unwrap();
+ db.insert_models(&[(
+ provider,
+ model.to_string(),
+ ModelRateLimits {
+ max_requests_per_minute: 5,
+ max_tokens_per_minute: 10_000,
+ max_tokens_per_day: 50_000,
+ },
+ )])
+ .await
+ .unwrap();
+
+ let t0 = Utc::now();
+ let user_id = 123;
+
+ let now = t0;
+ db.record_usage(user_id, provider, model, 1000, now)
+ .await
+ .unwrap();
+
+ let now = t0 + Duration::seconds(10);
+ db.record_usage(user_id, provider, model, 2000, now)
+ .await
+ .unwrap();
+
+ let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+ assert_eq!(
+ usage,
+ Usage {
+ requests_this_minute: 2,
+ tokens_this_minute: 3000,
+ tokens_this_day: 3000,
+ tokens_this_month: 3000,
+ }
+ );
+
+ let now = t0 + Duration::seconds(60);
+ let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+ assert_eq!(
+ usage,
+ Usage {
+ requests_this_minute: 1,
+ tokens_this_minute: 2000,
+ tokens_this_day: 3000,
+ tokens_this_month: 3000,
+ }
+ );
-test_both_llm_dbs!(
- test_find_or_create_usage,
- test_find_or_create_usage_postgres,
- test_find_or_create_usage_sqlite
-);
+ let now = t0 + Duration::seconds(60);
+ db.record_usage(user_id, provider, model, 3000, now)
+ .await
+ .unwrap();
-async fn test_find_or_create_usage(db: &Arc<LlmDatabase>) {
- db.initialize_providers().await.unwrap();
+ let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+ assert_eq!(
+ usage,
+ Usage {
+ requests_this_minute: 2,
+ tokens_this_minute: 5000,
+ tokens_this_day: 6000,
+ tokens_this_month: 6000,
+ }
+ );
- let usage = db
- .find_or_create_usage(123, LanguageModelProvider::Anthropic, "claude-3-5-sonnet")
+ let t1 = t0 + Duration::hours(24);
+ let now = t1;
+ let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+ assert_eq!(
+ usage,
+ Usage {
+ requests_this_minute: 0,
+ tokens_this_minute: 0,
+ tokens_this_day: 5000,
+ tokens_this_month: 6000,
+ }
+ );
+
+ db.record_usage(user_id, provider, model, 4000, now)
.await
.unwrap();
- assert_eq!(usage.user_id, 123);
+ let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+ assert_eq!(
+ usage,
+ Usage {
+ requests_this_minute: 1,
+ tokens_this_minute: 4000,
+ tokens_this_day: 9000,
+ tokens_this_month: 10000,
+ }
+ );
+
+ let t2 = t0 + Duration::days(30);
+ let now = t2;
+ let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
+ assert_eq!(
+ usage,
+ Usage {
+ requests_this_minute: 0,
+ tokens_this_minute: 0,
+ tokens_this_day: 0,
+ tokens_this_month: 9000,
+ }
+ );
}
@@ -52,10 +52,18 @@ async fn main() -> Result<()> {
Some("seed") => {
let config = envy::from_env::<Config>().expect("error loading config");
let db_options = db::ConnectOptions::new(config.database_url.clone());
+
let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?;
- collab::seed::seed(&config, &db, true).await?;
+ collab::seed::seed(&config, &db, false).await?;
+
+ if let Some(llm_database_url) = config.llm_database_url.clone() {
+ let db_options = db::ConnectOptions::new(llm_database_url);
+ let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
+ db.initialize().await?;
+ collab::llm::db::seed_database(&config, &mut db, true).await?;
+ }
}
Some("serve") => {
let mode = match args.next().as_deref() {
@@ -1,9 +1,13 @@
use serde::{Deserialize, Serialize};
+use strum::{Display, EnumIter, EnumString};
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
-#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
+#[derive(
+ Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
+)]
#[serde(rename_all = "snake_case")]
+#[strum(serialize_all = "snake_case")]
pub enum LanguageModelProvider {
Anthropic,
OpenAi,