@@ -1,8 +1,12 @@
mod authorization;
pub mod db;
+mod telemetry;
mod token;
-use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
+use crate::{
+ api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error,
+ Result,
+};
use anyhow::{anyhow, Context as _};
use authorization::authorize_access_to_language_model;
use axum::{
@@ -17,12 +21,15 @@ use chrono::{DateTime, Duration, Utc};
use db::{ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use http_client::IsahcHttpClient;
-use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
+use rpc::{
+ proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
+};
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
+use telemetry::{report_llm_usage, LlmUsageEventRow};
use tokio::sync::RwLock;
use util::ResultExt;
@@ -33,6 +40,7 @@ pub struct LlmState {
pub executor: Executor,
pub db: Arc<LlmDatabase>,
pub http_client: IsahcHttpClient,
+ pub clickhouse_client: Option<clickhouse::Client>,
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
}
@@ -65,11 +73,15 @@ impl LlmState {
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
let this = Self {
- config,
executor,
db,
http_client,
+ clickhouse_client: config
+ .clickhouse_url
+ .as_ref()
+ .and_then(|_| build_clickhouse_client(&config).log_err()),
active_user_count: RwLock::new(initial_active_user_count),
+ config,
};
Ok(Arc::new(this))
@@ -155,8 +167,6 @@ async fn perform_completion(
&model,
)?;
- let user_id = claims.user_id as i32;
-
check_usage_limit(&state, params.provider, &model, &claims).await?;
let stream = match params.provider {
@@ -310,9 +320,8 @@ async fn perform_completion(
};
Ok(Response::new(Body::wrap_stream(TokenCountingStream {
- db: state.db.clone(),
- executor: state.executor.clone(),
- user_id,
+ state,
+ claims,
provider: params.provider,
model,
input_tokens: 0,
@@ -403,9 +412,8 @@ async fn check_usage_limit(
}
struct TokenCountingStream<S> {
- db: Arc<LlmDatabase>,
- executor: Executor,
- user_id: i32,
+ state: Arc<LlmState>,
+ claims: LlmTokenClaims,
provider: LanguageModelProvider,
model: String,
input_tokens: usize,
@@ -436,15 +444,49 @@ where
impl<S> Drop for TokenCountingStream<S> {
fn drop(&mut self) {
- let db = self.db.clone();
- let user_id = self.user_id;
+ let state = self.state.clone();
+ let claims = self.claims.clone();
let provider = self.provider;
let model = std::mem::take(&mut self.model);
- let token_count = self.input_tokens + self.output_tokens;
- self.executor.spawn_detached(async move {
- db.record_usage(user_id, provider, &model, token_count, Utc::now())
+ let input_token_count = self.input_tokens;
+ let output_token_count = self.output_tokens;
+ self.state.executor.spawn_detached(async move {
+ let usage = state
+ .db
+ .record_usage(
+ claims.user_id as i32,
+ provider,
+ &model,
+ input_token_count + output_token_count,
+ Utc::now(),
+ )
.await
.log_err();
+
+ if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
+ report_llm_usage(
+ clickhouse_client,
+ LlmUsageEventRow {
+ time: Utc::now().timestamp_millis(),
+ user_id: claims.user_id as i32,
+ is_staff: claims.is_staff,
+ plan: match claims.plan {
+ Plan::Free => "free".to_string(),
+ Plan::ZedPro => "zed_pro".to_string(),
+ },
+ model,
+ provider: provider.to_string(),
+ input_token_count: input_token_count as u64,
+ output_token_count: output_token_count as u64,
+ requests_this_minute: usage.requests_this_minute as u64,
+ tokens_this_minute: usage.tokens_this_minute as u64,
+ tokens_this_day: usage.tokens_this_day as u64,
+ tokens_this_month: usage.tokens_this_month as u64,
+ },
+ )
+ .await
+ .log_err();
+ }
})
}
}
@@ -107,7 +107,7 @@ impl LlmDatabase {
model_name: &str,
token_count: usize,
now: DateTimeUtc,
- ) -> Result<()> {
+ ) -> Result<Usage> {
self.transaction(|tx| async move {
let model = self.model(provider, model_name)?;
@@ -120,48 +120,57 @@ impl LlmDatabase {
.all(&*tx)
.await?;
- self.update_usage_for_measure(
- user_id,
- model.id,
- &usages,
- UsageMeasure::RequestsPerMinute,
- now,
- 1,
- &tx,
- )
- .await?;
- self.update_usage_for_measure(
- user_id,
- model.id,
- &usages,
- UsageMeasure::TokensPerMinute,
- now,
- token_count,
- &tx,
- )
- .await?;
- self.update_usage_for_measure(
- user_id,
- model.id,
- &usages,
- UsageMeasure::TokensPerDay,
- now,
- token_count,
- &tx,
- )
- .await?;
- self.update_usage_for_measure(
- user_id,
- model.id,
- &usages,
- UsageMeasure::TokensPerMonth,
- now,
- token_count,
- &tx,
- )
- .await?;
+ let requests_this_minute = self
+ .update_usage_for_measure(
+ user_id,
+ model.id,
+ &usages,
+ UsageMeasure::RequestsPerMinute,
+ now,
+ 1,
+ &tx,
+ )
+ .await?;
+ let tokens_this_minute = self
+ .update_usage_for_measure(
+ user_id,
+ model.id,
+ &usages,
+ UsageMeasure::TokensPerMinute,
+ now,
+ token_count,
+ &tx,
+ )
+ .await?;
+ let tokens_this_day = self
+ .update_usage_for_measure(
+ user_id,
+ model.id,
+ &usages,
+ UsageMeasure::TokensPerDay,
+ now,
+ token_count,
+ &tx,
+ )
+ .await?;
+ let tokens_this_month = self
+ .update_usage_for_measure(
+ user_id,
+ model.id,
+ &usages,
+ UsageMeasure::TokensPerMonth,
+ now,
+ token_count,
+ &tx,
+ )
+ .await?;
- Ok(())
+ Ok(Usage {
+ requests_this_minute,
+ tokens_this_minute,
+ tokens_this_day,
+ tokens_this_month,
+ })
})
.await
}
@@ -205,7 +214,7 @@ impl LlmDatabase {
now: DateTimeUtc,
usage_to_add: usize,
tx: &DatabaseTransaction,
- ) -> Result<()> {
+ ) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
@@ -230,6 +239,7 @@ impl LlmDatabase {
}
*buckets.last_mut().unwrap() += usage_to_add as i64;
+ let total_usage = buckets.iter().sum::<i64>() as usize;
let mut model = usage::ActiveModel {
user_id: ActiveValue::set(user_id),
@@ -249,7 +259,7 @@ impl LlmDatabase {
.await?;
}
- Ok(())
+ Ok(total_usage)
}
fn get_usage_for_measure(
@@ -0,0 +1,25 @@
+use anyhow::Result;
+use serde::Serialize;
+
+#[derive(Serialize, Debug, clickhouse::Row)]
+pub struct LlmUsageEventRow {
+ pub time: i64,
+ pub user_id: i32,
+ pub is_staff: bool,
+ pub plan: String,
+ pub model: String,
+ pub provider: String,
+ pub input_token_count: u64,
+ pub output_token_count: u64,
+ pub requests_this_minute: u64,
+ pub tokens_this_minute: u64,
+ pub tokens_this_day: u64,
+ pub tokens_this_month: u64,
+}
+
+pub async fn report_llm_usage(client: &clickhouse::Client, row: LlmUsageEventRow) -> Result<()> {
+ let mut insert = client.insert("llm_usage_events")?;
+ insert.write(&row).await?;
+ insert.end().await?;
+ Ok(())
+}