1mod authorization;
2pub mod db;
3mod telemetry;
4mod token;
5
6use crate::api::events::SnowflakeRow;
7use crate::api::CloudflareIpCountryHeader;
8use crate::build_kinesis_client;
9use crate::{
10 build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result,
11};
12use anyhow::{anyhow, Context as _};
13use authorization::authorize_access_to_language_model;
14use axum::routing::get;
15use axum::{
16 body::Body,
17 http::{self, HeaderName, HeaderValue, Request, StatusCode},
18 middleware::{self, Next},
19 response::{IntoResponse, Response},
20 routing::post,
21 Extension, Json, Router, TypedHeader,
22};
23use chrono::{DateTime, Duration, Utc};
24use collections::HashMap;
25use db::TokenUsage;
26use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
27use futures::{Stream, StreamExt as _};
28use reqwest_client::ReqwestClient;
29use rpc::{
30 proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
31};
32use rpc::{
33 ListModelsResponse, PredictEditsParams, PredictEditsResponse,
34 MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
35};
36use serde_json::json;
37use std::{
38 pin::Pin,
39 sync::Arc,
40 task::{Context, Poll},
41};
42use strum::IntoEnumIterator;
43use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
44use tokio::sync::RwLock;
45use util::ResultExt;
46
47pub use token::*;
48
49pub struct LlmState {
50 pub config: Config,
51 pub executor: Executor,
52 pub db: Arc<LlmDatabase>,
53 pub http_client: ReqwestClient,
54 pub kinesis_client: Option<aws_sdk_kinesis::Client>,
55 pub clickhouse_client: Option<clickhouse::Client>,
56 active_user_count_by_model:
57 RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
58}
59
60const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
61
62impl LlmState {
63 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
64 let database_url = config
65 .llm_database_url
66 .as_ref()
67 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
68 let max_connections = config
69 .llm_database_max_connections
70 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
71
72 let mut db_options = db::ConnectOptions::new(database_url);
73 db_options.max_connections(max_connections);
74 let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
75 db.initialize().await?;
76
77 let db = Arc::new(db);
78
79 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
80 let http_client =
81 ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
82
83 let this = Self {
84 executor,
85 db,
86 http_client,
87 kinesis_client: if config.kinesis_access_key.is_some() {
88 build_kinesis_client(&config).await.log_err()
89 } else {
90 None
91 },
92 clickhouse_client: config
93 .clickhouse_url
94 .as_ref()
95 .and_then(|_| build_clickhouse_client(&config).log_err()),
96 active_user_count_by_model: RwLock::new(HashMap::default()),
97 config,
98 };
99
100 Ok(Arc::new(this))
101 }
102
103 pub async fn get_active_user_count(
104 &self,
105 provider: LanguageModelProvider,
106 model: &str,
107 ) -> Result<ActiveUserCount> {
108 let now = Utc::now();
109
110 {
111 let active_user_count_by_model = self.active_user_count_by_model.read().await;
112 if let Some((last_updated, count)) =
113 active_user_count_by_model.get(&(provider, model.to_string()))
114 {
115 if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
116 return Ok(*count);
117 }
118 }
119 }
120
121 let mut cache = self.active_user_count_by_model.write().await;
122 let new_count = self.db.get_active_user_count(provider, model, now).await?;
123 cache.insert((provider, model.to_string()), (now, new_count));
124 Ok(new_count)
125 }
126}
127
128pub fn routes() -> Router<(), Body> {
129 Router::new()
130 .route("/models", get(list_models))
131 .route("/completion", post(perform_completion))
132 .route("/predict_edits", post(predict_edits))
133 .layer(middleware::from_fn(validate_api_token))
134}
135
136async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
137 let token = req
138 .headers()
139 .get(http::header::AUTHORIZATION)
140 .and_then(|header| header.to_str().ok())
141 .ok_or_else(|| {
142 Error::http(
143 StatusCode::BAD_REQUEST,
144 "missing authorization header".to_string(),
145 )
146 })?
147 .strip_prefix("Bearer ")
148 .ok_or_else(|| {
149 Error::http(
150 StatusCode::BAD_REQUEST,
151 "invalid authorization header".to_string(),
152 )
153 })?;
154
155 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
156 match LlmTokenClaims::validate(token, &state.config) {
157 Ok(claims) => {
158 if state.db.is_access_token_revoked(&claims.jti).await? {
159 return Err(Error::http(
160 StatusCode::UNAUTHORIZED,
161 "unauthorized".to_string(),
162 ));
163 }
164
165 tracing::Span::current()
166 .record("user_id", claims.user_id)
167 .record("login", claims.github_user_login.clone())
168 .record("authn.jti", &claims.jti)
169 .record("is_staff", claims.is_staff);
170
171 req.extensions_mut().insert(claims);
172 Ok::<_, Error>(next.run(req).await.into_response())
173 }
174 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
175 StatusCode::UNAUTHORIZED,
176 "unauthorized".to_string(),
177 [(
178 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
179 HeaderValue::from_static("true"),
180 )]
181 .into_iter()
182 .collect(),
183 )),
184 Err(_err) => Err(Error::http(
185 StatusCode::UNAUTHORIZED,
186 "unauthorized".to_string(),
187 )),
188 }
189}
190
191async fn list_models(
192 Extension(state): Extension<Arc<LlmState>>,
193 Extension(claims): Extension<LlmTokenClaims>,
194 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
195) -> Result<Json<ListModelsResponse>> {
196 let country_code = country_code_header.map(|header| header.to_string());
197
198 let mut accessible_models = Vec::new();
199
200 for (provider, model) in state.db.all_models() {
201 let authorize_result = authorize_access_to_language_model(
202 &state.config,
203 &claims,
204 country_code.as_deref(),
205 provider,
206 &model.name,
207 );
208
209 if authorize_result.is_ok() {
210 accessible_models.push(rpc::LanguageModel {
211 provider,
212 name: model.name,
213 });
214 }
215 }
216
217 Ok(Json(ListModelsResponse {
218 models: accessible_models,
219 }))
220}
221
222async fn perform_completion(
223 Extension(state): Extension<Arc<LlmState>>,
224 Extension(claims): Extension<LlmTokenClaims>,
225 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
226 Json(params): Json<PerformCompletionParams>,
227) -> Result<impl IntoResponse> {
228 let model = normalize_model_name(
229 state.db.model_names_for_provider(params.provider),
230 params.model,
231 );
232
233 authorize_access_to_language_model(
234 &state.config,
235 &claims,
236 country_code_header
237 .map(|header| header.to_string())
238 .as_deref(),
239 params.provider,
240 &model,
241 )?;
242
243 check_usage_limit(&state, params.provider, &model, &claims).await?;
244
245 let stream = match params.provider {
246 LanguageModelProvider::Anthropic => {
247 let api_key = if claims.is_staff {
248 state
249 .config
250 .anthropic_staff_api_key
251 .as_ref()
252 .context("no Anthropic AI staff API key configured on the server")?
253 } else {
254 state
255 .config
256 .anthropic_api_key
257 .as_ref()
258 .context("no Anthropic AI API key configured on the server")?
259 };
260
261 let mut request: anthropic::Request =
262 serde_json::from_str(params.provider_request.get())?;
263
264 // Override the model on the request with the latest version of the model that is
265 // known to the server.
266 //
267 // Right now, we use the version that's defined in `model.id()`, but we will likely
268 // want to change this code once a new version of an Anthropic model is released,
269 // so that users can use the new version, without having to update Zed.
270 request.model = match model.as_str() {
271 "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
272 "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
273 "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
274 "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
275 _ => request.model,
276 };
277
278 let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
279 &state.http_client,
280 anthropic::ANTHROPIC_API_URL,
281 api_key,
282 request,
283 )
284 .await
285 .map_err(|err| match err {
286 anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
287 Some(anthropic::ApiErrorCode::RateLimitError) => {
288 tracing::info!(
289 target: "upstream rate limit exceeded",
290 user_id = claims.user_id,
291 login = claims.github_user_login,
292 authn.jti = claims.jti,
293 is_staff = claims.is_staff,
294 provider = params.provider.to_string(),
295 model = model
296 );
297
298 Error::http(
299 StatusCode::TOO_MANY_REQUESTS,
300 "Upstream Anthropic rate limit exceeded.".to_string(),
301 )
302 }
303 Some(anthropic::ApiErrorCode::InvalidRequestError) => {
304 Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
305 }
306 Some(anthropic::ApiErrorCode::OverloadedError) => {
307 Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
308 }
309 Some(_) => {
310 Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
311 }
312 None => Error::Internal(anyhow!(err)),
313 },
314 anthropic::AnthropicError::Other(err) => Error::Internal(err),
315 })?;
316
317 if let Some(rate_limit_info) = rate_limit_info {
318 tracing::info!(
319 target: "upstream rate limit",
320 is_staff = claims.is_staff,
321 provider = params.provider.to_string(),
322 model = model,
323 tokens_remaining = rate_limit_info.tokens_remaining,
324 requests_remaining = rate_limit_info.requests_remaining,
325 requests_reset = ?rate_limit_info.requests_reset,
326 tokens_reset = ?rate_limit_info.tokens_reset,
327 );
328 }
329
330 chunks
331 .map(move |event| {
332 let chunk = event?;
333 let (
334 input_tokens,
335 output_tokens,
336 cache_creation_input_tokens,
337 cache_read_input_tokens,
338 ) = match &chunk {
339 anthropic::Event::MessageStart {
340 message: anthropic::Response { usage, .. },
341 }
342 | anthropic::Event::MessageDelta { usage, .. } => (
343 usage.input_tokens.unwrap_or(0) as usize,
344 usage.output_tokens.unwrap_or(0) as usize,
345 usage.cache_creation_input_tokens.unwrap_or(0) as usize,
346 usage.cache_read_input_tokens.unwrap_or(0) as usize,
347 ),
348 _ => (0, 0, 0, 0),
349 };
350
351 anyhow::Ok(CompletionChunk {
352 bytes: serde_json::to_vec(&chunk).unwrap(),
353 input_tokens,
354 output_tokens,
355 cache_creation_input_tokens,
356 cache_read_input_tokens,
357 })
358 })
359 .boxed()
360 }
361 LanguageModelProvider::OpenAi => {
362 let api_key = state
363 .config
364 .openai_api_key
365 .as_ref()
366 .context("no OpenAI API key configured on the server")?;
367 let chunks = open_ai::stream_completion(
368 &state.http_client,
369 open_ai::OPEN_AI_API_URL,
370 api_key,
371 serde_json::from_str(params.provider_request.get())?,
372 )
373 .await?;
374
375 chunks
376 .map(|event| {
377 event.map(|chunk| {
378 let input_tokens =
379 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
380 let output_tokens =
381 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
382 CompletionChunk {
383 bytes: serde_json::to_vec(&chunk).unwrap(),
384 input_tokens,
385 output_tokens,
386 cache_creation_input_tokens: 0,
387 cache_read_input_tokens: 0,
388 }
389 })
390 })
391 .boxed()
392 }
393 LanguageModelProvider::Google => {
394 let api_key = state
395 .config
396 .google_ai_api_key
397 .as_ref()
398 .context("no Google AI API key configured on the server")?;
399 let chunks = google_ai::stream_generate_content(
400 &state.http_client,
401 google_ai::API_URL,
402 api_key,
403 serde_json::from_str(params.provider_request.get())?,
404 )
405 .await?;
406
407 chunks
408 .map(|event| {
409 event.map(|chunk| {
410 // TODO - implement token counting for Google AI
411 CompletionChunk {
412 bytes: serde_json::to_vec(&chunk).unwrap(),
413 input_tokens: 0,
414 output_tokens: 0,
415 cache_creation_input_tokens: 0,
416 cache_read_input_tokens: 0,
417 }
418 })
419 })
420 .boxed()
421 }
422 };
423
424 Ok(Response::new(Body::wrap_stream(TokenCountingStream {
425 state,
426 claims,
427 provider: params.provider,
428 model,
429 tokens: TokenUsage::default(),
430 inner_stream: stream,
431 })))
432}
433
434fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
435 if let Some(known_model_name) = known_models
436 .iter()
437 .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
438 .max_by_key(|known_model_name| known_model_name.len())
439 {
440 known_model_name.to_string()
441 } else {
442 name
443 }
444}
445
446async fn predict_edits(
447 Extension(state): Extension<Arc<LlmState>>,
448 Extension(claims): Extension<LlmTokenClaims>,
449 _country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
450 Json(params): Json<PredictEditsParams>,
451) -> Result<impl IntoResponse> {
452 if !claims.is_staff {
453 return Err(anyhow!("not found"))?;
454 }
455
456 let api_url = state
457 .config
458 .prediction_api_url
459 .as_ref()
460 .context("no PREDICTION_API_URL configured on the server")?;
461 let api_key = state
462 .config
463 .prediction_api_key
464 .as_ref()
465 .context("no PREDICTION_API_KEY configured on the server")?;
466 let model = state
467 .config
468 .prediction_model
469 .as_ref()
470 .context("no PREDICTION_MODEL configured on the server")?;
471 let prompt = include_str!("./llm/prediction_prompt.md")
472 .replace("<events>", ¶ms.input_events)
473 .replace("<excerpt>", ¶ms.input_excerpt);
474 let mut response = open_ai::complete_text(
475 &state.http_client,
476 api_url,
477 api_key,
478 open_ai::CompletionRequest {
479 model: model.to_string(),
480 prompt: prompt.clone(),
481 max_tokens: 1024,
482 temperature: 0.,
483 prediction: Some(open_ai::Prediction::Content {
484 content: params.input_excerpt,
485 }),
486 rewrite_speculation: Some(true),
487 },
488 )
489 .await?;
490 let choice = response
491 .choices
492 .pop()
493 .context("no output from completion response")?;
494 Ok(Json(PredictEditsResponse {
495 output_excerpt: choice.text,
496 }))
497}
498
499/// The maximum monthly spending an individual user can reach on the free tier
500/// before they have to pay.
501pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
502
503/// The default value to use for maximum spend per month if the user did not
504/// explicitly set a maximum spend.
505///
506/// Used to prevent surprise bills.
507pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
508
509async fn check_usage_limit(
510 state: &Arc<LlmState>,
511 provider: LanguageModelProvider,
512 model_name: &str,
513 claims: &LlmTokenClaims,
514) -> Result<()> {
515 if claims.is_staff {
516 return Ok(());
517 }
518
519 let model = state.db.model(provider, model_name)?;
520 let usage = state
521 .db
522 .get_usage(
523 UserId::from_proto(claims.user_id),
524 provider,
525 model_name,
526 Utc::now(),
527 )
528 .await?;
529 let free_tier = claims.free_tier_monthly_spending_limit();
530
531 if usage.spending_this_month >= free_tier {
532 if !claims.has_llm_subscription {
533 return Err(Error::http(
534 StatusCode::PAYMENT_REQUIRED,
535 "Maximum spending limit reached for this month.".to_string(),
536 ));
537 }
538
539 if (usage.spending_this_month - free_tier) >= Cents(claims.max_monthly_spend_in_cents) {
540 return Err(Error::Http(
541 StatusCode::FORBIDDEN,
542 "Maximum spending limit reached for this month.".to_string(),
543 [(
544 HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
545 HeaderValue::from_static("true"),
546 )]
547 .into_iter()
548 .collect(),
549 ));
550 }
551 }
552
553 let active_users = state.get_active_user_count(provider, model_name).await?;
554
555 let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
556 let users_in_recent_days = active_users.users_in_recent_days.max(1);
557
558 let per_user_max_requests_per_minute =
559 model.max_requests_per_minute as usize / users_in_recent_minutes;
560 let per_user_max_tokens_per_minute =
561 model.max_tokens_per_minute as usize / users_in_recent_minutes;
562 let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
563
564 let checks = [
565 (
566 usage.requests_this_minute,
567 per_user_max_requests_per_minute,
568 UsageMeasure::RequestsPerMinute,
569 ),
570 (
571 usage.tokens_this_minute,
572 per_user_max_tokens_per_minute,
573 UsageMeasure::TokensPerMinute,
574 ),
575 (
576 usage.tokens_this_day,
577 per_user_max_tokens_per_day,
578 UsageMeasure::TokensPerDay,
579 ),
580 ];
581
582 for (used, limit, usage_measure) in checks {
583 if used > limit {
584 let resource = match usage_measure {
585 UsageMeasure::RequestsPerMinute => "requests_per_minute",
586 UsageMeasure::TokensPerMinute => "tokens_per_minute",
587 UsageMeasure::TokensPerDay => "tokens_per_day",
588 };
589
590 tracing::info!(
591 target: "user rate limit",
592 user_id = claims.user_id,
593 login = claims.github_user_login,
594 authn.jti = claims.jti,
595 is_staff = claims.is_staff,
596 provider = provider.to_string(),
597 model = model.name,
598 requests_this_minute = usage.requests_this_minute,
599 tokens_this_minute = usage.tokens_this_minute,
600 tokens_this_day = usage.tokens_this_day,
601 users_in_recent_minutes = users_in_recent_minutes,
602 users_in_recent_days = users_in_recent_days,
603 max_requests_per_minute = per_user_max_requests_per_minute,
604 max_tokens_per_minute = per_user_max_tokens_per_minute,
605 max_tokens_per_day = per_user_max_tokens_per_day,
606 );
607
608 SnowflakeRow::new(
609 "Language Model Rate Limited",
610 claims.metrics_id,
611 claims.is_staff,
612 claims.system_id.clone(),
613 json!({
614 "usage": usage,
615 "users_in_recent_minutes": users_in_recent_minutes,
616 "users_in_recent_days": users_in_recent_days,
617 "max_requests_per_minute": per_user_max_requests_per_minute,
618 "max_tokens_per_minute": per_user_max_tokens_per_minute,
619 "max_tokens_per_day": per_user_max_tokens_per_day,
620 "plan": match claims.plan {
621 Plan::Free => "free".to_string(),
622 Plan::ZedPro => "zed_pro".to_string(),
623 },
624 "model": model.name.clone(),
625 "provider": provider.to_string(),
626 "usage_measure": resource.to_string(),
627 }),
628 )
629 .write(&state.kinesis_client, &state.config.kinesis_stream)
630 .await
631 .log_err();
632
633 if let Some(client) = state.clickhouse_client.as_ref() {
634 report_llm_rate_limit(
635 client,
636 LlmRateLimitEventRow {
637 time: Utc::now().timestamp_millis(),
638 user_id: claims.user_id as i32,
639 is_staff: claims.is_staff,
640 plan: match claims.plan {
641 Plan::Free => "free".to_string(),
642 Plan::ZedPro => "zed_pro".to_string(),
643 },
644 model: model.name.clone(),
645 provider: provider.to_string(),
646 usage_measure: resource.to_string(),
647 requests_this_minute: usage.requests_this_minute as u64,
648 tokens_this_minute: usage.tokens_this_minute as u64,
649 tokens_this_day: usage.tokens_this_day as u64,
650 users_in_recent_minutes: users_in_recent_minutes as u64,
651 users_in_recent_days: users_in_recent_days as u64,
652 max_requests_per_minute: per_user_max_requests_per_minute as u64,
653 max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
654 max_tokens_per_day: per_user_max_tokens_per_day as u64,
655 },
656 )
657 .await
658 .log_err();
659 }
660
661 return Err(Error::http(
662 StatusCode::TOO_MANY_REQUESTS,
663 format!("Rate limit exceeded. Maximum {} reached.", resource),
664 ));
665 }
666 }
667
668 Ok(())
669}
670
671struct CompletionChunk {
672 bytes: Vec<u8>,
673 input_tokens: usize,
674 output_tokens: usize,
675 cache_creation_input_tokens: usize,
676 cache_read_input_tokens: usize,
677}
678
679struct TokenCountingStream<S> {
680 state: Arc<LlmState>,
681 claims: LlmTokenClaims,
682 provider: LanguageModelProvider,
683 model: String,
684 tokens: TokenUsage,
685 inner_stream: S,
686}
687
688impl<S> Stream for TokenCountingStream<S>
689where
690 S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
691{
692 type Item = Result<Vec<u8>, anyhow::Error>;
693
694 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
695 match Pin::new(&mut self.inner_stream).poll_next(cx) {
696 Poll::Ready(Some(Ok(mut chunk))) => {
697 chunk.bytes.push(b'\n');
698 self.tokens.input += chunk.input_tokens;
699 self.tokens.output += chunk.output_tokens;
700 self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
701 self.tokens.input_cache_read += chunk.cache_read_input_tokens;
702 Poll::Ready(Some(Ok(chunk.bytes)))
703 }
704 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
705 Poll::Ready(None) => Poll::Ready(None),
706 Poll::Pending => Poll::Pending,
707 }
708 }
709}
710
711impl<S> Drop for TokenCountingStream<S> {
712 fn drop(&mut self) {
713 let state = self.state.clone();
714 let claims = self.claims.clone();
715 let provider = self.provider;
716 let model = std::mem::take(&mut self.model);
717 let tokens = self.tokens;
718 self.state.executor.spawn_detached(async move {
719 let usage = state
720 .db
721 .record_usage(
722 UserId::from_proto(claims.user_id),
723 claims.is_staff,
724 provider,
725 &model,
726 tokens,
727 claims.has_llm_subscription,
728 Cents(claims.max_monthly_spend_in_cents),
729 claims.free_tier_monthly_spending_limit(),
730 Utc::now(),
731 )
732 .await
733 .log_err();
734
735 if let Some(usage) = usage {
736 tracing::info!(
737 target: "user usage",
738 user_id = claims.user_id,
739 login = claims.github_user_login,
740 authn.jti = claims.jti,
741 is_staff = claims.is_staff,
742 requests_this_minute = usage.requests_this_minute,
743 tokens_this_minute = usage.tokens_this_minute,
744 );
745
746 let properties = json!({
747 "plan": match claims.plan {
748 Plan::Free => "free".to_string(),
749 Plan::ZedPro => "zed_pro".to_string(),
750 },
751 "model": model,
752 "provider": provider,
753 "usage": usage,
754 "tokens": tokens
755 });
756 SnowflakeRow::new(
757 "Language Model Used",
758 claims.metrics_id,
759 claims.is_staff,
760 claims.system_id.clone(),
761 properties,
762 )
763 .write(&state.kinesis_client, &state.config.kinesis_stream)
764 .await
765 .log_err();
766
767 if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
768 report_llm_usage(
769 clickhouse_client,
770 LlmUsageEventRow {
771 time: Utc::now().timestamp_millis(),
772 user_id: claims.user_id as i32,
773 is_staff: claims.is_staff,
774 plan: match claims.plan {
775 Plan::Free => "free".to_string(),
776 Plan::ZedPro => "zed_pro".to_string(),
777 },
778 model,
779 provider: provider.to_string(),
780 input_token_count: tokens.input as u64,
781 cache_creation_input_token_count: tokens.input_cache_creation as u64,
782 cache_read_input_token_count: tokens.input_cache_read as u64,
783 output_token_count: tokens.output as u64,
784 requests_this_minute: usage.requests_this_minute as u64,
785 tokens_this_minute: usage.tokens_this_minute as u64,
786 tokens_this_day: usage.tokens_this_day as u64,
787 input_tokens_this_month: usage.tokens_this_month.input as u64,
788 cache_creation_input_tokens_this_month: usage
789 .tokens_this_month
790 .input_cache_creation
791 as u64,
792 cache_read_input_tokens_this_month: usage
793 .tokens_this_month
794 .input_cache_read
795 as u64,
796 output_tokens_this_month: usage.tokens_this_month.output as u64,
797 spending_this_month: usage.spending_this_month.0 as u64,
798 lifetime_spending: usage.lifetime_spending.0 as u64,
799 },
800 )
801 .await
802 .log_err();
803 }
804 }
805 })
806 }
807}
808
809pub fn log_usage_periodically(state: Arc<LlmState>) {
810 state.executor.clone().spawn_detached(async move {
811 loop {
812 state
813 .executor
814 .sleep(std::time::Duration::from_secs(30))
815 .await;
816
817 for provider in LanguageModelProvider::iter() {
818 for model in state.db.model_names_for_provider(provider) {
819 if let Some(active_user_count) = state
820 .get_active_user_count(provider, &model)
821 .await
822 .log_err()
823 {
824 tracing::info!(
825 target: "active user counts",
826 provider = provider.to_string(),
827 model = model,
828 users_in_recent_minutes = active_user_count.users_in_recent_minutes,
829 users_in_recent_days = active_user_count.users_in_recent_days,
830 );
831 }
832 }
833 }
834
835 if let Some(usages) = state
836 .db
837 .get_application_wide_usages_by_model(Utc::now())
838 .await
839 .log_err()
840 {
841 for usage in usages {
842 tracing::info!(
843 target: "computed usage",
844 provider = usage.provider.to_string(),
845 model = usage.model,
846 requests_this_minute = usage.requests_this_minute,
847 tokens_this_minute = usage.tokens_this_minute,
848 );
849 }
850 }
851 }
852 })
853}