Send llm events to snowflake too (#21091)

Conrad Irwin created

Closes #ISSUE

Release Notes:

- N/A

Change summary

crates/client/src/client.rs                |  8 ++
crates/client/src/telemetry.rs             |  4 +
crates/collab/src/api.rs                   | 33 ++++++++
crates/collab/src/api/events.rs            | 43 ++++++++++
crates/collab/src/cents.rs                 |  3 
crates/collab/src/llm.rs                   | 95 ++++++++++++++++++-----
crates/collab/src/llm/db/queries/usages.rs |  4 
crates/collab/src/llm/token.rs             |  8 ++
crates/collab/src/rpc.rs                   |  9 ++
crates/collab/src/tests/test_server.rs     |  1 
10 files changed, 183 insertions(+), 25 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -1067,6 +1067,8 @@ impl Client {
         let proxy = http.proxy().cloned();
         let credentials = credentials.clone();
         let rpc_url = self.rpc_url(http, release_channel);
+        let system_id = self.telemetry.system_id();
+        let metrics_id = self.telemetry.metrics_id();
         cx.background_executor().spawn(async move {
             use HttpOrHttps::*;
 
@@ -1118,6 +1120,12 @@ impl Client {
                 "x-zed-release-channel",
                 HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
             );
+            if let Some(system_id) = system_id {
+                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
+            }
+            if let Some(metrics_id) = metrics_id {
+                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
+            }
 
             match url_scheme {
                 Https => {

crates/client/src/telemetry.rs 🔗

@@ -533,6 +533,10 @@ impl Telemetry {
         self.state.lock().metrics_id.clone()
     }
 
+    pub fn system_id(self: &Arc<Self>) -> Option<Arc<str>> {
+        self.state.lock().system_id.clone()
+    }
+
     pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> {
         self.state.lock().installation_id.clone()
     }

crates/collab/src/api.rs 🔗

@@ -61,6 +61,39 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
     }
 }
 
+pub struct SystemIdHeader(String);
+
+impl Header for SystemIdHeader {
+    fn name() -> &'static HeaderName {
+        static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
+        SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
+    }
+
+    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
+    where
+        Self: Sized,
+        I: Iterator<Item = &'i axum::http::HeaderValue>,
+    {
+        let system_id = values
+            .next()
+            .ok_or_else(axum::headers::Error::invalid)?
+            .to_str()
+            .map_err(|_| axum::headers::Error::invalid())?;
+
+        Ok(Self(system_id.to_string()))
+    }
+
+    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
+        unimplemented!()
+    }
+}
+
+impl std::fmt::Display for SystemIdHeader {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
 pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
     Router::new()
         .route("/user", get(get_authenticated_user))

crates/collab/src/api/events.rs 🔗

@@ -1578,8 +1578,8 @@ fn for_snowflake(
     })
 }
 
-#[derive(Serialize, Deserialize)]
-struct SnowflakeRow {
+#[derive(Serialize, Deserialize, Debug)]
+pub struct SnowflakeRow {
     pub time: chrono::DateTime<chrono::Utc>,
     pub user_id: Option<String>,
     pub device_id: Option<String>,
@@ -1588,3 +1588,42 @@ struct SnowflakeRow {
     pub user_properties: Option<serde_json::Value>,
     pub insert_id: Option<String>,
 }
+
+impl SnowflakeRow {
+    pub fn new(
+        event_type: impl Into<String>,
+        metrics_id: Option<Uuid>,
+        is_staff: bool,
+        system_id: Option<String>,
+        event_properties: serde_json::Value,
+    ) -> Self {
+        Self {
+            time: chrono::Utc::now(),
+            event_type: event_type.into(),
+            device_id: system_id,
+            user_id: metrics_id.map(|id| id.to_string()),
+            insert_id: Some(uuid::Uuid::new_v4().to_string()),
+            event_properties,
+            user_properties: Some(json!({"is_staff": is_staff})),
+        }
+    }
+
+    pub async fn write(
+        self,
+        client: &Option<aws_sdk_kinesis::Client>,
+        stream: &Option<String>,
+    ) -> anyhow::Result<()> {
+        let Some((client, stream)) = client.as_ref().zip(stream.as_ref()) else {
+            return Ok(());
+        };
+        let row = serde_json::to_vec(&self)?;
+        client
+            .put_record()
+            .stream_name(stream)
+            .partition_key(&self.user_id.unwrap_or_default())
+            .data(row.into())
+            .send()
+            .await?;
+        Ok(())
+    }
+}

crates/collab/src/cents.rs 🔗

@@ -1,3 +1,5 @@
+use serde::Serialize;
+
 /// A number of cents.
 #[derive(
     Debug,
@@ -12,6 +14,7 @@
     derive_more::AddAssign,
     derive_more::Sub,
     derive_more::SubAssign,
+    Serialize,
 )]
 pub struct Cents(pub u32);
 

crates/collab/src/llm.rs 🔗

@@ -3,9 +3,11 @@ pub mod db;
 mod telemetry;
 mod token;
 
+use crate::api::events::SnowflakeRow;
+use crate::api::CloudflareIpCountryHeader;
+use crate::build_kinesis_client;
 use crate::{
-    api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents,
-    Config, Error, Result,
+    build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result,
 };
 use anyhow::{anyhow, Context as _};
 use authorization::authorize_access_to_language_model;
@@ -28,6 +30,7 @@ use rpc::{
     proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 };
 use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
+use serde_json::json;
 use std::{
     pin::Pin,
     sync::Arc,
@@ -45,6 +48,7 @@ pub struct LlmState {
     pub executor: Executor,
     pub db: Arc<LlmDatabase>,
     pub http_client: ReqwestClient,
+    pub kinesis_client: Option<aws_sdk_kinesis::Client>,
     pub clickhouse_client: Option<clickhouse::Client>,
     active_user_count_by_model:
         RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
@@ -77,6 +81,11 @@ impl LlmState {
             executor,
             db,
             http_client,
+            kinesis_client: if config.kinesis_access_key.is_some() {
+                build_kinesis_client(&config).await.log_err()
+            } else {
+                None
+            },
             clickhouse_client: config
                 .clickhouse_url
                 .as_ref()
@@ -521,25 +530,50 @@ async fn check_usage_limit(
                 UsageMeasure::TokensPerDay => "tokens_per_day",
             };
 
-            if let Some(client) = state.clickhouse_client.as_ref() {
-                tracing::info!(
-                    target: "user rate limit",
-                    user_id = claims.user_id,
-                    login = claims.github_user_login,
-                    authn.jti = claims.jti,
-                    is_staff = claims.is_staff,
-                    provider = provider.to_string(),
-                    model = model.name,
-                    requests_this_minute = usage.requests_this_minute,
-                    tokens_this_minute = usage.tokens_this_minute,
-                    tokens_this_day = usage.tokens_this_day,
-                    users_in_recent_minutes = users_in_recent_minutes,
-                    users_in_recent_days = users_in_recent_days,
-                    max_requests_per_minute = per_user_max_requests_per_minute,
-                    max_tokens_per_minute = per_user_max_tokens_per_minute,
-                    max_tokens_per_day = per_user_max_tokens_per_day,
-                );
+            tracing::info!(
+                target: "user rate limit",
+                user_id = claims.user_id,
+                login = claims.github_user_login,
+                authn.jti = claims.jti,
+                is_staff = claims.is_staff,
+                provider = provider.to_string(),
+                model = model.name,
+                requests_this_minute = usage.requests_this_minute,
+                tokens_this_minute = usage.tokens_this_minute,
+                tokens_this_day = usage.tokens_this_day,
+                users_in_recent_minutes = users_in_recent_minutes,
+                users_in_recent_days = users_in_recent_days,
+                max_requests_per_minute = per_user_max_requests_per_minute,
+                max_tokens_per_minute = per_user_max_tokens_per_minute,
+                max_tokens_per_day = per_user_max_tokens_per_day,
+            );
+
+            SnowflakeRow::new(
+                "Language Model Rate Limited",
+                claims.metrics_id,
+                claims.is_staff,
+                claims.system_id.clone(),
+                json!({
+                    "usage": usage,
+                    "users_in_recent_minutes": users_in_recent_minutes,
+                    "users_in_recent_days": users_in_recent_days,
+                    "max_requests_per_minute": per_user_max_requests_per_minute,
+                    "max_tokens_per_minute": per_user_max_tokens_per_minute,
+                    "max_tokens_per_day": per_user_max_tokens_per_day,
+                    "plan": match claims.plan {
+                        Plan::Free => "free".to_string(),
+                        Plan::ZedPro => "zed_pro".to_string(),
+                    },
+                    "model": model.name.clone(),
+                    "provider": provider.to_string(),
+                    "usage_measure": resource.to_string(),
+                }),
+            )
+            .write(&state.kinesis_client, &state.config.kinesis_stream)
+            .await
+            .log_err();
 
+            if let Some(client) = state.clickhouse_client.as_ref() {
                 report_llm_rate_limit(
                     client,
                     LlmRateLimitEventRow {
@@ -652,6 +686,27 @@ impl<S> Drop for TokenCountingStream<S> {
                     tokens_this_minute = usage.tokens_this_minute,
                 );
 
+                let properties = json!({
+                    "plan": match claims.plan {
+                        Plan::Free => "free".to_string(),
+                        Plan::ZedPro => "zed_pro".to_string(),
+                    },
+                    "model": model,
+                    "provider": provider,
+                    "usage": usage,
+                    "tokens": tokens
+                });
+                SnowflakeRow::new(
+                    "Language Model Used",
+                    claims.metrics_id,
+                    claims.is_staff,
+                    claims.system_id.clone(),
+                    properties,
+                )
+                .write(&state.kinesis_client, &state.config.kinesis_stream)
+                .await
+                .log_err();
+
                 if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
                     report_llm_usage(
                         clickhouse_client,

crates/collab/src/llm/db/queries/usages.rs 🔗

@@ -9,7 +9,7 @@ use strum::IntoEnumIterator as _;
 
 use super::*;
 
-#[derive(Debug, PartialEq, Clone, Copy, Default)]
+#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
 pub struct TokenUsage {
     pub input: usize,
     pub input_cache_creation: usize,
@@ -23,7 +23,7 @@ impl TokenUsage {
     }
 }
 
-#[derive(Debug, PartialEq, Clone, Copy)]
+#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
 pub struct Usage {
     pub requests_this_minute: usize,
     pub tokens_this_minute: usize,

crates/collab/src/llm/token.rs 🔗

@@ -8,6 +8,7 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
 use serde::{Deserialize, Serialize};
 use std::time::Duration;
 use thiserror::Error;
+use uuid::Uuid;
 
 #[derive(Clone, Debug, Default, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
@@ -16,6 +17,10 @@ pub struct LlmTokenClaims {
     pub exp: u64,
     pub jti: String,
     pub user_id: u64,
+    #[serde(default)]
+    pub system_id: Option<String>,
+    #[serde(default)]
+    pub metrics_id: Option<Uuid>,
     pub github_user_login: String,
     pub is_staff: bool,
     pub has_llm_closed_beta_feature_flag: bool,
@@ -36,6 +41,7 @@ impl LlmTokenClaims {
         has_llm_closed_beta_feature_flag: bool,
         has_llm_subscription: bool,
         plan: rpc::proto::Plan,
+        system_id: Option<String>,
         config: &Config,
     ) -> Result<String> {
         let secret = config
@@ -49,6 +55,8 @@ impl LlmTokenClaims {
             exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
             jti: uuid::Uuid::new_v4().to_string(),
             user_id: user.id.to_proto(),
+            system_id,
+            metrics_id: Some(user.metrics_id),
             github_user_login: user.github_login.clone(),
             is_staff,
             has_llm_closed_beta_feature_flag,

crates/collab/src/rpc.rs 🔗

@@ -1,6 +1,6 @@
 mod connection_pool;
 
-use crate::api::CloudflareIpCountryHeader;
+use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
 use crate::llm::LlmTokenClaims;
 use crate::{
     auth,
@@ -137,6 +137,7 @@ struct Session {
     /// The GeoIP country code for the user.
     #[allow(unused)]
     geoip_country_code: Option<String>,
+    system_id: Option<String>,
     _executor: Executor,
 }
 
@@ -682,6 +683,7 @@ impl Server {
         principal: Principal,
         zed_version: ZedVersion,
         geoip_country_code: Option<String>,
+        system_id: Option<String>,
         send_connection_id: Option<oneshot::Sender<ConnectionId>>,
         executor: Executor,
     ) -> impl Future<Output = ()> {
@@ -737,6 +739,7 @@ impl Server {
                 app_state: this.app_state.clone(),
                 http_client,
                 geoip_country_code,
+                system_id,
                 _executor: executor.clone(),
                 supermaven_client,
             };
@@ -1056,6 +1059,7 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
         .layer(Extension(server))
 }
 
+#[allow(clippy::too_many_arguments)]
 pub async fn handle_websocket_request(
     TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
     app_version_header: Option<TypedHeader<AppVersionHeader>>,
@@ -1063,6 +1067,7 @@ pub async fn handle_websocket_request(
     Extension(server): Extension<Arc<Server>>,
     Extension(principal): Extension<Principal>,
     country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
+    system_id_header: Option<TypedHeader<SystemIdHeader>>,
     ws: WebSocketUpgrade,
 ) -> axum::response::Response {
     if protocol_version != rpc::PROTOCOL_VERSION {
@@ -1104,6 +1109,7 @@ pub async fn handle_websocket_request(
                     principal,
                     version,
                     country_code_header.map(|header| header.to_string()),
+                    system_id_header.map(|header| header.to_string()),
                     None,
                     Executor::Production,
                 )
@@ -4053,6 +4059,7 @@ async fn get_llm_api_token(
         has_llm_closed_beta_feature_flag,
         has_llm_subscription,
         session.current_plan(&db).await?,
+        session.system_id.clone(),
         &session.app_state.config,
     )?;
     response.send(proto::GetLlmTokenResponse { token })?;

crates/collab/src/tests/test_server.rs 🔗

@@ -244,6 +244,7 @@ impl TestServer {
                                 Principal::User(user),
                                 ZedVersion(SemanticVersion::new(1, 0, 0)),
                                 None,
+                                None,
                                 Some(connection_id_tx),
                                 Executor::Deterministic(cx.background_executor().clone()),
                             ))