Add tracing needed for LLM rate limit dashboards (#16388)

Max Brunsfeld and Marshall created

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>

Change summary

Cargo.lock                                 |   2 
crates/anthropic/Cargo.toml                |   2 
crates/anthropic/src/anthropic.rs          |  69 +++++++++++++++
crates/collab/src/llm.rs                   | 105 ++++++++++++++++++-----
crates/collab/src/llm/db/queries/usages.rs |  74 ++++++++++++++++
crates/collab/src/main.rs                  |   4 
6 files changed, 227 insertions(+), 29 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -223,6 +223,7 @@ name = "anthropic"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "chrono",
  "futures 0.3.30",
  "http_client",
  "isahc",
@@ -232,6 +233,7 @@ dependencies = [
  "strum",
  "thiserror",
  "tokio",
+ "util",
 ]
 
 [[package]]

crates/anthropic/Cargo.toml 🔗

@@ -17,6 +17,7 @@ path = "src/anthropic.rs"
 
 [dependencies]
 anyhow.workspace = true
+chrono.workspace = true
 futures.workspace = true
 http_client.workspace = true
 isahc.workspace = true
@@ -25,6 +26,7 @@ serde.workspace = true
 serde_json.workspace = true
 strum.workspace = true
 thiserror.workspace = true
+util.workspace = true
 
 [dev-dependencies]
 tokio.workspace = true

crates/anthropic/src/anthropic.rs 🔗

@@ -1,14 +1,17 @@
 mod supported_countries;
 
 use anyhow::{anyhow, Context, Result};
+use chrono::{DateTime, Utc};
 use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
+use isahc::http::{HeaderMap, HeaderValue};
 use serde::{Deserialize, Serialize};
 use std::time::Duration;
 use std::{pin::Pin, str::FromStr};
 use strum::{EnumIter, EnumString};
 use thiserror::Error;
+use util::ResultExt as _;
 
 pub use supported_countries::*;
 
@@ -195,6 +198,66 @@ pub async fn stream_completion(
     request: Request,
     low_speed_timeout: Option<Duration>,
 ) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
+    stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
+        .await
+        .map(|output| output.0)
+}
+
+/// https://docs.anthropic.com/en/api/rate-limits#response-headers
+#[derive(Debug)]
+pub struct RateLimitInfo {
+    pub requests_limit: usize,
+    pub requests_remaining: usize,
+    pub requests_reset: DateTime<Utc>,
+    pub tokens_limit: usize,
+    pub tokens_remaining: usize,
+    pub tokens_reset: DateTime<Utc>,
+}
+
+impl RateLimitInfo {
+    fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
+        let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
+        let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
+        let tokens_remaining =
+            get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
+        let requests_remaining =
+            get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
+        let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
+        let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
+        let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
+        let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
+
+        Ok(Self {
+            requests_limit,
+            tokens_limit,
+            requests_remaining,
+            tokens_remaining,
+            requests_reset,
+            tokens_reset,
+        })
+    }
+}
+
+fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
+    Ok(headers
+        .get(key)
+        .ok_or_else(|| anyhow!("missing header `{key}`"))?
+        .to_str()?)
+}
+
+pub async fn stream_completion_with_rate_limit_info(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: Request,
+    low_speed_timeout: Option<Duration>,
+) -> Result<
+    (
+        BoxStream<'static, Result<Event, AnthropicError>>,
+        Option<RateLimitInfo>,
+    ),
+    AnthropicError,
+> {
     let request = StreamingRequest {
         base: request,
         stream: true,
@@ -224,8 +287,9 @@ pub async fn stream_completion(
         .await
         .context("failed to send request to Anthropic")?;
     if response.status().is_success() {
+        let rate_limits = RateLimitInfo::from_headers(response.headers());
         let reader = BufReader::new(response.into_body());
-        Ok(reader
+        let stream = reader
             .lines()
             .filter_map(|line| async move {
                 match line {
@@ -239,7 +303,8 @@ pub async fn stream_completion(
                     Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
                 }
             })
-            .boxed())
+            .boxed();
+        Ok((stream, rate_limits.log_err()))
     } else {
         let mut body = Vec::new();
         response

crates/collab/src/llm.rs 🔗

@@ -217,7 +217,7 @@ async fn perform_completion(
                 _ => request.model,
             };
 
-            let chunks = anthropic::stream_completion(
+            let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
                 &state.http_client,
                 anthropic::ANTHROPIC_API_URL,
                 api_key,
@@ -245,6 +245,18 @@ async fn perform_completion(
                 anthropic::AnthropicError::Other(err) => Error::Internal(err),
             })?;
 
+            if let Some(rate_limit_info) = rate_limit_info {
+                tracing::info!(
+                    target: "upstream rate limit",
+                    provider = params.provider.to_string(),
+                    model = model,
+                    tokens_remaining = rate_limit_info.tokens_remaining,
+                    requests_remaining = rate_limit_info.requests_remaining,
+                    requests_reset = ?rate_limit_info.requests_reset,
+                    tokens_reset = ?rate_limit_info.tokens_reset,
+                );
+            }
+
             chunks
                 .map(move |event| {
                     let chunk = event?;
@@ -540,33 +552,74 @@ impl<S> Drop for TokenCountingStream<S> {
                 .await
                 .log_err();
 
-            if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
-                report_llm_usage(
-                    clickhouse_client,
-                    LlmUsageEventRow {
-                        time: Utc::now().timestamp_millis(),
-                        user_id: claims.user_id as i32,
-                        is_staff: claims.is_staff,
-                        plan: match claims.plan {
-                            Plan::Free => "free".to_string(),
-                            Plan::ZedPro => "zed_pro".to_string(),
+            if let Some(usage) = usage {
+                tracing::info!(
+                    target: "user usage",
+                    user_id = claims.user_id,
+                    login = claims.github_user_login,
+                    authn.jti = claims.jti,
+                    requests_this_minute = usage.requests_this_minute,
+                    tokens_this_minute = usage.tokens_this_minute,
+                );
+
+                if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
+                    report_llm_usage(
+                        clickhouse_client,
+                        LlmUsageEventRow {
+                            time: Utc::now().timestamp_millis(),
+                            user_id: claims.user_id as i32,
+                            is_staff: claims.is_staff,
+                            plan: match claims.plan {
+                                Plan::Free => "free".to_string(),
+                                Plan::ZedPro => "zed_pro".to_string(),
+                            },
+                            model,
+                            provider: provider.to_string(),
+                            input_token_count: input_token_count as u64,
+                            output_token_count: output_token_count as u64,
+                            requests_this_minute: usage.requests_this_minute as u64,
+                            tokens_this_minute: usage.tokens_this_minute as u64,
+                            tokens_this_day: usage.tokens_this_day as u64,
+                            input_tokens_this_month: usage.input_tokens_this_month as u64,
+                            output_tokens_this_month: usage.output_tokens_this_month as u64,
+                            spending_this_month: usage.spending_this_month as u64,
+                            lifetime_spending: usage.lifetime_spending as u64,
                         },
-                        model,
-                        provider: provider.to_string(),
-                        input_token_count: input_token_count as u64,
-                        output_token_count: output_token_count as u64,
-                        requests_this_minute: usage.requests_this_minute as u64,
-                        tokens_this_minute: usage.tokens_this_minute as u64,
-                        tokens_this_day: usage.tokens_this_day as u64,
-                        input_tokens_this_month: usage.input_tokens_this_month as u64,
-                        output_tokens_this_month: usage.output_tokens_this_month as u64,
-                        spending_this_month: usage.spending_this_month as u64,
-                        lifetime_spending: usage.lifetime_spending as u64,
-                    },
-                )
-                .await
-                .log_err();
+                    )
+                    .await
+                    .log_err();
+                }
             }
         })
     }
 }
+
+pub fn log_usage_periodically(state: Arc<LlmState>) {
+    state.executor.clone().spawn_detached(async move {
+        loop {
+            state
+                .executor
+                .sleep(std::time::Duration::from_secs(30))
+                .await;
+
+            let Some(usages) = state
+                .db
+                .get_application_wide_usages_by_model(Utc::now())
+                .await
+                .log_err()
+            else {
+                continue;
+            };
+
+            for usage in usages {
+                tracing::info!(
+                    target: "computed usage",
+                    provider = usage.provider.to_string(),
+                    model = usage.model,
+                    requests_this_minute = usage.requests_this_minute,
+                    tokens_this_minute = usage.tokens_this_minute,
+                );
+            }
+        }
+    })
+}

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -1,5 +1,6 @@
 use crate::db::UserId;
 use chrono::Duration;
+use futures::StreamExt as _;
 use rpc::LanguageModelProvider;
 use sea_orm::QuerySelect;
 use std::{iter, str::FromStr};
@@ -18,6 +19,14 @@ pub struct Usage {
     pub lifetime_spending: usize,
 }
 
+#[derive(Debug, PartialEq, Clone)]
+pub struct ApplicationWideUsage {
+    pub provider: LanguageModelProvider,
+    pub model: String,
+    pub requests_this_minute: usize,
+    pub tokens_this_minute: usize,
+}
+
 #[derive(Clone, Copy, Debug, Default)]
 pub struct ActiveUserCount {
     pub users_in_recent_minutes: usize,
@@ -63,6 +72,71 @@ impl LlmDatabase {
         Ok(())
     }
 
+    pub async fn get_application_wide_usages_by_model(
+        &self,
+        now: DateTimeUtc,
+    ) -> Result<Vec<ApplicationWideUsage>> {
+        self.transaction(|tx| async move {
+            let past_minute = now - Duration::minutes(1);
+            let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
+            let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
+
+            let mut results = Vec::new();
+            for (provider, model) in self.models.keys().cloned() {
+                let mut usages = usage::Entity::find()
+                    .filter(
+                        usage::Column::Timestamp
+                            .gte(past_minute.naive_utc())
+                            .and(usage::Column::IsStaff.eq(false))
+                            .and(
+                                usage::Column::MeasureId
+                                    .eq(requests_per_minute)
+                                    .or(usage::Column::MeasureId.eq(tokens_per_minute)),
+                            ),
+                    )
+                    .stream(&*tx)
+                    .await?;
+
+                let mut requests_this_minute = 0;
+                let mut tokens_this_minute = 0;
+                while let Some(usage) = usages.next().await {
+                    let usage = usage?;
+                    if usage.measure_id == requests_per_minute {
+                        requests_this_minute += Self::get_live_buckets(
+                            &usage,
+                            now.naive_utc(),
+                            UsageMeasure::RequestsPerMinute,
+                        )
+                        .0
+                        .iter()
+                        .copied()
+                        .sum::<i64>() as usize;
+                    } else if usage.measure_id == tokens_per_minute {
+                        tokens_this_minute += Self::get_live_buckets(
+                            &usage,
+                            now.naive_utc(),
+                            UsageMeasure::TokensPerMinute,
+                        )
+                        .0
+                        .iter()
+                        .copied()
+                        .sum::<i64>() as usize;
+                    }
+                }
+
+                results.push(ApplicationWideUsage {
+                    provider,
+                    model,
+                    requests_this_minute,
+                    tokens_this_minute,
+                })
+            }
+
+            Ok(results)
+        })
+        .await
+    }
+
     pub async fn get_usage(
         &self,
         user_id: UserId,

crates/collab/src/main.rs 🔗

@@ -5,7 +5,7 @@ use axum::{
     routing::get,
     Extension, Router,
 };
-use collab::llm::db::LlmDatabase;
+use collab::llm::{db::LlmDatabase, log_usage_periodically};
 use collab::migrations::run_database_migrations;
 use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
 use collab::{
@@ -95,6 +95,8 @@ async fn main() -> Result<()> {
 
                 let state = LlmState::new(config.clone(), Executor::Production).await?;
 
+                log_usage_periodically(state.clone());
+
                 app = app
                     .merge(collab::llm::routes())
                     .layer(Extension(state.clone()));