llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::{
  7    api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents,
  8    Config, Error, Result,
  9};
 10use anyhow::{anyhow, Context as _};
 11use authorization::authorize_access_to_language_model;
 12use axum::routing::get;
 13use axum::{
 14    body::Body,
 15    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 16    middleware::{self, Next},
 17    response::{IntoResponse, Response},
 18    routing::post,
 19    Extension, Json, Router, TypedHeader,
 20};
 21use chrono::{DateTime, Duration, Utc};
 22use collections::HashMap;
 23use db::TokenUsage;
 24use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 25use futures::{Stream, StreamExt as _};
 26use isahc_http_client::IsahcHttpClient;
 27use rpc::{
 28    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 29};
 30use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
 31use std::{
 32    pin::Pin,
 33    sync::Arc,
 34    task::{Context, Poll},
 35};
 36use strum::IntoEnumIterator;
 37use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
 38use tokio::sync::RwLock;
 39use util::ResultExt;
 40
 41pub use token::*;
 42
 43pub struct LlmState {
 44    pub config: Config,
 45    pub executor: Executor,
 46    pub db: Arc<LlmDatabase>,
 47    pub http_client: IsahcHttpClient,
 48    pub clickhouse_client: Option<clickhouse::Client>,
 49    active_user_count_by_model:
 50        RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
 51}
 52
 53const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 54
 55impl LlmState {
 56    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 57        let database_url = config
 58            .llm_database_url
 59            .as_ref()
 60            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 61        let max_connections = config
 62            .llm_database_max_connections
 63            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 64
 65        let mut db_options = db::ConnectOptions::new(database_url);
 66        db_options.max_connections(max_connections);
 67        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 68        db.initialize().await?;
 69
 70        let db = Arc::new(db);
 71
 72        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 73        let http_client = IsahcHttpClient::builder()
 74            .default_header("User-Agent", user_agent)
 75            .build()
 76            .map(IsahcHttpClient::from)
 77            .context("failed to construct http client")?;
 78
 79        let this = Self {
 80            executor,
 81            db,
 82            http_client,
 83            clickhouse_client: config
 84                .clickhouse_url
 85                .as_ref()
 86                .and_then(|_| build_clickhouse_client(&config).log_err()),
 87            active_user_count_by_model: RwLock::new(HashMap::default()),
 88            config,
 89        };
 90
 91        Ok(Arc::new(this))
 92    }
 93
 94    pub async fn get_active_user_count(
 95        &self,
 96        provider: LanguageModelProvider,
 97        model: &str,
 98    ) -> Result<ActiveUserCount> {
 99        let now = Utc::now();
100
101        {
102            let active_user_count_by_model = self.active_user_count_by_model.read().await;
103            if let Some((last_updated, count)) =
104                active_user_count_by_model.get(&(provider, model.to_string()))
105            {
106                if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
107                    return Ok(*count);
108                }
109            }
110        }
111
112        let mut cache = self.active_user_count_by_model.write().await;
113        let new_count = self.db.get_active_user_count(provider, model, now).await?;
114        cache.insert((provider, model.to_string()), (now, new_count));
115        Ok(new_count)
116    }
117}
118
119pub fn routes() -> Router<(), Body> {
120    Router::new()
121        .route("/models", get(list_models))
122        .route("/completion", post(perform_completion))
123        .layer(middleware::from_fn(validate_api_token))
124}
125
126async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
127    let token = req
128        .headers()
129        .get(http::header::AUTHORIZATION)
130        .and_then(|header| header.to_str().ok())
131        .ok_or_else(|| {
132            Error::http(
133                StatusCode::BAD_REQUEST,
134                "missing authorization header".to_string(),
135            )
136        })?
137        .strip_prefix("Bearer ")
138        .ok_or_else(|| {
139            Error::http(
140                StatusCode::BAD_REQUEST,
141                "invalid authorization header".to_string(),
142            )
143        })?;
144
145    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
146    match LlmTokenClaims::validate(token, &state.config) {
147        Ok(claims) => {
148            if state.db.is_access_token_revoked(&claims.jti).await? {
149                return Err(Error::http(
150                    StatusCode::UNAUTHORIZED,
151                    "unauthorized".to_string(),
152                ));
153            }
154
155            tracing::Span::current()
156                .record("user_id", claims.user_id)
157                .record("login", claims.github_user_login.clone())
158                .record("authn.jti", &claims.jti)
159                .record("is_staff", claims.is_staff);
160
161            req.extensions_mut().insert(claims);
162            Ok::<_, Error>(next.run(req).await.into_response())
163        }
164        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
165            StatusCode::UNAUTHORIZED,
166            "unauthorized".to_string(),
167            [(
168                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
169                HeaderValue::from_static("true"),
170            )]
171            .into_iter()
172            .collect(),
173        )),
174        Err(_err) => Err(Error::http(
175            StatusCode::UNAUTHORIZED,
176            "unauthorized".to_string(),
177        )),
178    }
179}
180
181async fn list_models(
182    Extension(state): Extension<Arc<LlmState>>,
183    Extension(claims): Extension<LlmTokenClaims>,
184    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
185) -> Result<Json<ListModelsResponse>> {
186    let country_code = country_code_header.map(|header| header.to_string());
187
188    let mut accessible_models = Vec::new();
189
190    for (provider, model) in state.db.all_models() {
191        let authorize_result = authorize_access_to_language_model(
192            &state.config,
193            &claims,
194            country_code.as_deref(),
195            provider,
196            &model.name,
197        );
198
199        if authorize_result.is_ok() {
200            accessible_models.push(rpc::LanguageModel {
201                provider,
202                name: model.name,
203            });
204        }
205    }
206
207    Ok(Json(ListModelsResponse {
208        models: accessible_models,
209    }))
210}
211
212async fn perform_completion(
213    Extension(state): Extension<Arc<LlmState>>,
214    Extension(claims): Extension<LlmTokenClaims>,
215    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
216    Json(params): Json<PerformCompletionParams>,
217) -> Result<impl IntoResponse> {
218    let model = normalize_model_name(
219        state.db.model_names_for_provider(params.provider),
220        params.model,
221    );
222
223    authorize_access_to_language_model(
224        &state.config,
225        &claims,
226        country_code_header
227            .map(|header| header.to_string())
228            .as_deref(),
229        params.provider,
230        &model,
231    )?;
232
233    check_usage_limit(&state, params.provider, &model, &claims).await?;
234
235    let stream = match params.provider {
236        LanguageModelProvider::Anthropic => {
237            let api_key = if claims.is_staff {
238                state
239                    .config
240                    .anthropic_staff_api_key
241                    .as_ref()
242                    .context("no Anthropic AI staff API key configured on the server")?
243            } else {
244                state
245                    .config
246                    .anthropic_api_key
247                    .as_ref()
248                    .context("no Anthropic AI API key configured on the server")?
249            };
250
251            let mut request: anthropic::Request =
252                serde_json::from_str(params.provider_request.get())?;
253
254            // Override the model on the request with the latest version of the model that is
255            // known to the server.
256            //
257            // Right now, we use the version that's defined in `model.id()`, but we will likely
258            // want to change this code once a new version of an Anthropic model is released,
259            // so that users can use the new version, without having to update Zed.
260            request.model = match model.as_str() {
261                "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
262                "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
263                "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
264                "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
265                _ => request.model,
266            };
267
268            let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
269                &state.http_client,
270                anthropic::ANTHROPIC_API_URL,
271                api_key,
272                request,
273                None,
274            )
275            .await
276            .map_err(|err| match err {
277                anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
278                    Some(anthropic::ApiErrorCode::RateLimitError) => {
279                        tracing::info!(
280                            target: "upstream rate limit exceeded",
281                            user_id = claims.user_id,
282                            login = claims.github_user_login,
283                            authn.jti = claims.jti,
284                            is_staff = claims.is_staff,
285                            provider = params.provider.to_string(),
286                            model = model
287                        );
288
289                        Error::http(
290                            StatusCode::TOO_MANY_REQUESTS,
291                            "Upstream Anthropic rate limit exceeded.".to_string(),
292                        )
293                    }
294                    Some(anthropic::ApiErrorCode::InvalidRequestError) => {
295                        Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
296                    }
297                    Some(anthropic::ApiErrorCode::OverloadedError) => {
298                        Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
299                    }
300                    Some(_) => {
301                        Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
302                    }
303                    None => Error::Internal(anyhow!(err)),
304                },
305                anthropic::AnthropicError::Other(err) => Error::Internal(err),
306            })?;
307
308            if let Some(rate_limit_info) = rate_limit_info {
309                tracing::info!(
310                    target: "upstream rate limit",
311                    is_staff = claims.is_staff,
312                    provider = params.provider.to_string(),
313                    model = model,
314                    tokens_remaining = rate_limit_info.tokens_remaining,
315                    requests_remaining = rate_limit_info.requests_remaining,
316                    requests_reset = ?rate_limit_info.requests_reset,
317                    tokens_reset = ?rate_limit_info.tokens_reset,
318                );
319            }
320
321            chunks
322                .map(move |event| {
323                    let chunk = event?;
324                    let (
325                        input_tokens,
326                        output_tokens,
327                        cache_creation_input_tokens,
328                        cache_read_input_tokens,
329                    ) = match &chunk {
330                        anthropic::Event::MessageStart {
331                            message: anthropic::Response { usage, .. },
332                        }
333                        | anthropic::Event::MessageDelta { usage, .. } => (
334                            usage.input_tokens.unwrap_or(0) as usize,
335                            usage.output_tokens.unwrap_or(0) as usize,
336                            usage.cache_creation_input_tokens.unwrap_or(0) as usize,
337                            usage.cache_read_input_tokens.unwrap_or(0) as usize,
338                        ),
339                        _ => (0, 0, 0, 0),
340                    };
341
342                    anyhow::Ok(CompletionChunk {
343                        bytes: serde_json::to_vec(&chunk).unwrap(),
344                        input_tokens,
345                        output_tokens,
346                        cache_creation_input_tokens,
347                        cache_read_input_tokens,
348                    })
349                })
350                .boxed()
351        }
352        LanguageModelProvider::OpenAi => {
353            let api_key = state
354                .config
355                .openai_api_key
356                .as_ref()
357                .context("no OpenAI API key configured on the server")?;
358            let chunks = open_ai::stream_completion(
359                &state.http_client,
360                open_ai::OPEN_AI_API_URL,
361                api_key,
362                serde_json::from_str(params.provider_request.get())?,
363                None,
364            )
365            .await?;
366
367            chunks
368                .map(|event| {
369                    event.map(|chunk| {
370                        let input_tokens =
371                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
372                        let output_tokens =
373                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
374                        CompletionChunk {
375                            bytes: serde_json::to_vec(&chunk).unwrap(),
376                            input_tokens,
377                            output_tokens,
378                            cache_creation_input_tokens: 0,
379                            cache_read_input_tokens: 0,
380                        }
381                    })
382                })
383                .boxed()
384        }
385        LanguageModelProvider::Google => {
386            let api_key = state
387                .config
388                .google_ai_api_key
389                .as_ref()
390                .context("no Google AI API key configured on the server")?;
391            let chunks = google_ai::stream_generate_content(
392                &state.http_client,
393                google_ai::API_URL,
394                api_key,
395                serde_json::from_str(params.provider_request.get())?,
396                None,
397            )
398            .await?;
399
400            chunks
401                .map(|event| {
402                    event.map(|chunk| {
403                        // TODO - implement token counting for Google AI
404                        CompletionChunk {
405                            bytes: serde_json::to_vec(&chunk).unwrap(),
406                            input_tokens: 0,
407                            output_tokens: 0,
408                            cache_creation_input_tokens: 0,
409                            cache_read_input_tokens: 0,
410                        }
411                    })
412                })
413                .boxed()
414        }
415    };
416
417    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
418        state,
419        claims,
420        provider: params.provider,
421        model,
422        tokens: TokenUsage::default(),
423        inner_stream: stream,
424    })))
425}
426
427fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
428    if let Some(known_model_name) = known_models
429        .iter()
430        .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
431        .max_by_key(|known_model_name| known_model_name.len())
432    {
433        known_model_name.to_string()
434    } else {
435        name
436    }
437}
438
439/// The maximum monthly spending an individual user can reach on the free tier
440/// before they have to pay.
441pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(5);
442
443/// The default value to use for maximum spend per month if the user did not
444/// explicitly set a maximum spend.
445///
446/// Used to prevent surprise bills.
447pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
448
449/// The maximum lifetime spending an individual user can reach before being cut off.
450const LIFETIME_SPENDING_LIMIT: Cents = Cents::from_dollars(1_000);
451
452async fn check_usage_limit(
453    state: &Arc<LlmState>,
454    provider: LanguageModelProvider,
455    model_name: &str,
456    claims: &LlmTokenClaims,
457) -> Result<()> {
458    let model = state.db.model(provider, model_name)?;
459    let usage = state
460        .db
461        .get_usage(
462            UserId::from_proto(claims.user_id),
463            provider,
464            model_name,
465            Utc::now(),
466        )
467        .await?;
468
469    if state.config.is_llm_billing_enabled() {
470        if usage.spending_this_month >= FREE_TIER_MONTHLY_SPENDING_LIMIT {
471            if !claims.has_llm_subscription {
472                return Err(Error::http(
473                    StatusCode::PAYMENT_REQUIRED,
474                    "Maximum spending limit reached for this month.".to_string(),
475                ));
476            }
477
478            if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
479                return Err(Error::Http(
480                    StatusCode::FORBIDDEN,
481                    "Maximum spending limit reached for this month.".to_string(),
482                    [(
483                        HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
484                        HeaderValue::from_static("true"),
485                    )]
486                    .into_iter()
487                    .collect(),
488                ));
489            }
490        }
491    }
492
493    // TODO: Remove this once we've rolled out monthly spending limits.
494    if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT {
495        return Err(Error::http(
496            StatusCode::FORBIDDEN,
497            "Maximum spending limit reached.".to_string(),
498        ));
499    }
500
501    let active_users = state.get_active_user_count(provider, model_name).await?;
502
503    let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
504    let users_in_recent_days = active_users.users_in_recent_days.max(1);
505
506    let per_user_max_requests_per_minute =
507        model.max_requests_per_minute as usize / users_in_recent_minutes;
508    let per_user_max_tokens_per_minute =
509        model.max_tokens_per_minute as usize / users_in_recent_minutes;
510    let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
511
512    let checks = [
513        (
514            usage.requests_this_minute,
515            per_user_max_requests_per_minute,
516            UsageMeasure::RequestsPerMinute,
517        ),
518        (
519            usage.tokens_this_minute,
520            per_user_max_tokens_per_minute,
521            UsageMeasure::TokensPerMinute,
522        ),
523        (
524            usage.tokens_this_day,
525            per_user_max_tokens_per_day,
526            UsageMeasure::TokensPerDay,
527        ),
528    ];
529
530    for (used, limit, usage_measure) in checks {
531        // Temporarily bypass rate-limiting for staff members.
532        if claims.is_staff {
533            continue;
534        }
535
536        if used > limit {
537            let resource = match usage_measure {
538                UsageMeasure::RequestsPerMinute => "requests_per_minute",
539                UsageMeasure::TokensPerMinute => "tokens_per_minute",
540                UsageMeasure::TokensPerDay => "tokens_per_day",
541            };
542
543            if let Some(client) = state.clickhouse_client.as_ref() {
544                tracing::info!(
545                    target: "user rate limit",
546                    user_id = claims.user_id,
547                    login = claims.github_user_login,
548                    authn.jti = claims.jti,
549                    is_staff = claims.is_staff,
550                    provider = provider.to_string(),
551                    model = model.name,
552                    requests_this_minute = usage.requests_this_minute,
553                    tokens_this_minute = usage.tokens_this_minute,
554                    tokens_this_day = usage.tokens_this_day,
555                    users_in_recent_minutes = users_in_recent_minutes,
556                    users_in_recent_days = users_in_recent_days,
557                    max_requests_per_minute = per_user_max_requests_per_minute,
558                    max_tokens_per_minute = per_user_max_tokens_per_minute,
559                    max_tokens_per_day = per_user_max_tokens_per_day,
560                );
561
562                report_llm_rate_limit(
563                    client,
564                    LlmRateLimitEventRow {
565                        time: Utc::now().timestamp_millis(),
566                        user_id: claims.user_id as i32,
567                        is_staff: claims.is_staff,
568                        plan: match claims.plan {
569                            Plan::Free => "free".to_string(),
570                            Plan::ZedPro => "zed_pro".to_string(),
571                        },
572                        model: model.name.clone(),
573                        provider: provider.to_string(),
574                        usage_measure: resource.to_string(),
575                        requests_this_minute: usage.requests_this_minute as u64,
576                        tokens_this_minute: usage.tokens_this_minute as u64,
577                        tokens_this_day: usage.tokens_this_day as u64,
578                        users_in_recent_minutes: users_in_recent_minutes as u64,
579                        users_in_recent_days: users_in_recent_days as u64,
580                        max_requests_per_minute: per_user_max_requests_per_minute as u64,
581                        max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
582                        max_tokens_per_day: per_user_max_tokens_per_day as u64,
583                    },
584                )
585                .await
586                .log_err();
587            }
588
589            return Err(Error::http(
590                StatusCode::TOO_MANY_REQUESTS,
591                format!("Rate limit exceeded. Maximum {} reached.", resource),
592            ));
593        }
594    }
595
596    Ok(())
597}
598
599struct CompletionChunk {
600    bytes: Vec<u8>,
601    input_tokens: usize,
602    output_tokens: usize,
603    cache_creation_input_tokens: usize,
604    cache_read_input_tokens: usize,
605}
606
607struct TokenCountingStream<S> {
608    state: Arc<LlmState>,
609    claims: LlmTokenClaims,
610    provider: LanguageModelProvider,
611    model: String,
612    tokens: TokenUsage,
613    inner_stream: S,
614}
615
616impl<S> Stream for TokenCountingStream<S>
617where
618    S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
619{
620    type Item = Result<Vec<u8>, anyhow::Error>;
621
622    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
623        match Pin::new(&mut self.inner_stream).poll_next(cx) {
624            Poll::Ready(Some(Ok(mut chunk))) => {
625                chunk.bytes.push(b'\n');
626                self.tokens.input += chunk.input_tokens;
627                self.tokens.output += chunk.output_tokens;
628                self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
629                self.tokens.input_cache_read += chunk.cache_read_input_tokens;
630                Poll::Ready(Some(Ok(chunk.bytes)))
631            }
632            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
633            Poll::Ready(None) => Poll::Ready(None),
634            Poll::Pending => Poll::Pending,
635        }
636    }
637}
638
639impl<S> Drop for TokenCountingStream<S> {
640    fn drop(&mut self) {
641        let state = self.state.clone();
642        let claims = self.claims.clone();
643        let provider = self.provider;
644        let model = std::mem::take(&mut self.model);
645        let tokens = self.tokens;
646        self.state.executor.spawn_detached(async move {
647            let usage = state
648                .db
649                .record_usage(
650                    UserId::from_proto(claims.user_id),
651                    claims.is_staff,
652                    provider,
653                    &model,
654                    tokens,
655                    claims.has_llm_subscription,
656                    Cents(claims.max_monthly_spend_in_cents),
657                    Utc::now(),
658                )
659                .await
660                .log_err();
661
662            if let Some(usage) = usage {
663                tracing::info!(
664                    target: "user usage",
665                    user_id = claims.user_id,
666                    login = claims.github_user_login,
667                    authn.jti = claims.jti,
668                    is_staff = claims.is_staff,
669                    requests_this_minute = usage.requests_this_minute,
670                    tokens_this_minute = usage.tokens_this_minute,
671                );
672
673                if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
674                    report_llm_usage(
675                        clickhouse_client,
676                        LlmUsageEventRow {
677                            time: Utc::now().timestamp_millis(),
678                            user_id: claims.user_id as i32,
679                            is_staff: claims.is_staff,
680                            plan: match claims.plan {
681                                Plan::Free => "free".to_string(),
682                                Plan::ZedPro => "zed_pro".to_string(),
683                            },
684                            model,
685                            provider: provider.to_string(),
686                            input_token_count: tokens.input as u64,
687                            cache_creation_input_token_count: tokens.input_cache_creation as u64,
688                            cache_read_input_token_count: tokens.input_cache_read as u64,
689                            output_token_count: tokens.output as u64,
690                            requests_this_minute: usage.requests_this_minute as u64,
691                            tokens_this_minute: usage.tokens_this_minute as u64,
692                            tokens_this_day: usage.tokens_this_day as u64,
693                            input_tokens_this_month: usage.tokens_this_month.input as u64,
694                            cache_creation_input_tokens_this_month: usage
695                                .tokens_this_month
696                                .input_cache_creation
697                                as u64,
698                            cache_read_input_tokens_this_month: usage
699                                .tokens_this_month
700                                .input_cache_read
701                                as u64,
702                            output_tokens_this_month: usage.tokens_this_month.output as u64,
703                            spending_this_month: usage.spending_this_month.0 as u64,
704                            lifetime_spending: usage.lifetime_spending.0 as u64,
705                        },
706                    )
707                    .await
708                    .log_err();
709                }
710            }
711        })
712    }
713}
714
715pub fn log_usage_periodically(state: Arc<LlmState>) {
716    state.executor.clone().spawn_detached(async move {
717        loop {
718            state
719                .executor
720                .sleep(std::time::Duration::from_secs(30))
721                .await;
722
723            for provider in LanguageModelProvider::iter() {
724                for model in state.db.model_names_for_provider(provider) {
725                    if let Some(active_user_count) = state
726                        .get_active_user_count(provider, &model)
727                        .await
728                        .log_err()
729                    {
730                        tracing::info!(
731                            target: "active user counts",
732                            provider = provider.to_string(),
733                            model = model,
734                            users_in_recent_minutes = active_user_count.users_in_recent_minutes,
735                            users_in_recent_days = active_user_count.users_in_recent_days,
736                        );
737                    }
738                }
739            }
740
741            if let Some(usages) = state
742                .db
743                .get_application_wide_usages_by_model(Utc::now())
744                .await
745                .log_err()
746            {
747                for usage in usages {
748                    tracing::info!(
749                        target: "computed usage",
750                        provider = usage.provider.to_string(),
751                        model = usage.model,
752                        requests_this_minute = usage.requests_this_minute,
753                        tokens_this_minute = usage.tokens_this_minute,
754                    );
755                }
756            }
757        }
758    })
759}