diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 8e66a1c6ba6559b6f86f0ec187e3b8713538a693..e7d489e837dc0eff8b3da8a3560b3c73d51bff1a 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -18,7 +18,7 @@ use axum::{ Extension, Json, Router, TypedHeader, }; use chrono::{DateTime, Duration, Utc}; -use db::{ActiveUserCount, LlmDatabase}; +use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; use futures::{Stream, StreamExt as _}; use http_client::IsahcHttpClient; use rpc::{ @@ -29,7 +29,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use telemetry::{report_llm_usage, LlmUsageEventRow}; +use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow}; use tokio::sync::RwLock; use util::ResultExt; @@ -401,38 +401,75 @@ async fn check_usage_limit( let active_users = state.get_active_user_count().await?; + let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1); + let users_in_recent_days = active_users.users_in_recent_days.max(1); + let per_user_max_requests_per_minute = - model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1); + model.max_requests_per_minute as usize / users_in_recent_minutes; let per_user_max_tokens_per_minute = - model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1); - let per_user_max_tokens_per_day = - model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1); + model.max_tokens_per_minute as usize / users_in_recent_minutes; + let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days; let checks = [ ( usage.requests_this_minute, per_user_max_requests_per_minute, - "requests per minute", + UsageMeasure::RequestsPerMinute, ), ( usage.tokens_this_minute, per_user_max_tokens_per_minute, - "tokens per minute", + UsageMeasure::TokensPerMinute, ), ( usage.tokens_this_day, per_user_max_tokens_per_day, - "tokens per day", + UsageMeasure::TokensPerDay, ), ]; - for (usage, limit, resource) in checks { + for (used, limit, usage_measure) in checks { // Temporarily bypass rate-limiting for staff members. if claims.is_staff { continue; } - if usage > limit { + if used > limit { + let resource = match usage_measure { + UsageMeasure::RequestsPerMinute => "requests_per_minute", + UsageMeasure::TokensPerMinute => "tokens_per_minute", + UsageMeasure::TokensPerDay => "tokens_per_day", + _ => "", + }; + + if let Some(client) = state.clickhouse_client.as_ref() { + report_llm_rate_limit( + client, + LlmRateLimitEventRow { + time: Utc::now().timestamp_millis(), + user_id: claims.user_id as i32, + is_staff: claims.is_staff, + plan: match claims.plan { + Plan::Free => "free".to_string(), + Plan::ZedPro => "zed_pro".to_string(), + }, + model: model.name.clone(), + provider: provider.to_string(), + usage_measure: resource.to_string(), + requests_this_minute: usage.requests_this_minute as u64, + tokens_this_minute: usage.tokens_this_minute as u64, + tokens_this_day: usage.tokens_this_day as u64, + users_in_recent_minutes: users_in_recent_minutes as u64, + users_in_recent_days: users_in_recent_days as u64, + max_requests_per_minute: per_user_max_requests_per_minute as u64, + max_tokens_per_minute: per_user_max_tokens_per_minute as u64, + max_tokens_per_day: per_user_max_tokens_per_day as u64, + }, + ) + .await + .log_err(); + } + return Err(Error::http( StatusCode::TOO_MANY_REQUESTS, format!("Rate limit exceeded. Maximum {} reached.", resource), diff --git a/crates/collab/src/llm/telemetry.rs b/crates/collab/src/llm/telemetry.rs index f8d0cf4aacaefc7add61658055387af2e7d322fd..1cfa18e69d546c10b03f12fed7a951e0f2e187dd 100644 --- a/crates/collab/src/llm/telemetry.rs +++ b/crates/collab/src/llm/telemetry.rs @@ -19,9 +19,38 @@ pub struct LlmUsageEventRow { pub spending_this_month: u64, } +#[derive(Serialize, Debug, clickhouse::Row)] +pub struct LlmRateLimitEventRow { + pub time: i64, + pub user_id: i32, + pub is_staff: bool, + pub plan: String, + pub model: String, + pub provider: String, + pub usage_measure: String, + pub requests_this_minute: u64, + pub tokens_this_minute: u64, + pub tokens_this_day: u64, + pub users_in_recent_minutes: u64, + pub users_in_recent_days: u64, + pub max_requests_per_minute: u64, + pub max_tokens_per_minute: u64, + pub max_tokens_per_day: u64, +} + pub async fn report_llm_usage(client: &clickhouse::Client, row: LlmUsageEventRow) -> Result<()> { let mut insert = client.insert("llm_usage_events")?; insert.write(&row).await?; insert.end().await?; Ok(()) } + +pub async fn report_llm_rate_limit( + client: &clickhouse::Client, + row: LlmRateLimitEventRow, +) -> Result<()> { + let mut insert = client.insert("llm_rate_limits")?; + insert.write(&row).await?; + insert.end().await?; + Ok(()) +}