Add telemetry for LLM usage (#16049)

Max Brunsfeld and Marshall created

Release Notes:

- N/A

Co-authored-by: Marshall <marshall@zed.dev>

Change summary

crates/collab/src/llm.rs                   | 74 ++++++++++++++---
crates/collab/src/llm/db/queries/usages.rs | 98 +++++++++++++----------
crates/collab/src/llm/telemetry.rs         | 25 ++++++
3 files changed, 137 insertions(+), 60 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -1,8 +1,12 @@
 mod authorization;
 pub mod db;
+mod telemetry;
 mod token;
 
-use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
+use crate::{
+    api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error,
+    Result,
+};
 use anyhow::{anyhow, Context as _};
 use authorization::authorize_access_to_language_model;
 use axum::{
@@ -17,12 +21,15 @@ use chrono::{DateTime, Duration, Utc};
 use db::{ActiveUserCount, LlmDatabase};
 use futures::{Stream, StreamExt as _};
 use http_client::IsahcHttpClient;
-use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
+use rpc::{
+    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
+};
 use std::{
     pin::Pin,
     sync::Arc,
     task::{Context, Poll},
 };
+use telemetry::{report_llm_usage, LlmUsageEventRow};
 use tokio::sync::RwLock;
 use util::ResultExt;
 
@@ -33,6 +40,7 @@ pub struct LlmState {
     pub executor: Executor,
     pub db: Arc<LlmDatabase>,
     pub http_client: IsahcHttpClient,
+    pub clickhouse_client: Option<clickhouse::Client>,
     active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 }
 
@@ -65,11 +73,15 @@ impl LlmState {
             Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
 
         let this = Self {
-            config,
             executor,
             db,
             http_client,
+            clickhouse_client: config
+                .clickhouse_url
+                .as_ref()
+                .and_then(|_| build_clickhouse_client(&config).log_err()),
             active_user_count: RwLock::new(initial_active_user_count),
+            config,
         };
 
         Ok(Arc::new(this))
@@ -155,8 +167,6 @@ async fn perform_completion(
         &model,
     )?;
 
-    let user_id = claims.user_id as i32;
-
     check_usage_limit(&state, params.provider, &model, &claims).await?;
 
     let stream = match params.provider {
@@ -310,9 +320,8 @@ async fn perform_completion(
     };
 
     Ok(Response::new(Body::wrap_stream(TokenCountingStream {
-        db: state.db.clone(),
-        executor: state.executor.clone(),
-        user_id,
+        state,
+        claims,
         provider: params.provider,
         model,
         input_tokens: 0,
@@ -403,9 +412,8 @@ async fn check_usage_limit(
 }
 
 struct TokenCountingStream<S> {
-    db: Arc<LlmDatabase>,
-    executor: Executor,
-    user_id: i32,
+    state: Arc<LlmState>,
+    claims: LlmTokenClaims,
     provider: LanguageModelProvider,
     model: String,
     input_tokens: usize,
@@ -436,15 +444,49 @@ where
 
 impl<S> Drop for TokenCountingStream<S> {
     fn drop(&mut self) {
-        let db = self.db.clone();
-        let user_id = self.user_id;
+        let state = self.state.clone();
+        let claims = self.claims.clone();
         let provider = self.provider;
         let model = std::mem::take(&mut self.model);
-        let token_count = self.input_tokens + self.output_tokens;
-        self.executor.spawn_detached(async move {
-            db.record_usage(user_id, provider, &model, token_count, Utc::now())
+        let input_token_count = self.input_tokens;
+        let output_token_count = self.output_tokens;
+        self.state.executor.spawn_detached(async move {
+            let usage = state
+                .db
+                .record_usage(
+                    claims.user_id as i32,
+                    provider,
+                    &model,
+                    input_token_count + output_token_count,
+                    Utc::now(),
+                )
                 .await
                 .log_err();
+
+            if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
+                report_llm_usage(
+                    clickhouse_client,
+                    LlmUsageEventRow {
+                        time: Utc::now().timestamp_millis(),
+                        user_id: claims.user_id as i32,
+                        is_staff: claims.is_staff,
+                        plan: match claims.plan {
+                            Plan::Free => "free".to_string(),
+                            Plan::ZedPro => "zed_pro".to_string(),
+                        },
+                        model,
+                        provider: provider.to_string(),
+                        input_token_count: input_token_count as u64,
+                        output_token_count: output_token_count as u64,
+                        requests_this_minute: usage.requests_this_minute as u64,
+                        tokens_this_minute: usage.tokens_this_minute as u64,
+                        tokens_this_day: usage.tokens_this_day as u64,
+                        tokens_this_month: usage.tokens_this_month as u64,
+                    },
+                )
+                .await
+                .log_err();
+            }
         })
     }
 }

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -107,7 +107,7 @@ impl LlmDatabase {
         model_name: &str,
         token_count: usize,
         now: DateTimeUtc,
-    ) -> Result<()> {
+    ) -> Result<Usage> {
         self.transaction(|tx| async move {
             let model = self.model(provider, model_name)?;
 
@@ -120,48 +120,57 @@ impl LlmDatabase {
                 .all(&*tx)
                 .await?;
 
-            self.update_usage_for_measure(
-                user_id,
-                model.id,
-                &usages,
-                UsageMeasure::RequestsPerMinute,
-                now,
-                1,
-                &tx,
-            )
-            .await?;
-            self.update_usage_for_measure(
-                user_id,
-                model.id,
-                &usages,
-                UsageMeasure::TokensPerMinute,
-                now,
-                token_count,
-                &tx,
-            )
-            .await?;
-            self.update_usage_for_measure(
-                user_id,
-                model.id,
-                &usages,
-                UsageMeasure::TokensPerDay,
-                now,
-                token_count,
-                &tx,
-            )
-            .await?;
-            self.update_usage_for_measure(
-                user_id,
-                model.id,
-                &usages,
-                UsageMeasure::TokensPerMonth,
-                now,
-                token_count,
-                &tx,
-            )
-            .await?;
+            let requests_this_minute = self
+                .update_usage_for_measure(
+                    user_id,
+                    model.id,
+                    &usages,
+                    UsageMeasure::RequestsPerMinute,
+                    now,
+                    1,
+                    &tx,
+                )
+                .await?;
+            let tokens_this_minute = self
+                .update_usage_for_measure(
+                    user_id,
+                    model.id,
+                    &usages,
+                    UsageMeasure::TokensPerMinute,
+                    now,
+                    token_count,
+                    &tx,
+                )
+                .await?;
+            let tokens_this_day = self
+                .update_usage_for_measure(
+                    user_id,
+                    model.id,
+                    &usages,
+                    UsageMeasure::TokensPerDay,
+                    now,
+                    token_count,
+                    &tx,
+                )
+                .await?;
+            let tokens_this_month = self
+                .update_usage_for_measure(
+                    user_id,
+                    model.id,
+                    &usages,
+                    UsageMeasure::TokensPerMonth,
+                    now,
+                    token_count,
+                    &tx,
+                )
+                .await?;
 
-            Ok(())
+            Ok(Usage {
+                requests_this_minute,
+                tokens_this_minute,
+                tokens_this_day,
+                tokens_this_month,
+            })
         })
         .await
     }
@@ -205,7 +214,7 @@ impl LlmDatabase {
         now: DateTimeUtc,
         usage_to_add: usize,
         tx: &DatabaseTransaction,
-    ) -> Result<()> {
+    ) -> Result<usize> {
         let now = now.naive_utc();
         let measure_id = *self
             .usage_measure_ids
@@ -230,6 +239,7 @@ impl LlmDatabase {
         }
 
         *buckets.last_mut().unwrap() += usage_to_add as i64;
+        let total_usage = buckets.iter().sum::<i64>() as usize;
 
         let mut model = usage::ActiveModel {
             user_id: ActiveValue::set(user_id),
@@ -249,7 +259,7 @@ impl LlmDatabase {
                 .await?;
         }
 
-        Ok(())
+        Ok(total_usage)
     }
 
     fn get_usage_for_measure(

crates/collab/src/llm/telemetry.rs 🔗

@@ -0,0 +1,25 @@
+use anyhow::Result;
+use serde::Serialize;
+
+#[derive(Serialize, Debug, clickhouse::Row)]
+pub struct LlmUsageEventRow {
+    pub time: i64,
+    pub user_id: i32,
+    pub is_staff: bool,
+    pub plan: String,
+    pub model: String,
+    pub provider: String,
+    pub input_token_count: u64,
+    pub output_token_count: u64,
+    pub requests_this_minute: u64,
+    pub tokens_this_minute: u64,
+    pub tokens_this_day: u64,
+    pub tokens_this_month: u64,
+}
+
+pub async fn report_llm_usage(client: &clickhouse::Client, row: LlmUsageEventRow) -> Result<()> {
+    let mut insert = client.insert("llm_usage_events")?;
+    insert.write(&row).await?;
+    insert.end().await?;
+    Ok(())
+}