diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 041973e8841024507530fd13624683c9ab5e5023..a20584fabd49db951dbbb5efbb9312f83a78d190 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1067,6 +1067,8 @@ impl Client { let proxy = http.proxy().cloned(); let credentials = credentials.clone(); let rpc_url = self.rpc_url(http, release_channel); + let system_id = self.telemetry.system_id(); + let metrics_id = self.telemetry.metrics_id(); cx.background_executor().spawn(async move { use HttpOrHttps::*; @@ -1118,6 +1120,12 @@ impl Client { "x-zed-release-channel", HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?, ); + if let Some(system_id) = system_id { + request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?); + } + if let Some(metrics_id) = metrics_id { + request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?); + } match url_scheme { Https => { diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 583f9757c49cb7a9aa5b2316935f1681fb34e756..eef2a8215fe91090a9763840cc53f706dfbe5a3f 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -533,6 +533,10 @@ impl Telemetry { self.state.lock().metrics_id.clone() } + pub fn system_id(self: &Arc) -> Option> { + self.state.lock().system_id.clone() + } + pub fn installation_id(self: &Arc) -> Option> { self.state.lock().installation_id.clone() } diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 46ca5906c5bb4f1a89ffb3427a1797606f2c013c..7adf17ac06c66f85b100b0154ddbb55368cbd91c 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -61,6 +61,39 @@ impl std::fmt::Display for CloudflareIpCountryHeader { } } +pub struct SystemIdHeader(String); + +impl Header for SystemIdHeader { + fn name() -> &'static HeaderName { + static SYSTEM_ID_HEADER: OnceLock = OnceLock::new(); + SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id")) + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let system_id = values + .next() + .ok_or_else(axum::headers::Error::invalid)? + .to_str() + .map_err(|_| axum::headers::Error::invalid())?; + + Ok(Self(system_id.to_string())) + } + + fn encode>(&self, _values: &mut E) { + unimplemented!() + } +} + +impl std::fmt::Display for SystemIdHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() .route("/user", get(get_authenticated_user)) diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index 3cda6a397acd720696f95b1edda4ff313cb3a574..11137cb4e96f164410e9fe71548d64d70224bf4d 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -1578,8 +1578,8 @@ fn for_snowflake( }) } -#[derive(Serialize, Deserialize)] -struct SnowflakeRow { +#[derive(Serialize, Deserialize, Debug)] +pub struct SnowflakeRow { pub time: chrono::DateTime, pub user_id: Option, pub device_id: Option, @@ -1588,3 +1588,42 @@ struct SnowflakeRow { pub user_properties: Option, pub insert_id: Option, } + +impl SnowflakeRow { + pub fn new( + event_type: impl Into, + metrics_id: Option, + is_staff: bool, + system_id: Option, + event_properties: serde_json::Value, + ) -> Self { + Self { + time: chrono::Utc::now(), + event_type: event_type.into(), + device_id: system_id, + user_id: metrics_id.map(|id| id.to_string()), + insert_id: Some(uuid::Uuid::new_v4().to_string()), + event_properties, + user_properties: Some(json!({"is_staff": is_staff})), + } + } + + pub async fn write( + self, + client: &Option, + stream: &Option, + ) -> anyhow::Result<()> { + let Some((client, stream)) = client.as_ref().zip(stream.as_ref()) else { + return Ok(()); + }; + let row = serde_json::to_vec(&self)?; + client + .put_record() + .stream_name(stream) + .partition_key(&self.user_id.unwrap_or_default()) + .data(row.into()) + .send() + .await?; + Ok(()) + } +} diff --git a/crates/collab/src/cents.rs b/crates/collab/src/cents.rs index defbcea4e26d39a34ee214889259c6d7304bd3c3..a05971f1417339664d667665ddff63a13237f4dc 100644 --- a/crates/collab/src/cents.rs +++ b/crates/collab/src/cents.rs @@ -1,3 +1,5 @@ +use serde::Serialize; + /// A number of cents. #[derive( Debug, @@ -12,6 +14,7 @@ derive_more::AddAssign, derive_more::Sub, derive_more::SubAssign, + Serialize, )] pub struct Cents(pub u32); diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index fa48ec95ea0b6f41cfa960c6fe82e26e334c9b93..603b76db739e19159c50afe24f16598d614919c3 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -3,9 +3,11 @@ pub mod db; mod telemetry; mod token; +use crate::api::events::SnowflakeRow; +use crate::api::CloudflareIpCountryHeader; +use crate::build_kinesis_client; use crate::{ - api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents, - Config, Error, Result, + build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result, }; use anyhow::{anyhow, Context as _}; use authorization::authorize_access_to_language_model; @@ -28,6 +30,7 @@ use rpc::{ proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, }; use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME}; +use serde_json::json; use std::{ pin::Pin, sync::Arc, @@ -45,6 +48,7 @@ pub struct LlmState { pub executor: Executor, pub db: Arc, pub http_client: ReqwestClient, + pub kinesis_client: Option, pub clickhouse_client: Option, active_user_count_by_model: RwLock, ActiveUserCount)>>, @@ -77,6 +81,11 @@ impl LlmState { executor, db, http_client, + kinesis_client: if config.kinesis_access_key.is_some() { + build_kinesis_client(&config).await.log_err() + } else { + None + }, clickhouse_client: config .clickhouse_url .as_ref() @@ -521,25 +530,50 @@ async fn check_usage_limit( UsageMeasure::TokensPerDay => "tokens_per_day", }; - if let Some(client) = state.clickhouse_client.as_ref() { - tracing::info!( - target: "user rate limit", - user_id = claims.user_id, - login = claims.github_user_login, - authn.jti = claims.jti, - is_staff = claims.is_staff, - provider = provider.to_string(), - model = model.name, - requests_this_minute = usage.requests_this_minute, - tokens_this_minute = usage.tokens_this_minute, - tokens_this_day = usage.tokens_this_day, - users_in_recent_minutes = users_in_recent_minutes, - users_in_recent_days = users_in_recent_days, - max_requests_per_minute = per_user_max_requests_per_minute, - max_tokens_per_minute = per_user_max_tokens_per_minute, - max_tokens_per_day = per_user_max_tokens_per_day, - ); + tracing::info!( + target: "user rate limit", + user_id = claims.user_id, + login = claims.github_user_login, + authn.jti = claims.jti, + is_staff = claims.is_staff, + provider = provider.to_string(), + model = model.name, + requests_this_minute = usage.requests_this_minute, + tokens_this_minute = usage.tokens_this_minute, + tokens_this_day = usage.tokens_this_day, + users_in_recent_minutes = users_in_recent_minutes, + users_in_recent_days = users_in_recent_days, + max_requests_per_minute = per_user_max_requests_per_minute, + max_tokens_per_minute = per_user_max_tokens_per_minute, + max_tokens_per_day = per_user_max_tokens_per_day, + ); + + SnowflakeRow::new( + "Language Model Rate Limited", + claims.metrics_id, + claims.is_staff, + claims.system_id.clone(), + json!({ + "usage": usage, + "users_in_recent_minutes": users_in_recent_minutes, + "users_in_recent_days": users_in_recent_days, + "max_requests_per_minute": per_user_max_requests_per_minute, + "max_tokens_per_minute": per_user_max_tokens_per_minute, + "max_tokens_per_day": per_user_max_tokens_per_day, + "plan": match claims.plan { + Plan::Free => "free".to_string(), + Plan::ZedPro => "zed_pro".to_string(), + }, + "model": model.name.clone(), + "provider": provider.to_string(), + "usage_measure": resource.to_string(), + }), + ) + .write(&state.kinesis_client, &state.config.kinesis_stream) + .await + .log_err(); + if let Some(client) = state.clickhouse_client.as_ref() { report_llm_rate_limit( client, LlmRateLimitEventRow { @@ -652,6 +686,27 @@ impl Drop for TokenCountingStream { tokens_this_minute = usage.tokens_this_minute, ); + let properties = json!({ + "plan": match claims.plan { + Plan::Free => "free".to_string(), + Plan::ZedPro => "zed_pro".to_string(), + }, + "model": model, + "provider": provider, + "usage": usage, + "tokens": tokens + }); + SnowflakeRow::new( + "Language Model Used", + claims.metrics_id, + claims.is_staff, + claims.system_id.clone(), + properties, + ) + .write(&state.kinesis_client, &state.config.kinesis_stream) + .await + .log_err(); + if let Some(clickhouse_client) = state.clickhouse_client.as_ref() { report_llm_usage( clickhouse_client, diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index f2628217436edf54bf2bd4ba963a0b857fcb78a7..27e8039f54aef1d337266c545d9fab058cd6cb95 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -9,7 +9,7 @@ use strum::IntoEnumIterator as _; use super::*; -#[derive(Debug, PartialEq, Clone, Copy, Default)] +#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)] pub struct TokenUsage { pub input: usize, pub input_cache_creation: usize, @@ -23,7 +23,7 @@ impl TokenUsage { } } -#[derive(Debug, PartialEq, Clone, Copy)] +#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)] pub struct Usage { pub requests_this_minute: usize, pub tokens_this_minute: usize, diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index 35f7cf26e737317a0ed72daa88e3f193e322a2c2..7e0706e2d5a12dead105dd6627cdb9c27da6053b 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -8,6 +8,7 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::time::Duration; use thiserror::Error; +use uuid::Uuid; #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -16,6 +17,10 @@ pub struct LlmTokenClaims { pub exp: u64, pub jti: String, pub user_id: u64, + #[serde(default)] + pub system_id: Option, + #[serde(default)] + pub metrics_id: Option, pub github_user_login: String, pub is_staff: bool, pub has_llm_closed_beta_feature_flag: bool, @@ -36,6 +41,7 @@ impl LlmTokenClaims { has_llm_closed_beta_feature_flag: bool, has_llm_subscription: bool, plan: rpc::proto::Plan, + system_id: Option, config: &Config, ) -> Result { let secret = config @@ -49,6 +55,8 @@ impl LlmTokenClaims { exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, jti: uuid::Uuid::new_v4().to_string(), user_id: user.id.to_proto(), + system_id, + metrics_id: Some(user.metrics_id), github_user_login: user.github_login.clone(), is_staff, has_llm_closed_beta_feature_flag, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 1184c4861820e788459bf58bdd5a2bb6bd2f03e2..a17d4924b72d53d67dd236486c072322b5d28279 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,6 +1,6 @@ mod connection_pool; -use crate::api::CloudflareIpCountryHeader; +use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::llm::LlmTokenClaims; use crate::{ auth, @@ -137,6 +137,7 @@ struct Session { /// The GeoIP country code for the user. #[allow(unused)] geoip_country_code: Option, + system_id: Option, _executor: Executor, } @@ -682,6 +683,7 @@ impl Server { principal: Principal, zed_version: ZedVersion, geoip_country_code: Option, + system_id: Option, send_connection_id: Option>, executor: Executor, ) -> impl Future { @@ -737,6 +739,7 @@ impl Server { app_state: this.app_state.clone(), http_client, geoip_country_code, + system_id, _executor: executor.clone(), supermaven_client, }; @@ -1056,6 +1059,7 @@ pub fn routes(server: Arc) -> Router<(), Body> { .layer(Extension(server)) } +#[allow(clippy::too_many_arguments)] pub async fn handle_websocket_request( TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, app_version_header: Option>, @@ -1063,6 +1067,7 @@ pub async fn handle_websocket_request( Extension(server): Extension>, Extension(principal): Extension, country_code_header: Option>, + system_id_header: Option>, ws: WebSocketUpgrade, ) -> axum::response::Response { if protocol_version != rpc::PROTOCOL_VERSION { @@ -1104,6 +1109,7 @@ pub async fn handle_websocket_request( principal, version, country_code_header.map(|header| header.to_string()), + system_id_header.map(|header| header.to_string()), None, Executor::Production, ) @@ -4053,6 +4059,7 @@ async fn get_llm_api_token( has_llm_closed_beta_feature_flag, has_llm_subscription, session.current_plan(&db).await?, + session.system_id.clone(), &session.app_state.config, )?; response.send(proto::GetLlmTokenResponse { token })?; diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 8a09f06092fe72703a1901b4bdcd825a83aae8ce..c93cce9770e58fce1a3073be65fe0aeb55d556a2 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -244,6 +244,7 @@ impl TestServer { Principal::User(user), ZedVersion(SemanticVersion::new(1, 0, 0)), None, + None, Some(connection_id_tx), Executor::Deterministic(cx.background_executor().clone()), ))