diff --git a/Cargo.lock b/Cargo.lock index 402953d95fb1bf672992c9d3da9a631c53f02e38..52bf0e405989934ddb53fd2af739371e5ae33e74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,6 +223,7 @@ name = "anthropic" version = "0.1.0" dependencies = [ "anyhow", + "chrono", "futures 0.3.30", "http_client", "isahc", @@ -232,6 +233,7 @@ dependencies = [ "strum", "thiserror", "tokio", + "util", ] [[package]] diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 4628d3db809cf8f33672bfb37ad78ca723265aaf..9e48ad0e57d81d1434d3e872e84edcab7f233900 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -17,6 +17,7 @@ path = "src/anthropic.rs" [dependencies] anyhow.workspace = true +chrono.workspace = true futures.workspace = true http_client.workspace = true isahc.workspace = true @@ -25,6 +26,7 @@ serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true +util.workspace = true [dev-dependencies] tokio.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 0339f65390c3ad3379c0b70ee25bfbf7febcf98a..38b4f5466c32c96de2ae1a0d5d31900ea5266c81 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,14 +1,17 @@ mod supported_countries; use anyhow::{anyhow, Context, Result}; +use chrono::{DateTime, Utc}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; +use isahc::http::{HeaderMap, HeaderValue}; use serde::{Deserialize, Serialize}; use std::time::Duration; use std::{pin::Pin, str::FromStr}; use strum::{EnumIter, EnumString}; use thiserror::Error; +use util::ResultExt as _; pub use supported_countries::*; @@ -195,6 +198,66 @@ pub async fn stream_completion( request: Request, low_speed_timeout: Option, ) -> Result>, AnthropicError> { + stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout) + .await + .map(|output| output.0) +} + +/// https://docs.anthropic.com/en/api/rate-limits#response-headers +#[derive(Debug)] +pub struct RateLimitInfo { + pub requests_limit: usize, + pub requests_remaining: usize, + pub requests_reset: DateTime, + pub tokens_limit: usize, + pub tokens_remaining: usize, + pub tokens_reset: DateTime, +} + +impl RateLimitInfo { + fn from_headers(headers: &HeaderMap) -> Result { + let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?; + let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?; + let tokens_remaining = + get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?; + let requests_remaining = + get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?; + let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?; + let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?; + let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc(); + let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc(); + + Ok(Self { + requests_limit, + tokens_limit, + requests_remaining, + tokens_remaining, + requests_reset, + tokens_reset, + }) + } +} + +fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> { + Ok(headers + .get(key) + .ok_or_else(|| anyhow!("missing header `{key}`"))? + .to_str()?) +} + +pub async fn stream_completion_with_rate_limit_info( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, + low_speed_timeout: Option, +) -> Result< + ( + BoxStream<'static, Result>, + Option, + ), + AnthropicError, +> { let request = StreamingRequest { base: request, stream: true, @@ -224,8 +287,9 @@ pub async fn stream_completion( .await .context("failed to send request to Anthropic")?; if response.status().is_success() { + let rate_limits = RateLimitInfo::from_headers(response.headers()); let reader = BufReader::new(response.into_body()); - Ok(reader + let stream = reader .lines() .filter_map(|line| async move { match line { @@ -239,7 +303,8 @@ pub async fn stream_completion( Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))), } }) - .boxed()) + .boxed(); + Ok((stream, rate_limits.log_err())) } else { let mut body = Vec::new(); response diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 9eb17ef976e6d7e906dcd4b5d5b0f254cd60805e..4ec8b70ac0cb11c8f5f82c9a81f1ee467d32907f 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -217,7 +217,7 @@ async fn perform_completion( _ => request.model, }; - let chunks = anthropic::stream_completion( + let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info( &state.http_client, anthropic::ANTHROPIC_API_URL, api_key, @@ -245,6 +245,18 @@ async fn perform_completion( anthropic::AnthropicError::Other(err) => Error::Internal(err), })?; + if let Some(rate_limit_info) = rate_limit_info { + tracing::info!( + target: "upstream rate limit", + provider = params.provider.to_string(), + model = model, + tokens_remaining = rate_limit_info.tokens_remaining, + requests_remaining = rate_limit_info.requests_remaining, + requests_reset = ?rate_limit_info.requests_reset, + tokens_reset = ?rate_limit_info.tokens_reset, + ); + } + chunks .map(move |event| { let chunk = event?; @@ -540,33 +552,74 @@ impl Drop for TokenCountingStream { .await .log_err(); - if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) { - report_llm_usage( - clickhouse_client, - LlmUsageEventRow { - time: Utc::now().timestamp_millis(), - user_id: claims.user_id as i32, - is_staff: claims.is_staff, - plan: match claims.plan { - Plan::Free => "free".to_string(), - Plan::ZedPro => "zed_pro".to_string(), + if let Some(usage) = usage { + tracing::info!( + target: "user usage", + user_id = claims.user_id, + login = claims.github_user_login, + authn.jti = claims.jti, + requests_this_minute = usage.requests_this_minute, + tokens_this_minute = usage.tokens_this_minute, + ); + + if let Some(clickhouse_client) = state.clickhouse_client.as_ref() { + report_llm_usage( + clickhouse_client, + LlmUsageEventRow { + time: Utc::now().timestamp_millis(), + user_id: claims.user_id as i32, + is_staff: claims.is_staff, + plan: match claims.plan { + Plan::Free => "free".to_string(), + Plan::ZedPro => "zed_pro".to_string(), + }, + model, + provider: provider.to_string(), + input_token_count: input_token_count as u64, + output_token_count: output_token_count as u64, + requests_this_minute: usage.requests_this_minute as u64, + tokens_this_minute: usage.tokens_this_minute as u64, + tokens_this_day: usage.tokens_this_day as u64, + input_tokens_this_month: usage.input_tokens_this_month as u64, + output_tokens_this_month: usage.output_tokens_this_month as u64, + spending_this_month: usage.spending_this_month as u64, + lifetime_spending: usage.lifetime_spending as u64, }, - model, - provider: provider.to_string(), - input_token_count: input_token_count as u64, - output_token_count: output_token_count as u64, - requests_this_minute: usage.requests_this_minute as u64, - tokens_this_minute: usage.tokens_this_minute as u64, - tokens_this_day: usage.tokens_this_day as u64, - input_tokens_this_month: usage.input_tokens_this_month as u64, - output_tokens_this_month: usage.output_tokens_this_month as u64, - spending_this_month: usage.spending_this_month as u64, - lifetime_spending: usage.lifetime_spending as u64, - }, - ) - .await - .log_err(); + ) + .await + .log_err(); + } } }) } } + +pub fn log_usage_periodically(state: Arc) { + state.executor.clone().spawn_detached(async move { + loop { + state + .executor + .sleep(std::time::Duration::from_secs(30)) + .await; + + let Some(usages) = state + .db + .get_application_wide_usages_by_model(Utc::now()) + .await + .log_err() + else { + continue; + }; + + for usage in usages { + tracing::info!( + target: "computed usage", + provider = usage.provider.to_string(), + model = usage.model, + requests_this_minute = usage.requests_this_minute, + tokens_this_minute = usage.tokens_this_minute, + ); + } + } + }) +} diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index adfd55088fdc5aa69e96744633b03bc4f70f5dd7..0bfbb4c1b1ba96eaac2e81164e462de19899cae4 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,5 +1,6 @@ use crate::db::UserId; use chrono::Duration; +use futures::StreamExt as _; use rpc::LanguageModelProvider; use sea_orm::QuerySelect; use std::{iter, str::FromStr}; @@ -18,6 +19,14 @@ pub struct Usage { pub lifetime_spending: usize, } +#[derive(Debug, PartialEq, Clone)] +pub struct ApplicationWideUsage { + pub provider: LanguageModelProvider, + pub model: String, + pub requests_this_minute: usize, + pub tokens_this_minute: usize, +} + #[derive(Clone, Copy, Debug, Default)] pub struct ActiveUserCount { pub users_in_recent_minutes: usize, @@ -63,6 +72,71 @@ impl LlmDatabase { Ok(()) } + pub async fn get_application_wide_usages_by_model( + &self, + now: DateTimeUtc, + ) -> Result> { + self.transaction(|tx| async move { + let past_minute = now - Duration::minutes(1); + let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute]; + let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute]; + + let mut results = Vec::new(); + for (provider, model) in self.models.keys().cloned() { + let mut usages = usage::Entity::find() + .filter( + usage::Column::Timestamp + .gte(past_minute.naive_utc()) + .and(usage::Column::IsStaff.eq(false)) + .and( + usage::Column::MeasureId + .eq(requests_per_minute) + .or(usage::Column::MeasureId.eq(tokens_per_minute)), + ), + ) + .stream(&*tx) + .await?; + + let mut requests_this_minute = 0; + let mut tokens_this_minute = 0; + while let Some(usage) = usages.next().await { + let usage = usage?; + if usage.measure_id == requests_per_minute { + requests_this_minute += Self::get_live_buckets( + &usage, + now.naive_utc(), + UsageMeasure::RequestsPerMinute, + ) + .0 + .iter() + .copied() + .sum::() as usize; + } else if usage.measure_id == tokens_per_minute { + tokens_this_minute += Self::get_live_buckets( + &usage, + now.naive_utc(), + UsageMeasure::TokensPerMinute, + ) + .0 + .iter() + .copied() + .sum::() as usize; + } + } + + results.push(ApplicationWideUsage { + provider, + model, + requests_this_minute, + tokens_this_minute, + }) + } + + Ok(results) + }) + .await + } + pub async fn get_usage( &self, user_id: UserId, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 5d4fc2abe06c915bba6cd03fa0625ebd2d52eed1..35a80b702e0c24cef48df24c7269ca8a1bdceaca 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -5,7 +5,7 @@ use axum::{ routing::get, Extension, Router, }; -use collab::llm::db::LlmDatabase; +use collab::llm::{db::LlmDatabase, log_usage_periodically}; use collab::migrations::run_database_migrations; use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode}; use collab::{ @@ -95,6 +95,8 @@ async fn main() -> Result<()> { let state = LlmState::new(config.clone(), Executor::Production).await?; + log_usage_periodically(state.clone()); + app = app .merge(collab::llm::routes()) .layer(Extension(state.clone()));