1mod authorization;
2pub mod db;
3mod token;
4
5use crate::api::events::SnowflakeRow;
6use crate::api::CloudflareIpCountryHeader;
7use crate::build_kinesis_client;
8use crate::{db::UserId, executor::Executor, Cents, Config, Error, Result};
9use anyhow::{anyhow, Context as _};
10use authorization::authorize_access_to_language_model;
11use axum::routing::get;
12use axum::{
13 body::Body,
14 http::{self, HeaderName, HeaderValue, Request, StatusCode},
15 middleware::{self, Next},
16 response::{IntoResponse, Response},
17 routing::post,
18 Extension, Json, Router, TypedHeader,
19};
20use chrono::{DateTime, Duration, Utc};
21use collections::HashMap;
22use db::TokenUsage;
23use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
24use futures::{FutureExt, Stream, StreamExt as _};
25use reqwest_client::ReqwestClient;
26use rpc::{
27 proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
28};
29use rpc::{
30 ListModelsResponse, PredictEditsParams, PredictEditsResponse,
31 MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
32};
33use serde_json::json;
34use std::{
35 pin::Pin,
36 sync::Arc,
37 task::{Context, Poll},
38};
39use strum::IntoEnumIterator;
40use tokio::sync::RwLock;
41use util::ResultExt;
42
43pub use token::*;
44
45const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
46
47/// Output token limit. A copy of this constant is also in `crates/zeta/src/zeta.rs`.
48const MAX_OUTPUT_TOKENS: u32 = 2048;
49
50pub struct LlmState {
51 pub config: Config,
52 pub executor: Executor,
53 pub db: Arc<LlmDatabase>,
54 pub http_client: ReqwestClient,
55 pub kinesis_client: Option<aws_sdk_kinesis::Client>,
56 active_user_count_by_model:
57 RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
58}
59
60impl LlmState {
61 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
62 let database_url = config
63 .llm_database_url
64 .as_ref()
65 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
66 let max_connections = config
67 .llm_database_max_connections
68 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
69
70 let mut db_options = db::ConnectOptions::new(database_url);
71 db_options.max_connections(max_connections);
72 let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
73 db.initialize().await?;
74
75 let db = Arc::new(db);
76
77 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
78 let http_client =
79 ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
80
81 let this = Self {
82 executor,
83 db,
84 http_client,
85 kinesis_client: if config.kinesis_access_key.is_some() {
86 build_kinesis_client(&config).await.log_err()
87 } else {
88 None
89 },
90 active_user_count_by_model: RwLock::new(HashMap::default()),
91 config,
92 };
93
94 Ok(Arc::new(this))
95 }
96
97 pub async fn get_active_user_count(
98 &self,
99 provider: LanguageModelProvider,
100 model: &str,
101 ) -> Result<ActiveUserCount> {
102 let now = Utc::now();
103
104 {
105 let active_user_count_by_model = self.active_user_count_by_model.read().await;
106 if let Some((last_updated, count)) =
107 active_user_count_by_model.get(&(provider, model.to_string()))
108 {
109 if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
110 return Ok(*count);
111 }
112 }
113 }
114
115 let mut cache = self.active_user_count_by_model.write().await;
116 let new_count = self.db.get_active_user_count(provider, model, now).await?;
117 cache.insert((provider, model.to_string()), (now, new_count));
118 Ok(new_count)
119 }
120}
121
122pub fn routes() -> Router<(), Body> {
123 Router::new()
124 .route("/models", get(list_models))
125 .route("/completion", post(perform_completion))
126 .route("/predict_edits", post(predict_edits))
127 .layer(middleware::from_fn(validate_api_token))
128}
129
130async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
131 let token = req
132 .headers()
133 .get(http::header::AUTHORIZATION)
134 .and_then(|header| header.to_str().ok())
135 .ok_or_else(|| {
136 Error::http(
137 StatusCode::BAD_REQUEST,
138 "missing authorization header".to_string(),
139 )
140 })?
141 .strip_prefix("Bearer ")
142 .ok_or_else(|| {
143 Error::http(
144 StatusCode::BAD_REQUEST,
145 "invalid authorization header".to_string(),
146 )
147 })?;
148
149 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
150 match LlmTokenClaims::validate(token, &state.config) {
151 Ok(claims) => {
152 if state.db.is_access_token_revoked(&claims.jti).await? {
153 return Err(Error::http(
154 StatusCode::UNAUTHORIZED,
155 "unauthorized".to_string(),
156 ));
157 }
158
159 tracing::Span::current()
160 .record("user_id", claims.user_id)
161 .record("login", claims.github_user_login.clone())
162 .record("authn.jti", &claims.jti)
163 .record("is_staff", claims.is_staff);
164
165 req.extensions_mut().insert(claims);
166 Ok::<_, Error>(next.run(req).await.into_response())
167 }
168 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
169 StatusCode::UNAUTHORIZED,
170 "unauthorized".to_string(),
171 [(
172 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
173 HeaderValue::from_static("true"),
174 )]
175 .into_iter()
176 .collect(),
177 )),
178 Err(_err) => Err(Error::http(
179 StatusCode::UNAUTHORIZED,
180 "unauthorized".to_string(),
181 )),
182 }
183}
184
185async fn list_models(
186 Extension(state): Extension<Arc<LlmState>>,
187 Extension(claims): Extension<LlmTokenClaims>,
188 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
189) -> Result<Json<ListModelsResponse>> {
190 let country_code = country_code_header.map(|header| header.to_string());
191
192 let mut accessible_models = Vec::new();
193
194 for (provider, model) in state.db.all_models() {
195 let authorize_result = authorize_access_to_language_model(
196 &state.config,
197 &claims,
198 country_code.as_deref(),
199 provider,
200 &model.name,
201 );
202
203 if authorize_result.is_ok() {
204 accessible_models.push(rpc::LanguageModel {
205 provider,
206 name: model.name,
207 });
208 }
209 }
210
211 Ok(Json(ListModelsResponse {
212 models: accessible_models,
213 }))
214}
215
216async fn perform_completion(
217 Extension(state): Extension<Arc<LlmState>>,
218 Extension(claims): Extension<LlmTokenClaims>,
219 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
220 Json(params): Json<PerformCompletionParams>,
221) -> Result<impl IntoResponse> {
222 let model = normalize_model_name(
223 state.db.model_names_for_provider(params.provider),
224 params.model,
225 );
226
227 authorize_access_to_language_model(
228 &state.config,
229 &claims,
230 country_code_header
231 .map(|header| header.to_string())
232 .as_deref(),
233 params.provider,
234 &model,
235 )?;
236
237 check_usage_limit(&state, params.provider, &model, &claims).await?;
238
239 let stream = match params.provider {
240 LanguageModelProvider::Anthropic => {
241 let api_key = if claims.is_staff {
242 state
243 .config
244 .anthropic_staff_api_key
245 .as_ref()
246 .context("no Anthropic AI staff API key configured on the server")?
247 } else {
248 state
249 .config
250 .anthropic_api_key
251 .as_ref()
252 .context("no Anthropic AI API key configured on the server")?
253 };
254
255 let mut request: anthropic::Request =
256 serde_json::from_str(params.provider_request.get())?;
257
258 // Override the model on the request with the latest version of the model that is
259 // known to the server.
260 //
261 // Right now, we use the version that's defined in `model.id()`, but we will likely
262 // want to change this code once a new version of an Anthropic model is released,
263 // so that users can use the new version, without having to update Zed.
264 request.model = match model.as_str() {
265 "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
266 "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
267 "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
268 "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
269 _ => request.model,
270 };
271
272 let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
273 &state.http_client,
274 anthropic::ANTHROPIC_API_URL,
275 api_key,
276 request,
277 )
278 .await
279 .map_err(|err| match err {
280 anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
281 Some(anthropic::ApiErrorCode::RateLimitError) => {
282 tracing::info!(
283 target: "upstream rate limit exceeded",
284 user_id = claims.user_id,
285 login = claims.github_user_login,
286 authn.jti = claims.jti,
287 is_staff = claims.is_staff,
288 provider = params.provider.to_string(),
289 model = model
290 );
291
292 Error::http(
293 StatusCode::TOO_MANY_REQUESTS,
294 "Upstream Anthropic rate limit exceeded.".to_string(),
295 )
296 }
297 Some(anthropic::ApiErrorCode::InvalidRequestError) => {
298 Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
299 }
300 Some(anthropic::ApiErrorCode::OverloadedError) => {
301 Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
302 }
303 Some(_) => {
304 Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
305 }
306 None => Error::Internal(anyhow!(err)),
307 },
308 anthropic::AnthropicError::Other(err) => Error::Internal(err),
309 })?;
310
311 if let Some(rate_limit_info) = rate_limit_info {
312 tracing::info!(
313 target: "upstream rate limit",
314 is_staff = claims.is_staff,
315 provider = params.provider.to_string(),
316 model = model,
317 tokens_remaining = rate_limit_info.tokens_remaining,
318 requests_remaining = rate_limit_info.requests_remaining,
319 requests_reset = ?rate_limit_info.requests_reset,
320 tokens_reset = ?rate_limit_info.tokens_reset,
321 );
322 }
323
324 chunks
325 .map(move |event| {
326 let chunk = event?;
327 let (
328 input_tokens,
329 output_tokens,
330 cache_creation_input_tokens,
331 cache_read_input_tokens,
332 ) = match &chunk {
333 anthropic::Event::MessageStart {
334 message: anthropic::Response { usage, .. },
335 }
336 | anthropic::Event::MessageDelta { usage, .. } => (
337 usage.input_tokens.unwrap_or(0) as usize,
338 usage.output_tokens.unwrap_or(0) as usize,
339 usage.cache_creation_input_tokens.unwrap_or(0) as usize,
340 usage.cache_read_input_tokens.unwrap_or(0) as usize,
341 ),
342 _ => (0, 0, 0, 0),
343 };
344
345 anyhow::Ok(CompletionChunk {
346 bytes: serde_json::to_vec(&chunk).unwrap(),
347 input_tokens,
348 output_tokens,
349 cache_creation_input_tokens,
350 cache_read_input_tokens,
351 })
352 })
353 .boxed()
354 }
355 LanguageModelProvider::OpenAi => {
356 let api_key = state
357 .config
358 .openai_api_key
359 .as_ref()
360 .context("no OpenAI API key configured on the server")?;
361 let chunks = open_ai::stream_completion(
362 &state.http_client,
363 open_ai::OPEN_AI_API_URL,
364 api_key,
365 serde_json::from_str(params.provider_request.get())?,
366 )
367 .await?;
368
369 chunks
370 .map(|event| {
371 event.map(|chunk| {
372 let input_tokens =
373 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
374 let output_tokens =
375 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
376 CompletionChunk {
377 bytes: serde_json::to_vec(&chunk).unwrap(),
378 input_tokens,
379 output_tokens,
380 cache_creation_input_tokens: 0,
381 cache_read_input_tokens: 0,
382 }
383 })
384 })
385 .boxed()
386 }
387 LanguageModelProvider::Google => {
388 let api_key = state
389 .config
390 .google_ai_api_key
391 .as_ref()
392 .context("no Google AI API key configured on the server")?;
393 let chunks = google_ai::stream_generate_content(
394 &state.http_client,
395 google_ai::API_URL,
396 api_key,
397 serde_json::from_str(params.provider_request.get())?,
398 )
399 .await?;
400
401 chunks
402 .map(|event| {
403 event.map(|chunk| {
404 // TODO - implement token counting for Google AI
405 CompletionChunk {
406 bytes: serde_json::to_vec(&chunk).unwrap(),
407 input_tokens: 0,
408 output_tokens: 0,
409 cache_creation_input_tokens: 0,
410 cache_read_input_tokens: 0,
411 }
412 })
413 })
414 .boxed()
415 }
416 };
417
418 Ok(Response::new(Body::wrap_stream(TokenCountingStream {
419 state,
420 claims,
421 provider: params.provider,
422 model,
423 tokens: TokenUsage::default(),
424 inner_stream: stream,
425 })))
426}
427
428fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
429 if let Some(known_model_name) = known_models
430 .iter()
431 .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
432 .max_by_key(|known_model_name| known_model_name.len())
433 {
434 known_model_name.to_string()
435 } else {
436 name
437 }
438}
439
440async fn predict_edits(
441 Extension(state): Extension<Arc<LlmState>>,
442 Extension(claims): Extension<LlmTokenClaims>,
443 _country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
444 Json(params): Json<PredictEditsParams>,
445) -> Result<impl IntoResponse> {
446 if !claims.is_staff && !claims.has_predict_edits_feature_flag {
447 return Err(Error::http(
448 StatusCode::FORBIDDEN,
449 "no access to Zed's edit prediction feature".to_string(),
450 ));
451 }
452
453 let should_sample = claims.is_staff || params.can_collect_data;
454
455 let api_url = state
456 .config
457 .prediction_api_url
458 .as_ref()
459 .context("no PREDICTION_API_URL configured on the server")?;
460 let api_key = state
461 .config
462 .prediction_api_key
463 .as_ref()
464 .context("no PREDICTION_API_KEY configured on the server")?;
465 let model = state
466 .config
467 .prediction_model
468 .as_ref()
469 .context("no PREDICTION_MODEL configured on the server")?;
470
471 let outline_prefix = params
472 .outline
473 .as_ref()
474 .map(|outline| format!("### Outline for current file:\n{}\n", outline))
475 .unwrap_or_default();
476
477 let prompt = include_str!("./llm/prediction_prompt.md")
478 .replace("<outline>", &outline_prefix)
479 .replace("<events>", ¶ms.input_events)
480 .replace("<excerpt>", ¶ms.input_excerpt);
481
482 let request_start = std::time::Instant::now();
483 let timeout = state
484 .executor
485 .sleep(std::time::Duration::from_secs(2))
486 .fuse();
487 let response = fireworks::complete(
488 &state.http_client,
489 api_url,
490 api_key,
491 fireworks::CompletionRequest {
492 model: model.to_string(),
493 prompt: prompt.clone(),
494 max_tokens: MAX_OUTPUT_TOKENS,
495 temperature: 0.,
496 prediction: Some(fireworks::Prediction::Content {
497 content: params.input_excerpt.clone(),
498 }),
499 rewrite_speculation: Some(true),
500 },
501 )
502 .fuse();
503 futures::pin_mut!(timeout);
504 futures::pin_mut!(response);
505
506 futures::select! {
507 _ = timeout => {
508 state.executor.spawn_detached({
509 let kinesis_client = state.kinesis_client.clone();
510 let kinesis_stream = state.config.kinesis_stream.clone();
511 let model = model.clone();
512 async move {
513 SnowflakeRow::new(
514 "Fireworks Completion Timeout",
515 claims.metrics_id,
516 claims.is_staff,
517 claims.system_id.clone(),
518 json!({
519 "model": model.to_string(),
520 "prompt": prompt,
521 }),
522 )
523 .write(&kinesis_client, &kinesis_stream)
524 .await
525 .log_err();
526 }
527 });
528 Err(anyhow!("request timed out"))?
529 },
530 response = response => {
531 let duration = request_start.elapsed();
532
533 let mut response = response?;
534 let choice = response
535 .completion
536 .choices
537 .pop()
538 .context("no output from completion response")?;
539
540 state.executor.spawn_detached({
541 let kinesis_client = state.kinesis_client.clone();
542 let kinesis_stream = state.config.kinesis_stream.clone();
543 let model = model.clone();
544 let output = choice.text.clone();
545
546 async move {
547 let properties = if should_sample {
548 json!({
549 "model": model.to_string(),
550 "headers": response.headers,
551 "usage": response.completion.usage,
552 "duration": duration.as_secs_f64(),
553 "prompt": prompt,
554 "input_excerpt": params.input_excerpt,
555 "input_events": params.input_events,
556 "outline": params.outline,
557 "output": output,
558 "is_sampled": true,
559 })
560 } else {
561 json!({
562 "model": model.to_string(),
563 "headers": response.headers,
564 "usage": response.completion.usage,
565 "duration": duration.as_secs_f64(),
566 "is_sampled": false,
567 })
568 };
569
570 SnowflakeRow::new(
571 "Fireworks Completion Requested",
572 claims.metrics_id,
573 claims.is_staff,
574 claims.system_id.clone(),
575 properties,
576 )
577 .write(&kinesis_client, &kinesis_stream)
578 .await
579 .log_err();
580 }
581 });
582
583 Ok(Json(PredictEditsResponse {
584 output_excerpt: choice.text,
585 }))
586 },
587 }
588}
589
590/// The maximum monthly spending an individual user can reach on the free tier
591/// before they have to pay.
592pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
593
594/// The default value to use for maximum spend per month if the user did not
595/// explicitly set a maximum spend.
596///
597/// Used to prevent surprise bills.
598pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
599
600async fn check_usage_limit(
601 state: &Arc<LlmState>,
602 provider: LanguageModelProvider,
603 model_name: &str,
604 claims: &LlmTokenClaims,
605) -> Result<()> {
606 if claims.is_staff {
607 return Ok(());
608 }
609
610 let model = state.db.model(provider, model_name)?;
611 let usage = state
612 .db
613 .get_usage(
614 UserId::from_proto(claims.user_id),
615 provider,
616 model_name,
617 Utc::now(),
618 )
619 .await?;
620 let free_tier = claims.free_tier_monthly_spending_limit();
621
622 if usage.spending_this_month >= free_tier {
623 if !claims.has_llm_subscription {
624 return Err(Error::http(
625 StatusCode::PAYMENT_REQUIRED,
626 "Maximum spending limit reached for this month.".to_string(),
627 ));
628 }
629
630 if (usage.spending_this_month - free_tier) >= Cents(claims.max_monthly_spend_in_cents) {
631 return Err(Error::Http(
632 StatusCode::FORBIDDEN,
633 "Maximum spending limit reached for this month.".to_string(),
634 [(
635 HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
636 HeaderValue::from_static("true"),
637 )]
638 .into_iter()
639 .collect(),
640 ));
641 }
642 }
643
644 let active_users = state.get_active_user_count(provider, model_name).await?;
645
646 let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
647 let users_in_recent_days = active_users.users_in_recent_days.max(1);
648
649 let per_user_max_requests_per_minute =
650 model.max_requests_per_minute as usize / users_in_recent_minutes;
651 let per_user_max_tokens_per_minute =
652 model.max_tokens_per_minute as usize / users_in_recent_minutes;
653 let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
654
655 let checks = [
656 (
657 usage.requests_this_minute,
658 per_user_max_requests_per_minute,
659 UsageMeasure::RequestsPerMinute,
660 ),
661 (
662 usage.tokens_this_minute,
663 per_user_max_tokens_per_minute,
664 UsageMeasure::TokensPerMinute,
665 ),
666 (
667 usage.tokens_this_day,
668 per_user_max_tokens_per_day,
669 UsageMeasure::TokensPerDay,
670 ),
671 ];
672
673 for (used, limit, usage_measure) in checks {
674 if used > limit {
675 let resource = match usage_measure {
676 UsageMeasure::RequestsPerMinute => "requests_per_minute",
677 UsageMeasure::TokensPerMinute => "tokens_per_minute",
678 UsageMeasure::TokensPerDay => "tokens_per_day",
679 };
680
681 tracing::info!(
682 target: "user rate limit",
683 user_id = claims.user_id,
684 login = claims.github_user_login,
685 authn.jti = claims.jti,
686 is_staff = claims.is_staff,
687 provider = provider.to_string(),
688 model = model.name,
689 requests_this_minute = usage.requests_this_minute,
690 tokens_this_minute = usage.tokens_this_minute,
691 tokens_this_day = usage.tokens_this_day,
692 users_in_recent_minutes = users_in_recent_minutes,
693 users_in_recent_days = users_in_recent_days,
694 max_requests_per_minute = per_user_max_requests_per_minute,
695 max_tokens_per_minute = per_user_max_tokens_per_minute,
696 max_tokens_per_day = per_user_max_tokens_per_day,
697 );
698
699 SnowflakeRow::new(
700 "Language Model Rate Limited",
701 claims.metrics_id,
702 claims.is_staff,
703 claims.system_id.clone(),
704 json!({
705 "usage": usage,
706 "users_in_recent_minutes": users_in_recent_minutes,
707 "users_in_recent_days": users_in_recent_days,
708 "max_requests_per_minute": per_user_max_requests_per_minute,
709 "max_tokens_per_minute": per_user_max_tokens_per_minute,
710 "max_tokens_per_day": per_user_max_tokens_per_day,
711 "plan": match claims.plan {
712 Plan::Free => "free".to_string(),
713 Plan::ZedPro => "zed_pro".to_string(),
714 },
715 "model": model.name.clone(),
716 "provider": provider.to_string(),
717 "usage_measure": resource.to_string(),
718 }),
719 )
720 .write(&state.kinesis_client, &state.config.kinesis_stream)
721 .await
722 .log_err();
723
724 return Err(Error::http(
725 StatusCode::TOO_MANY_REQUESTS,
726 format!("Rate limit exceeded. Maximum {} reached.", resource),
727 ));
728 }
729 }
730
731 Ok(())
732}
733
734struct CompletionChunk {
735 bytes: Vec<u8>,
736 input_tokens: usize,
737 output_tokens: usize,
738 cache_creation_input_tokens: usize,
739 cache_read_input_tokens: usize,
740}
741
742struct TokenCountingStream<S> {
743 state: Arc<LlmState>,
744 claims: LlmTokenClaims,
745 provider: LanguageModelProvider,
746 model: String,
747 tokens: TokenUsage,
748 inner_stream: S,
749}
750
751impl<S> Stream for TokenCountingStream<S>
752where
753 S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
754{
755 type Item = Result<Vec<u8>, anyhow::Error>;
756
757 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
758 match Pin::new(&mut self.inner_stream).poll_next(cx) {
759 Poll::Ready(Some(Ok(mut chunk))) => {
760 chunk.bytes.push(b'\n');
761 self.tokens.input += chunk.input_tokens;
762 self.tokens.output += chunk.output_tokens;
763 self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
764 self.tokens.input_cache_read += chunk.cache_read_input_tokens;
765 Poll::Ready(Some(Ok(chunk.bytes)))
766 }
767 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
768 Poll::Ready(None) => Poll::Ready(None),
769 Poll::Pending => Poll::Pending,
770 }
771 }
772}
773
774impl<S> Drop for TokenCountingStream<S> {
775 fn drop(&mut self) {
776 let state = self.state.clone();
777 let claims = self.claims.clone();
778 let provider = self.provider;
779 let model = std::mem::take(&mut self.model);
780 let tokens = self.tokens;
781 self.state.executor.spawn_detached(async move {
782 let usage = state
783 .db
784 .record_usage(
785 UserId::from_proto(claims.user_id),
786 claims.is_staff,
787 provider,
788 &model,
789 tokens,
790 claims.has_llm_subscription,
791 Cents(claims.max_monthly_spend_in_cents),
792 claims.free_tier_monthly_spending_limit(),
793 Utc::now(),
794 )
795 .await
796 .log_err();
797
798 if let Some(usage) = usage {
799 tracing::info!(
800 target: "user usage",
801 user_id = claims.user_id,
802 login = claims.github_user_login,
803 authn.jti = claims.jti,
804 is_staff = claims.is_staff,
805 requests_this_minute = usage.requests_this_minute,
806 tokens_this_minute = usage.tokens_this_minute,
807 );
808
809 let properties = json!({
810 "has_llm_subscription": claims.has_llm_subscription,
811 "max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents,
812 "plan": match claims.plan {
813 Plan::Free => "free".to_string(),
814 Plan::ZedPro => "zed_pro".to_string(),
815 },
816 "model": model,
817 "provider": provider,
818 "usage": usage,
819 "tokens": tokens
820 });
821 SnowflakeRow::new(
822 "Language Model Used",
823 claims.metrics_id,
824 claims.is_staff,
825 claims.system_id.clone(),
826 properties,
827 )
828 .write(&state.kinesis_client, &state.config.kinesis_stream)
829 .await
830 .log_err();
831 }
832 })
833 }
834}
835
836pub fn log_usage_periodically(state: Arc<LlmState>) {
837 state.executor.clone().spawn_detached(async move {
838 loop {
839 state
840 .executor
841 .sleep(std::time::Duration::from_secs(30))
842 .await;
843
844 for provider in LanguageModelProvider::iter() {
845 for model in state.db.model_names_for_provider(provider) {
846 if let Some(active_user_count) = state
847 .get_active_user_count(provider, &model)
848 .await
849 .log_err()
850 {
851 tracing::info!(
852 target: "active user counts",
853 provider = provider.to_string(),
854 model = model,
855 users_in_recent_minutes = active_user_count.users_in_recent_minutes,
856 users_in_recent_days = active_user_count.users_in_recent_days,
857 );
858 }
859 }
860 }
861
862 if let Some(usages) = state
863 .db
864 .get_application_wide_usages_by_model(Utc::now())
865 .await
866 .log_err()
867 {
868 for usage in usages {
869 tracing::info!(
870 target: "computed usage",
871 provider = usage.provider.to_string(),
872 model = usage.model,
873 requests_this_minute = usage.requests_this_minute,
874 tokens_this_minute = usage.tokens_this_minute,
875 );
876 }
877 }
878 }
879 })
880}