llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::api::events::SnowflakeRow;
  7use crate::api::CloudflareIpCountryHeader;
  8use crate::build_kinesis_client;
  9use crate::{
 10    build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result,
 11};
 12use anyhow::{anyhow, Context as _};
 13use authorization::authorize_access_to_language_model;
 14use axum::routing::get;
 15use axum::{
 16    body::Body,
 17    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 18    middleware::{self, Next},
 19    response::{IntoResponse, Response},
 20    routing::post,
 21    Extension, Json, Router, TypedHeader,
 22};
 23use chrono::{DateTime, Duration, Utc};
 24use collections::HashMap;
 25use db::TokenUsage;
 26use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 27use futures::{Stream, StreamExt as _};
 28use reqwest_client::ReqwestClient;
 29use rpc::{
 30    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 31};
 32use rpc::{
 33    ListModelsResponse, PredictEditsParams, PredictEditsResponse,
 34    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
 35};
 36use serde_json::json;
 37use std::{
 38    pin::Pin,
 39    sync::Arc,
 40    task::{Context, Poll},
 41};
 42use strum::IntoEnumIterator;
 43use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
 44use tokio::sync::RwLock;
 45use util::ResultExt;
 46
 47pub use token::*;
 48
 49pub struct LlmState {
 50    pub config: Config,
 51    pub executor: Executor,
 52    pub db: Arc<LlmDatabase>,
 53    pub http_client: ReqwestClient,
 54    pub kinesis_client: Option<aws_sdk_kinesis::Client>,
 55    pub clickhouse_client: Option<clickhouse::Client>,
 56    active_user_count_by_model:
 57        RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
 58}
 59
 60const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 61
 62impl LlmState {
 63    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 64        let database_url = config
 65            .llm_database_url
 66            .as_ref()
 67            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 68        let max_connections = config
 69            .llm_database_max_connections
 70            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 71
 72        let mut db_options = db::ConnectOptions::new(database_url);
 73        db_options.max_connections(max_connections);
 74        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 75        db.initialize().await?;
 76
 77        let db = Arc::new(db);
 78
 79        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 80        let http_client =
 81            ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
 82
 83        let this = Self {
 84            executor,
 85            db,
 86            http_client,
 87            kinesis_client: if config.kinesis_access_key.is_some() {
 88                build_kinesis_client(&config).await.log_err()
 89            } else {
 90                None
 91            },
 92            clickhouse_client: config
 93                .clickhouse_url
 94                .as_ref()
 95                .and_then(|_| build_clickhouse_client(&config).log_err()),
 96            active_user_count_by_model: RwLock::new(HashMap::default()),
 97            config,
 98        };
 99
100        Ok(Arc::new(this))
101    }
102
103    pub async fn get_active_user_count(
104        &self,
105        provider: LanguageModelProvider,
106        model: &str,
107    ) -> Result<ActiveUserCount> {
108        let now = Utc::now();
109
110        {
111            let active_user_count_by_model = self.active_user_count_by_model.read().await;
112            if let Some((last_updated, count)) =
113                active_user_count_by_model.get(&(provider, model.to_string()))
114            {
115                if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
116                    return Ok(*count);
117                }
118            }
119        }
120
121        let mut cache = self.active_user_count_by_model.write().await;
122        let new_count = self.db.get_active_user_count(provider, model, now).await?;
123        cache.insert((provider, model.to_string()), (now, new_count));
124        Ok(new_count)
125    }
126}
127
128pub fn routes() -> Router<(), Body> {
129    Router::new()
130        .route("/models", get(list_models))
131        .route("/completion", post(perform_completion))
132        .route("/predict_edits", post(predict_edits))
133        .layer(middleware::from_fn(validate_api_token))
134}
135
136async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
137    let token = req
138        .headers()
139        .get(http::header::AUTHORIZATION)
140        .and_then(|header| header.to_str().ok())
141        .ok_or_else(|| {
142            Error::http(
143                StatusCode::BAD_REQUEST,
144                "missing authorization header".to_string(),
145            )
146        })?
147        .strip_prefix("Bearer ")
148        .ok_or_else(|| {
149            Error::http(
150                StatusCode::BAD_REQUEST,
151                "invalid authorization header".to_string(),
152            )
153        })?;
154
155    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
156    match LlmTokenClaims::validate(token, &state.config) {
157        Ok(claims) => {
158            if state.db.is_access_token_revoked(&claims.jti).await? {
159                return Err(Error::http(
160                    StatusCode::UNAUTHORIZED,
161                    "unauthorized".to_string(),
162                ));
163            }
164
165            tracing::Span::current()
166                .record("user_id", claims.user_id)
167                .record("login", claims.github_user_login.clone())
168                .record("authn.jti", &claims.jti)
169                .record("is_staff", claims.is_staff);
170
171            req.extensions_mut().insert(claims);
172            Ok::<_, Error>(next.run(req).await.into_response())
173        }
174        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
175            StatusCode::UNAUTHORIZED,
176            "unauthorized".to_string(),
177            [(
178                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
179                HeaderValue::from_static("true"),
180            )]
181            .into_iter()
182            .collect(),
183        )),
184        Err(_err) => Err(Error::http(
185            StatusCode::UNAUTHORIZED,
186            "unauthorized".to_string(),
187        )),
188    }
189}
190
191async fn list_models(
192    Extension(state): Extension<Arc<LlmState>>,
193    Extension(claims): Extension<LlmTokenClaims>,
194    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
195) -> Result<Json<ListModelsResponse>> {
196    let country_code = country_code_header.map(|header| header.to_string());
197
198    let mut accessible_models = Vec::new();
199
200    for (provider, model) in state.db.all_models() {
201        let authorize_result = authorize_access_to_language_model(
202            &state.config,
203            &claims,
204            country_code.as_deref(),
205            provider,
206            &model.name,
207        );
208
209        if authorize_result.is_ok() {
210            accessible_models.push(rpc::LanguageModel {
211                provider,
212                name: model.name,
213            });
214        }
215    }
216
217    Ok(Json(ListModelsResponse {
218        models: accessible_models,
219    }))
220}
221
222async fn perform_completion(
223    Extension(state): Extension<Arc<LlmState>>,
224    Extension(claims): Extension<LlmTokenClaims>,
225    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
226    Json(params): Json<PerformCompletionParams>,
227) -> Result<impl IntoResponse> {
228    let model = normalize_model_name(
229        state.db.model_names_for_provider(params.provider),
230        params.model,
231    );
232
233    authorize_access_to_language_model(
234        &state.config,
235        &claims,
236        country_code_header
237            .map(|header| header.to_string())
238            .as_deref(),
239        params.provider,
240        &model,
241    )?;
242
243    check_usage_limit(&state, params.provider, &model, &claims).await?;
244
245    let stream = match params.provider {
246        LanguageModelProvider::Anthropic => {
247            let api_key = if claims.is_staff {
248                state
249                    .config
250                    .anthropic_staff_api_key
251                    .as_ref()
252                    .context("no Anthropic AI staff API key configured on the server")?
253            } else {
254                state
255                    .config
256                    .anthropic_api_key
257                    .as_ref()
258                    .context("no Anthropic AI API key configured on the server")?
259            };
260
261            let mut request: anthropic::Request =
262                serde_json::from_str(params.provider_request.get())?;
263
264            // Override the model on the request with the latest version of the model that is
265            // known to the server.
266            //
267            // Right now, we use the version that's defined in `model.id()`, but we will likely
268            // want to change this code once a new version of an Anthropic model is released,
269            // so that users can use the new version, without having to update Zed.
270            request.model = match model.as_str() {
271                "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
272                "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
273                "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
274                "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
275                _ => request.model,
276            };
277
278            let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
279                &state.http_client,
280                anthropic::ANTHROPIC_API_URL,
281                api_key,
282                request,
283            )
284            .await
285            .map_err(|err| match err {
286                anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
287                    Some(anthropic::ApiErrorCode::RateLimitError) => {
288                        tracing::info!(
289                            target: "upstream rate limit exceeded",
290                            user_id = claims.user_id,
291                            login = claims.github_user_login,
292                            authn.jti = claims.jti,
293                            is_staff = claims.is_staff,
294                            provider = params.provider.to_string(),
295                            model = model
296                        );
297
298                        Error::http(
299                            StatusCode::TOO_MANY_REQUESTS,
300                            "Upstream Anthropic rate limit exceeded.".to_string(),
301                        )
302                    }
303                    Some(anthropic::ApiErrorCode::InvalidRequestError) => {
304                        Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
305                    }
306                    Some(anthropic::ApiErrorCode::OverloadedError) => {
307                        Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
308                    }
309                    Some(_) => {
310                        Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
311                    }
312                    None => Error::Internal(anyhow!(err)),
313                },
314                anthropic::AnthropicError::Other(err) => Error::Internal(err),
315            })?;
316
317            if let Some(rate_limit_info) = rate_limit_info {
318                tracing::info!(
319                    target: "upstream rate limit",
320                    is_staff = claims.is_staff,
321                    provider = params.provider.to_string(),
322                    model = model,
323                    tokens_remaining = rate_limit_info.tokens_remaining,
324                    requests_remaining = rate_limit_info.requests_remaining,
325                    requests_reset = ?rate_limit_info.requests_reset,
326                    tokens_reset = ?rate_limit_info.tokens_reset,
327                );
328            }
329
330            chunks
331                .map(move |event| {
332                    let chunk = event?;
333                    let (
334                        input_tokens,
335                        output_tokens,
336                        cache_creation_input_tokens,
337                        cache_read_input_tokens,
338                    ) = match &chunk {
339                        anthropic::Event::MessageStart {
340                            message: anthropic::Response { usage, .. },
341                        }
342                        | anthropic::Event::MessageDelta { usage, .. } => (
343                            usage.input_tokens.unwrap_or(0) as usize,
344                            usage.output_tokens.unwrap_or(0) as usize,
345                            usage.cache_creation_input_tokens.unwrap_or(0) as usize,
346                            usage.cache_read_input_tokens.unwrap_or(0) as usize,
347                        ),
348                        _ => (0, 0, 0, 0),
349                    };
350
351                    anyhow::Ok(CompletionChunk {
352                        bytes: serde_json::to_vec(&chunk).unwrap(),
353                        input_tokens,
354                        output_tokens,
355                        cache_creation_input_tokens,
356                        cache_read_input_tokens,
357                    })
358                })
359                .boxed()
360        }
361        LanguageModelProvider::OpenAi => {
362            let api_key = state
363                .config
364                .openai_api_key
365                .as_ref()
366                .context("no OpenAI API key configured on the server")?;
367            let chunks = open_ai::stream_completion(
368                &state.http_client,
369                open_ai::OPEN_AI_API_URL,
370                api_key,
371                serde_json::from_str(params.provider_request.get())?,
372            )
373            .await?;
374
375            chunks
376                .map(|event| {
377                    event.map(|chunk| {
378                        let input_tokens =
379                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
380                        let output_tokens =
381                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
382                        CompletionChunk {
383                            bytes: serde_json::to_vec(&chunk).unwrap(),
384                            input_tokens,
385                            output_tokens,
386                            cache_creation_input_tokens: 0,
387                            cache_read_input_tokens: 0,
388                        }
389                    })
390                })
391                .boxed()
392        }
393        LanguageModelProvider::Google => {
394            let api_key = state
395                .config
396                .google_ai_api_key
397                .as_ref()
398                .context("no Google AI API key configured on the server")?;
399            let chunks = google_ai::stream_generate_content(
400                &state.http_client,
401                google_ai::API_URL,
402                api_key,
403                serde_json::from_str(params.provider_request.get())?,
404            )
405            .await?;
406
407            chunks
408                .map(|event| {
409                    event.map(|chunk| {
410                        // TODO - implement token counting for Google AI
411                        CompletionChunk {
412                            bytes: serde_json::to_vec(&chunk).unwrap(),
413                            input_tokens: 0,
414                            output_tokens: 0,
415                            cache_creation_input_tokens: 0,
416                            cache_read_input_tokens: 0,
417                        }
418                    })
419                })
420                .boxed()
421        }
422    };
423
424    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
425        state,
426        claims,
427        provider: params.provider,
428        model,
429        tokens: TokenUsage::default(),
430        inner_stream: stream,
431    })))
432}
433
434fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
435    if let Some(known_model_name) = known_models
436        .iter()
437        .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
438        .max_by_key(|known_model_name| known_model_name.len())
439    {
440        known_model_name.to_string()
441    } else {
442        name
443    }
444}
445
446async fn predict_edits(
447    Extension(state): Extension<Arc<LlmState>>,
448    Extension(claims): Extension<LlmTokenClaims>,
449    _country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
450    Json(params): Json<PredictEditsParams>,
451) -> Result<impl IntoResponse> {
452    if !claims.is_staff {
453        return Err(anyhow!("not found"))?;
454    }
455
456    let api_url = state
457        .config
458        .prediction_api_url
459        .as_ref()
460        .context("no PREDICTION_API_URL configured on the server")?;
461    let api_key = state
462        .config
463        .prediction_api_key
464        .as_ref()
465        .context("no PREDICTION_API_KEY configured on the server")?;
466    let model = state
467        .config
468        .prediction_model
469        .as_ref()
470        .context("no PREDICTION_MODEL configured on the server")?;
471    let prompt = include_str!("./llm/prediction_prompt.md")
472        .replace("<events>", &params.input_events)
473        .replace("<excerpt>", &params.input_excerpt);
474    let mut response = open_ai::complete_text(
475        &state.http_client,
476        api_url,
477        api_key,
478        open_ai::CompletionRequest {
479            model: model.to_string(),
480            prompt: prompt.clone(),
481            max_tokens: 1024,
482            temperature: 0.,
483            prediction: Some(open_ai::Prediction::Content {
484                content: params.input_excerpt,
485            }),
486            rewrite_speculation: Some(true),
487        },
488    )
489    .await?;
490    let choice = response
491        .choices
492        .pop()
493        .context("no output from completion response")?;
494    Ok(Json(PredictEditsResponse {
495        output_excerpt: choice.text,
496    }))
497}
498
499/// The maximum monthly spending an individual user can reach on the free tier
500/// before they have to pay.
501pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
502
503/// The default value to use for maximum spend per month if the user did not
504/// explicitly set a maximum spend.
505///
506/// Used to prevent surprise bills.
507pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
508
509async fn check_usage_limit(
510    state: &Arc<LlmState>,
511    provider: LanguageModelProvider,
512    model_name: &str,
513    claims: &LlmTokenClaims,
514) -> Result<()> {
515    if claims.is_staff {
516        return Ok(());
517    }
518
519    let model = state.db.model(provider, model_name)?;
520    let usage = state
521        .db
522        .get_usage(
523            UserId::from_proto(claims.user_id),
524            provider,
525            model_name,
526            Utc::now(),
527        )
528        .await?;
529    let free_tier = claims.free_tier_monthly_spending_limit();
530
531    if usage.spending_this_month >= free_tier {
532        if !claims.has_llm_subscription {
533            return Err(Error::http(
534                StatusCode::PAYMENT_REQUIRED,
535                "Maximum spending limit reached for this month.".to_string(),
536            ));
537        }
538
539        if (usage.spending_this_month - free_tier) >= Cents(claims.max_monthly_spend_in_cents) {
540            return Err(Error::Http(
541                StatusCode::FORBIDDEN,
542                "Maximum spending limit reached for this month.".to_string(),
543                [(
544                    HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
545                    HeaderValue::from_static("true"),
546                )]
547                .into_iter()
548                .collect(),
549            ));
550        }
551    }
552
553    let active_users = state.get_active_user_count(provider, model_name).await?;
554
555    let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
556    let users_in_recent_days = active_users.users_in_recent_days.max(1);
557
558    let per_user_max_requests_per_minute =
559        model.max_requests_per_minute as usize / users_in_recent_minutes;
560    let per_user_max_tokens_per_minute =
561        model.max_tokens_per_minute as usize / users_in_recent_minutes;
562    let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
563
564    let checks = [
565        (
566            usage.requests_this_minute,
567            per_user_max_requests_per_minute,
568            UsageMeasure::RequestsPerMinute,
569        ),
570        (
571            usage.tokens_this_minute,
572            per_user_max_tokens_per_minute,
573            UsageMeasure::TokensPerMinute,
574        ),
575        (
576            usage.tokens_this_day,
577            per_user_max_tokens_per_day,
578            UsageMeasure::TokensPerDay,
579        ),
580    ];
581
582    for (used, limit, usage_measure) in checks {
583        if used > limit {
584            let resource = match usage_measure {
585                UsageMeasure::RequestsPerMinute => "requests_per_minute",
586                UsageMeasure::TokensPerMinute => "tokens_per_minute",
587                UsageMeasure::TokensPerDay => "tokens_per_day",
588            };
589
590            tracing::info!(
591                target: "user rate limit",
592                user_id = claims.user_id,
593                login = claims.github_user_login,
594                authn.jti = claims.jti,
595                is_staff = claims.is_staff,
596                provider = provider.to_string(),
597                model = model.name,
598                requests_this_minute = usage.requests_this_minute,
599                tokens_this_minute = usage.tokens_this_minute,
600                tokens_this_day = usage.tokens_this_day,
601                users_in_recent_minutes = users_in_recent_minutes,
602                users_in_recent_days = users_in_recent_days,
603                max_requests_per_minute = per_user_max_requests_per_minute,
604                max_tokens_per_minute = per_user_max_tokens_per_minute,
605                max_tokens_per_day = per_user_max_tokens_per_day,
606            );
607
608            SnowflakeRow::new(
609                "Language Model Rate Limited",
610                claims.metrics_id,
611                claims.is_staff,
612                claims.system_id.clone(),
613                json!({
614                    "usage": usage,
615                    "users_in_recent_minutes": users_in_recent_minutes,
616                    "users_in_recent_days": users_in_recent_days,
617                    "max_requests_per_minute": per_user_max_requests_per_minute,
618                    "max_tokens_per_minute": per_user_max_tokens_per_minute,
619                    "max_tokens_per_day": per_user_max_tokens_per_day,
620                    "plan": match claims.plan {
621                        Plan::Free => "free".to_string(),
622                        Plan::ZedPro => "zed_pro".to_string(),
623                    },
624                    "model": model.name.clone(),
625                    "provider": provider.to_string(),
626                    "usage_measure": resource.to_string(),
627                }),
628            )
629            .write(&state.kinesis_client, &state.config.kinesis_stream)
630            .await
631            .log_err();
632
633            if let Some(client) = state.clickhouse_client.as_ref() {
634                report_llm_rate_limit(
635                    client,
636                    LlmRateLimitEventRow {
637                        time: Utc::now().timestamp_millis(),
638                        user_id: claims.user_id as i32,
639                        is_staff: claims.is_staff,
640                        plan: match claims.plan {
641                            Plan::Free => "free".to_string(),
642                            Plan::ZedPro => "zed_pro".to_string(),
643                        },
644                        model: model.name.clone(),
645                        provider: provider.to_string(),
646                        usage_measure: resource.to_string(),
647                        requests_this_minute: usage.requests_this_minute as u64,
648                        tokens_this_minute: usage.tokens_this_minute as u64,
649                        tokens_this_day: usage.tokens_this_day as u64,
650                        users_in_recent_minutes: users_in_recent_minutes as u64,
651                        users_in_recent_days: users_in_recent_days as u64,
652                        max_requests_per_minute: per_user_max_requests_per_minute as u64,
653                        max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
654                        max_tokens_per_day: per_user_max_tokens_per_day as u64,
655                    },
656                )
657                .await
658                .log_err();
659            }
660
661            return Err(Error::http(
662                StatusCode::TOO_MANY_REQUESTS,
663                format!("Rate limit exceeded. Maximum {} reached.", resource),
664            ));
665        }
666    }
667
668    Ok(())
669}
670
671struct CompletionChunk {
672    bytes: Vec<u8>,
673    input_tokens: usize,
674    output_tokens: usize,
675    cache_creation_input_tokens: usize,
676    cache_read_input_tokens: usize,
677}
678
679struct TokenCountingStream<S> {
680    state: Arc<LlmState>,
681    claims: LlmTokenClaims,
682    provider: LanguageModelProvider,
683    model: String,
684    tokens: TokenUsage,
685    inner_stream: S,
686}
687
688impl<S> Stream for TokenCountingStream<S>
689where
690    S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
691{
692    type Item = Result<Vec<u8>, anyhow::Error>;
693
694    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
695        match Pin::new(&mut self.inner_stream).poll_next(cx) {
696            Poll::Ready(Some(Ok(mut chunk))) => {
697                chunk.bytes.push(b'\n');
698                self.tokens.input += chunk.input_tokens;
699                self.tokens.output += chunk.output_tokens;
700                self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
701                self.tokens.input_cache_read += chunk.cache_read_input_tokens;
702                Poll::Ready(Some(Ok(chunk.bytes)))
703            }
704            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
705            Poll::Ready(None) => Poll::Ready(None),
706            Poll::Pending => Poll::Pending,
707        }
708    }
709}
710
711impl<S> Drop for TokenCountingStream<S> {
712    fn drop(&mut self) {
713        let state = self.state.clone();
714        let claims = self.claims.clone();
715        let provider = self.provider;
716        let model = std::mem::take(&mut self.model);
717        let tokens = self.tokens;
718        self.state.executor.spawn_detached(async move {
719            let usage = state
720                .db
721                .record_usage(
722                    UserId::from_proto(claims.user_id),
723                    claims.is_staff,
724                    provider,
725                    &model,
726                    tokens,
727                    claims.has_llm_subscription,
728                    Cents(claims.max_monthly_spend_in_cents),
729                    claims.free_tier_monthly_spending_limit(),
730                    Utc::now(),
731                )
732                .await
733                .log_err();
734
735            if let Some(usage) = usage {
736                tracing::info!(
737                    target: "user usage",
738                    user_id = claims.user_id,
739                    login = claims.github_user_login,
740                    authn.jti = claims.jti,
741                    is_staff = claims.is_staff,
742                    requests_this_minute = usage.requests_this_minute,
743                    tokens_this_minute = usage.tokens_this_minute,
744                );
745
746                let properties = json!({
747                    "plan": match claims.plan {
748                        Plan::Free => "free".to_string(),
749                        Plan::ZedPro => "zed_pro".to_string(),
750                    },
751                    "model": model,
752                    "provider": provider,
753                    "usage": usage,
754                    "tokens": tokens
755                });
756                SnowflakeRow::new(
757                    "Language Model Used",
758                    claims.metrics_id,
759                    claims.is_staff,
760                    claims.system_id.clone(),
761                    properties,
762                )
763                .write(&state.kinesis_client, &state.config.kinesis_stream)
764                .await
765                .log_err();
766
767                if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
768                    report_llm_usage(
769                        clickhouse_client,
770                        LlmUsageEventRow {
771                            time: Utc::now().timestamp_millis(),
772                            user_id: claims.user_id as i32,
773                            is_staff: claims.is_staff,
774                            plan: match claims.plan {
775                                Plan::Free => "free".to_string(),
776                                Plan::ZedPro => "zed_pro".to_string(),
777                            },
778                            model,
779                            provider: provider.to_string(),
780                            input_token_count: tokens.input as u64,
781                            cache_creation_input_token_count: tokens.input_cache_creation as u64,
782                            cache_read_input_token_count: tokens.input_cache_read as u64,
783                            output_token_count: tokens.output as u64,
784                            requests_this_minute: usage.requests_this_minute as u64,
785                            tokens_this_minute: usage.tokens_this_minute as u64,
786                            tokens_this_day: usage.tokens_this_day as u64,
787                            input_tokens_this_month: usage.tokens_this_month.input as u64,
788                            cache_creation_input_tokens_this_month: usage
789                                .tokens_this_month
790                                .input_cache_creation
791                                as u64,
792                            cache_read_input_tokens_this_month: usage
793                                .tokens_this_month
794                                .input_cache_read
795                                as u64,
796                            output_tokens_this_month: usage.tokens_this_month.output as u64,
797                            spending_this_month: usage.spending_this_month.0 as u64,
798                            lifetime_spending: usage.lifetime_spending.0 as u64,
799                        },
800                    )
801                    .await
802                    .log_err();
803                }
804            }
805        })
806    }
807}
808
809pub fn log_usage_periodically(state: Arc<LlmState>) {
810    state.executor.clone().spawn_detached(async move {
811        loop {
812            state
813                .executor
814                .sleep(std::time::Duration::from_secs(30))
815                .await;
816
817            for provider in LanguageModelProvider::iter() {
818                for model in state.db.model_names_for_provider(provider) {
819                    if let Some(active_user_count) = state
820                        .get_active_user_count(provider, &model)
821                        .await
822                        .log_err()
823                    {
824                        tracing::info!(
825                            target: "active user counts",
826                            provider = provider.to_string(),
827                            model = model,
828                            users_in_recent_minutes = active_user_count.users_in_recent_minutes,
829                            users_in_recent_days = active_user_count.users_in_recent_days,
830                        );
831                    }
832                }
833            }
834
835            if let Some(usages) = state
836                .db
837                .get_application_wide_usages_by_model(Utc::now())
838                .await
839                .log_err()
840            {
841                for usage in usages {
842                    tracing::info!(
843                        target: "computed usage",
844                        provider = usage.provider.to_string(),
845                        model = usage.model,
846                        requests_this_minute = usage.requests_this_minute,
847                        tokens_this_minute = usage.tokens_this_minute,
848                    );
849                }
850            }
851        }
852    })
853}