Detailed changes
@@ -1067,6 +1067,8 @@ impl Client {
let proxy = http.proxy().cloned();
let credentials = credentials.clone();
let rpc_url = self.rpc_url(http, release_channel);
+ let system_id = self.telemetry.system_id();
+ let metrics_id = self.telemetry.metrics_id();
cx.background_executor().spawn(async move {
use HttpOrHttps::*;
@@ -1118,6 +1120,12 @@ impl Client {
"x-zed-release-channel",
HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
);
+ if let Some(system_id) = system_id {
+ request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
+ }
+ if let Some(metrics_id) = metrics_id {
+ request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
+ }
match url_scheme {
Https => {
@@ -533,6 +533,10 @@ impl Telemetry {
self.state.lock().metrics_id.clone()
}
+ pub fn system_id(self: &Arc<Self>) -> Option<Arc<str>> {
+ self.state.lock().system_id.clone()
+ }
+
pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().installation_id.clone()
}
@@ -61,6 +61,39 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
}
}
+pub struct SystemIdHeader(String);
+
+impl Header for SystemIdHeader {
+ fn name() -> &'static HeaderName {
+ static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
+ SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
+ }
+
+ fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
+ where
+ Self: Sized,
+ I: Iterator<Item = &'i axum::http::HeaderValue>,
+ {
+ let system_id = values
+ .next()
+ .ok_or_else(axum::headers::Error::invalid)?
+ .to_str()
+ .map_err(|_| axum::headers::Error::invalid())?;
+
+ Ok(Self(system_id.to_string()))
+ }
+
+ fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
+ unimplemented!()
+ }
+}
+
+impl std::fmt::Display for SystemIdHeader {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new()
.route("/user", get(get_authenticated_user))
@@ -1578,8 +1578,8 @@ fn for_snowflake(
})
}
-#[derive(Serialize, Deserialize)]
-struct SnowflakeRow {
+#[derive(Serialize, Deserialize, Debug)]
+pub struct SnowflakeRow {
pub time: chrono::DateTime<chrono::Utc>,
pub user_id: Option<String>,
pub device_id: Option<String>,
@@ -1588,3 +1588,42 @@ struct SnowflakeRow {
pub user_properties: Option<serde_json::Value>,
pub insert_id: Option<String>,
}
+
+impl SnowflakeRow {
+ pub fn new(
+ event_type: impl Into<String>,
+ metrics_id: Option<Uuid>,
+ is_staff: bool,
+ system_id: Option<String>,
+ event_properties: serde_json::Value,
+ ) -> Self {
+ Self {
+ time: chrono::Utc::now(),
+ event_type: event_type.into(),
+ device_id: system_id,
+ user_id: metrics_id.map(|id| id.to_string()),
+ insert_id: Some(uuid::Uuid::new_v4().to_string()),
+ event_properties,
+ user_properties: Some(json!({"is_staff": is_staff})),
+ }
+ }
+
+ pub async fn write(
+ self,
+ client: &Option<aws_sdk_kinesis::Client>,
+ stream: &Option<String>,
+ ) -> anyhow::Result<()> {
+ let Some((client, stream)) = client.as_ref().zip(stream.as_ref()) else {
+ return Ok(());
+ };
+ let row = serde_json::to_vec(&self)?;
+ client
+ .put_record()
+ .stream_name(stream)
+ .partition_key(&self.user_id.unwrap_or_default())
+ .data(row.into())
+ .send()
+ .await?;
+ Ok(())
+ }
+}
@@ -1,3 +1,5 @@
+use serde::Serialize;
+
/// A number of cents.
#[derive(
Debug,
@@ -12,6 +14,7 @@
derive_more::AddAssign,
derive_more::Sub,
derive_more::SubAssign,
+ Serialize,
)]
pub struct Cents(pub u32);
@@ -3,9 +3,11 @@ pub mod db;
mod telemetry;
mod token;
+use crate::api::events::SnowflakeRow;
+use crate::api::CloudflareIpCountryHeader;
+use crate::build_kinesis_client;
use crate::{
- api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents,
- Config, Error, Result,
+ build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result,
};
use anyhow::{anyhow, Context as _};
use authorization::authorize_access_to_language_model;
@@ -28,6 +30,7 @@ use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
+use serde_json::json;
use std::{
pin::Pin,
sync::Arc,
@@ -45,6 +48,7 @@ pub struct LlmState {
pub executor: Executor,
pub db: Arc<LlmDatabase>,
pub http_client: ReqwestClient,
+ pub kinesis_client: Option<aws_sdk_kinesis::Client>,
pub clickhouse_client: Option<clickhouse::Client>,
active_user_count_by_model:
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
@@ -77,6 +81,11 @@ impl LlmState {
executor,
db,
http_client,
+ kinesis_client: if config.kinesis_access_key.is_some() {
+ build_kinesis_client(&config).await.log_err()
+ } else {
+ None
+ },
clickhouse_client: config
.clickhouse_url
.as_ref()
@@ -521,25 +530,50 @@ async fn check_usage_limit(
UsageMeasure::TokensPerDay => "tokens_per_day",
};
- if let Some(client) = state.clickhouse_client.as_ref() {
- tracing::info!(
- target: "user rate limit",
- user_id = claims.user_id,
- login = claims.github_user_login,
- authn.jti = claims.jti,
- is_staff = claims.is_staff,
- provider = provider.to_string(),
- model = model.name,
- requests_this_minute = usage.requests_this_minute,
- tokens_this_minute = usage.tokens_this_minute,
- tokens_this_day = usage.tokens_this_day,
- users_in_recent_minutes = users_in_recent_minutes,
- users_in_recent_days = users_in_recent_days,
- max_requests_per_minute = per_user_max_requests_per_minute,
- max_tokens_per_minute = per_user_max_tokens_per_minute,
- max_tokens_per_day = per_user_max_tokens_per_day,
- );
+ tracing::info!(
+ target: "user rate limit",
+ user_id = claims.user_id,
+ login = claims.github_user_login,
+ authn.jti = claims.jti,
+ is_staff = claims.is_staff,
+ provider = provider.to_string(),
+ model = model.name,
+ requests_this_minute = usage.requests_this_minute,
+ tokens_this_minute = usage.tokens_this_minute,
+ tokens_this_day = usage.tokens_this_day,
+ users_in_recent_minutes = users_in_recent_minutes,
+ users_in_recent_days = users_in_recent_days,
+ max_requests_per_minute = per_user_max_requests_per_minute,
+ max_tokens_per_minute = per_user_max_tokens_per_minute,
+ max_tokens_per_day = per_user_max_tokens_per_day,
+ );
+
+ SnowflakeRow::new(
+ "Language Model Rate Limited",
+ claims.metrics_id,
+ claims.is_staff,
+ claims.system_id.clone(),
+ json!({
+ "usage": usage,
+ "users_in_recent_minutes": users_in_recent_minutes,
+ "users_in_recent_days": users_in_recent_days,
+ "max_requests_per_minute": per_user_max_requests_per_minute,
+ "max_tokens_per_minute": per_user_max_tokens_per_minute,
+ "max_tokens_per_day": per_user_max_tokens_per_day,
+ "plan": match claims.plan {
+ Plan::Free => "free".to_string(),
+ Plan::ZedPro => "zed_pro".to_string(),
+ },
+ "model": model.name.clone(),
+ "provider": provider.to_string(),
+ "usage_measure": resource.to_string(),
+ }),
+ )
+ .write(&state.kinesis_client, &state.config.kinesis_stream)
+ .await
+ .log_err();
+ if let Some(client) = state.clickhouse_client.as_ref() {
report_llm_rate_limit(
client,
LlmRateLimitEventRow {
@@ -652,6 +686,27 @@ impl<S> Drop for TokenCountingStream<S> {
tokens_this_minute = usage.tokens_this_minute,
);
+ let properties = json!({
+ "plan": match claims.plan {
+ Plan::Free => "free".to_string(),
+ Plan::ZedPro => "zed_pro".to_string(),
+ },
+ "model": model,
+ "provider": provider,
+ "usage": usage,
+ "tokens": tokens
+ });
+ SnowflakeRow::new(
+ "Language Model Used",
+ claims.metrics_id,
+ claims.is_staff,
+ claims.system_id.clone(),
+ properties,
+ )
+ .write(&state.kinesis_client, &state.config.kinesis_stream)
+ .await
+ .log_err();
+
if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
report_llm_usage(
clickhouse_client,
@@ -9,7 +9,7 @@ use strum::IntoEnumIterator as _;
use super::*;
-#[derive(Debug, PartialEq, Clone, Copy, Default)]
+#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
pub struct TokenUsage {
pub input: usize,
pub input_cache_creation: usize,
@@ -23,7 +23,7 @@ impl TokenUsage {
}
}
-#[derive(Debug, PartialEq, Clone, Copy)]
+#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
pub struct Usage {
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
@@ -8,6 +8,7 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
+use uuid::Uuid;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -16,6 +17,10 @@ pub struct LlmTokenClaims {
pub exp: u64,
pub jti: String,
pub user_id: u64,
+ #[serde(default)]
+ pub system_id: Option<String>,
+ #[serde(default)]
+ pub metrics_id: Option<Uuid>,
pub github_user_login: String,
pub is_staff: bool,
pub has_llm_closed_beta_feature_flag: bool,
@@ -36,6 +41,7 @@ impl LlmTokenClaims {
has_llm_closed_beta_feature_flag: bool,
has_llm_subscription: bool,
plan: rpc::proto::Plan,
+ system_id: Option<String>,
config: &Config,
) -> Result<String> {
let secret = config
@@ -49,6 +55,8 @@ impl LlmTokenClaims {
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user.id.to_proto(),
+ system_id,
+ metrics_id: Some(user.metrics_id),
github_user_login: user.github_login.clone(),
is_staff,
has_llm_closed_beta_feature_flag,
@@ -1,6 +1,6 @@
mod connection_pool;
-use crate::api::CloudflareIpCountryHeader;
+use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::llm::LlmTokenClaims;
use crate::{
auth,
@@ -137,6 +137,7 @@ struct Session {
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
+ system_id: Option<String>,
_executor: Executor,
}
@@ -682,6 +683,7 @@ impl Server {
principal: Principal,
zed_version: ZedVersion,
geoip_country_code: Option<String>,
+ system_id: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor,
) -> impl Future<Output = ()> {
@@ -737,6 +739,7 @@ impl Server {
app_state: this.app_state.clone(),
http_client,
geoip_country_code,
+ system_id,
_executor: executor.clone(),
supermaven_client,
};
@@ -1056,6 +1059,7 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
.layer(Extension(server))
}
+#[allow(clippy::too_many_arguments)]
pub async fn handle_websocket_request(
TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
app_version_header: Option<TypedHeader<AppVersionHeader>>,
@@ -1063,6 +1067,7 @@ pub async fn handle_websocket_request(
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
+ system_id_header: Option<TypedHeader<SystemIdHeader>>,
ws: WebSocketUpgrade,
) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION {
@@ -1104,6 +1109,7 @@ pub async fn handle_websocket_request(
principal,
version,
country_code_header.map(|header| header.to_string()),
+ system_id_header.map(|header| header.to_string()),
None,
Executor::Production,
)
@@ -4053,6 +4059,7 @@ async fn get_llm_api_token(
has_llm_closed_beta_feature_flag,
has_llm_subscription,
session.current_plan(&db).await?,
+ session.system_id.clone(),
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
@@ -244,6 +244,7 @@ impl TestServer {
Principal::User(user),
ZedVersion(SemanticVersion::new(1, 0, 0)),
None,
+ None,
Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()),
))