llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::{
  7    api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
  8    Config, Error, Result,
  9};
 10use anyhow::{anyhow, Context as _};
 11use authorization::authorize_access_to_language_model;
 12use axum::routing::get;
 13use axum::{
 14    body::Body,
 15    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 16    middleware::{self, Next},
 17    response::{IntoResponse, Response},
 18    routing::post,
 19    Extension, Json, Router, TypedHeader,
 20};
 21use chrono::{DateTime, Duration, Utc};
 22use collections::HashMap;
 23use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 24use futures::{Stream, StreamExt as _};
 25use http_client::IsahcHttpClient;
 26use rpc::ListModelsResponse;
 27use rpc::{
 28    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 29};
 30use std::{
 31    pin::Pin,
 32    sync::Arc,
 33    task::{Context, Poll},
 34};
 35use strum::IntoEnumIterator;
 36use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
 37use tokio::sync::RwLock;
 38use util::ResultExt;
 39
 40pub use token::*;
 41
 42pub struct LlmState {
 43    pub config: Config,
 44    pub executor: Executor,
 45    pub db: Arc<LlmDatabase>,
 46    pub http_client: IsahcHttpClient,
 47    pub clickhouse_client: Option<clickhouse::Client>,
 48    active_user_count_by_model:
 49        RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
 50}
 51
 52const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 53
 54impl LlmState {
 55    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 56        let database_url = config
 57            .llm_database_url
 58            .as_ref()
 59            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 60        let max_connections = config
 61            .llm_database_max_connections
 62            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 63
 64        let mut db_options = db::ConnectOptions::new(database_url);
 65        db_options.max_connections(max_connections);
 66        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 67        db.initialize().await?;
 68
 69        let db = Arc::new(db);
 70
 71        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 72        let http_client = IsahcHttpClient::builder()
 73            .default_header("User-Agent", user_agent)
 74            .build()
 75            .context("failed to construct http client")?;
 76
 77        let this = Self {
 78            executor,
 79            db,
 80            http_client,
 81            clickhouse_client: config
 82                .clickhouse_url
 83                .as_ref()
 84                .and_then(|_| build_clickhouse_client(&config).log_err()),
 85            active_user_count_by_model: RwLock::new(HashMap::default()),
 86            config,
 87        };
 88
 89        Ok(Arc::new(this))
 90    }
 91
 92    pub async fn get_active_user_count(
 93        &self,
 94        provider: LanguageModelProvider,
 95        model: &str,
 96    ) -> Result<ActiveUserCount> {
 97        let now = Utc::now();
 98
 99        {
100            let active_user_count_by_model = self.active_user_count_by_model.read().await;
101            if let Some((last_updated, count)) =
102                active_user_count_by_model.get(&(provider, model.to_string()))
103            {
104                if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
105                    return Ok(*count);
106                }
107            }
108        }
109
110        let mut cache = self.active_user_count_by_model.write().await;
111        let new_count = self.db.get_active_user_count(provider, model, now).await?;
112        cache.insert((provider, model.to_string()), (now, new_count));
113        Ok(new_count)
114    }
115}
116
117pub fn routes() -> Router<(), Body> {
118    Router::new()
119        .route("/models", get(list_models))
120        .route("/completion", post(perform_completion))
121        .layer(middleware::from_fn(validate_api_token))
122}
123
124async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
125    let token = req
126        .headers()
127        .get(http::header::AUTHORIZATION)
128        .and_then(|header| header.to_str().ok())
129        .ok_or_else(|| {
130            Error::http(
131                StatusCode::BAD_REQUEST,
132                "missing authorization header".to_string(),
133            )
134        })?
135        .strip_prefix("Bearer ")
136        .ok_or_else(|| {
137            Error::http(
138                StatusCode::BAD_REQUEST,
139                "invalid authorization header".to_string(),
140            )
141        })?;
142
143    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
144    match LlmTokenClaims::validate(&token, &state.config) {
145        Ok(claims) => {
146            if state.db.is_access_token_revoked(&claims.jti).await? {
147                return Err(Error::http(
148                    StatusCode::UNAUTHORIZED,
149                    "unauthorized".to_string(),
150                ));
151            }
152
153            tracing::Span::current()
154                .record("user_id", claims.user_id)
155                .record("login", claims.github_user_login.clone())
156                .record("authn.jti", &claims.jti)
157                .record("is_staff", &claims.is_staff);
158
159            req.extensions_mut().insert(claims);
160            Ok::<_, Error>(next.run(req).await.into_response())
161        }
162        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
163            StatusCode::UNAUTHORIZED,
164            "unauthorized".to_string(),
165            [(
166                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
167                HeaderValue::from_static("true"),
168            )]
169            .into_iter()
170            .collect(),
171        )),
172        Err(_err) => Err(Error::http(
173            StatusCode::UNAUTHORIZED,
174            "unauthorized".to_string(),
175        )),
176    }
177}
178
179async fn list_models(
180    Extension(state): Extension<Arc<LlmState>>,
181    Extension(claims): Extension<LlmTokenClaims>,
182    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
183) -> Result<Json<ListModelsResponse>> {
184    let country_code = country_code_header.map(|header| header.to_string());
185
186    let mut accessible_models = Vec::new();
187
188    for (provider, model) in state.db.all_models() {
189        let authorize_result = authorize_access_to_language_model(
190            &state.config,
191            &claims,
192            country_code.as_deref(),
193            provider,
194            &model.name,
195        );
196
197        if authorize_result.is_ok() {
198            accessible_models.push(rpc::LanguageModel {
199                provider,
200                name: model.name,
201            });
202        }
203    }
204
205    Ok(Json(ListModelsResponse {
206        models: accessible_models,
207    }))
208}
209
210async fn perform_completion(
211    Extension(state): Extension<Arc<LlmState>>,
212    Extension(claims): Extension<LlmTokenClaims>,
213    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
214    Json(params): Json<PerformCompletionParams>,
215) -> Result<impl IntoResponse> {
216    let model = normalize_model_name(
217        state.db.model_names_for_provider(params.provider),
218        params.model,
219    );
220
221    authorize_access_to_language_model(
222        &state.config,
223        &claims,
224        country_code_header
225            .map(|header| header.to_string())
226            .as_deref(),
227        params.provider,
228        &model,
229    )?;
230
231    check_usage_limit(&state, params.provider, &model, &claims).await?;
232
233    let stream = match params.provider {
234        LanguageModelProvider::Anthropic => {
235            let api_key = if claims.is_staff {
236                state
237                    .config
238                    .anthropic_staff_api_key
239                    .as_ref()
240                    .context("no Anthropic AI staff API key configured on the server")?
241            } else {
242                state
243                    .config
244                    .anthropic_api_key
245                    .as_ref()
246                    .context("no Anthropic AI API key configured on the server")?
247            };
248
249            let mut request: anthropic::Request =
250                serde_json::from_str(&params.provider_request.get())?;
251
252            // Override the model on the request with the latest version of the model that is
253            // known to the server.
254            //
255            // Right now, we use the version that's defined in `model.id()`, but we will likely
256            // want to change this code once a new version of an Anthropic model is released,
257            // so that users can use the new version, without having to update Zed.
258            request.model = match model.as_str() {
259                "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
260                "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
261                "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
262                "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
263                _ => request.model,
264            };
265
266            let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
267                &state.http_client,
268                anthropic::ANTHROPIC_API_URL,
269                api_key,
270                request,
271                None,
272            )
273            .await
274            .map_err(|err| match err {
275                anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
276                    Some(anthropic::ApiErrorCode::RateLimitError) => {
277                        tracing::info!(
278                            target: "upstream rate limit exceeded",
279                            user_id = claims.user_id,
280                            login = claims.github_user_login,
281                            authn.jti = claims.jti,
282                            is_staff = claims.is_staff,
283                            provider = params.provider.to_string(),
284                            model = model
285                        );
286
287                        Error::http(
288                            StatusCode::TOO_MANY_REQUESTS,
289                            "Upstream Anthropic rate limit exceeded.".to_string(),
290                        )
291                    }
292                    Some(anthropic::ApiErrorCode::InvalidRequestError) => {
293                        Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
294                    }
295                    Some(anthropic::ApiErrorCode::OverloadedError) => {
296                        Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
297                    }
298                    Some(_) => {
299                        Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
300                    }
301                    None => Error::Internal(anyhow!(err)),
302                },
303                anthropic::AnthropicError::Other(err) => Error::Internal(err),
304            })?;
305
306            if let Some(rate_limit_info) = rate_limit_info {
307                tracing::info!(
308                    target: "upstream rate limit",
309                    is_staff = claims.is_staff,
310                    provider = params.provider.to_string(),
311                    model = model,
312                    tokens_remaining = rate_limit_info.tokens_remaining,
313                    requests_remaining = rate_limit_info.requests_remaining,
314                    requests_reset = ?rate_limit_info.requests_reset,
315                    tokens_reset = ?rate_limit_info.tokens_reset,
316                );
317            }
318
319            chunks
320                .map(move |event| {
321                    let chunk = event?;
322                    let (input_tokens, output_tokens) = match &chunk {
323                        anthropic::Event::MessageStart {
324                            message: anthropic::Response { usage, .. },
325                        }
326                        | anthropic::Event::MessageDelta { usage, .. } => (
327                            usage.input_tokens.unwrap_or(0) as usize,
328                            usage.output_tokens.unwrap_or(0) as usize,
329                        ),
330                        _ => (0, 0),
331                    };
332
333                    anyhow::Ok((
334                        serde_json::to_vec(&chunk).unwrap(),
335                        input_tokens,
336                        output_tokens,
337                    ))
338                })
339                .boxed()
340        }
341        LanguageModelProvider::OpenAi => {
342            let api_key = state
343                .config
344                .openai_api_key
345                .as_ref()
346                .context("no OpenAI API key configured on the server")?;
347            let chunks = open_ai::stream_completion(
348                &state.http_client,
349                open_ai::OPEN_AI_API_URL,
350                api_key,
351                serde_json::from_str(&params.provider_request.get())?,
352                None,
353            )
354            .await?;
355
356            chunks
357                .map(|event| {
358                    event.map(|chunk| {
359                        let input_tokens =
360                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
361                        let output_tokens =
362                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
363                        (
364                            serde_json::to_vec(&chunk).unwrap(),
365                            input_tokens,
366                            output_tokens,
367                        )
368                    })
369                })
370                .boxed()
371        }
372        LanguageModelProvider::Google => {
373            let api_key = state
374                .config
375                .google_ai_api_key
376                .as_ref()
377                .context("no Google AI API key configured on the server")?;
378            let chunks = google_ai::stream_generate_content(
379                &state.http_client,
380                google_ai::API_URL,
381                api_key,
382                serde_json::from_str(&params.provider_request.get())?,
383            )
384            .await?;
385
386            chunks
387                .map(|event| {
388                    event.map(|chunk| {
389                        // TODO - implement token counting for Google AI
390                        let input_tokens = 0;
391                        let output_tokens = 0;
392                        (
393                            serde_json::to_vec(&chunk).unwrap(),
394                            input_tokens,
395                            output_tokens,
396                        )
397                    })
398                })
399                .boxed()
400        }
401        LanguageModelProvider::Zed => {
402            let api_key = state
403                .config
404                .qwen2_7b_api_key
405                .as_ref()
406                .context("no Qwen2-7B API key configured on the server")?;
407            let api_url = state
408                .config
409                .qwen2_7b_api_url
410                .as_ref()
411                .context("no Qwen2-7B URL configured on the server")?;
412            let chunks = open_ai::stream_completion(
413                &state.http_client,
414                &api_url,
415                api_key,
416                serde_json::from_str(&params.provider_request.get())?,
417                None,
418            )
419            .await?;
420
421            chunks
422                .map(|event| {
423                    event.map(|chunk| {
424                        let input_tokens =
425                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
426                        let output_tokens =
427                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
428                        (
429                            serde_json::to_vec(&chunk).unwrap(),
430                            input_tokens,
431                            output_tokens,
432                        )
433                    })
434                })
435                .boxed()
436        }
437    };
438
439    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
440        state,
441        claims,
442        provider: params.provider,
443        model,
444        input_tokens: 0,
445        output_tokens: 0,
446        inner_stream: stream,
447    })))
448}
449
450fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
451    if let Some(known_model_name) = known_models
452        .iter()
453        .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
454        .max_by_key(|known_model_name| known_model_name.len())
455    {
456        known_model_name.to_string()
457    } else {
458        name
459    }
460}
461
462/// The maximum lifetime spending an individual user can reach before being cut off.
463///
464/// Represented in cents.
465const LIFETIME_SPENDING_LIMIT_IN_CENTS: usize = 1_000 * 100;
466
467async fn check_usage_limit(
468    state: &Arc<LlmState>,
469    provider: LanguageModelProvider,
470    model_name: &str,
471    claims: &LlmTokenClaims,
472) -> Result<()> {
473    let model = state.db.model(provider, model_name)?;
474    let usage = state
475        .db
476        .get_usage(
477            UserId::from_proto(claims.user_id),
478            provider,
479            model_name,
480            Utc::now(),
481        )
482        .await?;
483
484    if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS {
485        return Err(Error::http(
486            StatusCode::FORBIDDEN,
487            "Maximum spending limit reached.".to_string(),
488        ));
489    }
490
491    let active_users = state.get_active_user_count(provider, model_name).await?;
492
493    let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
494    let users_in_recent_days = active_users.users_in_recent_days.max(1);
495
496    let per_user_max_requests_per_minute =
497        model.max_requests_per_minute as usize / users_in_recent_minutes;
498    let per_user_max_tokens_per_minute =
499        model.max_tokens_per_minute as usize / users_in_recent_minutes;
500    let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
501
502    let checks = [
503        (
504            usage.requests_this_minute,
505            per_user_max_requests_per_minute,
506            UsageMeasure::RequestsPerMinute,
507        ),
508        (
509            usage.tokens_this_minute,
510            per_user_max_tokens_per_minute,
511            UsageMeasure::TokensPerMinute,
512        ),
513        (
514            usage.tokens_this_day,
515            per_user_max_tokens_per_day,
516            UsageMeasure::TokensPerDay,
517        ),
518    ];
519
520    for (used, limit, usage_measure) in checks {
521        // Temporarily bypass rate-limiting for staff members.
522        if claims.is_staff {
523            continue;
524        }
525
526        if used > limit {
527            let resource = match usage_measure {
528                UsageMeasure::RequestsPerMinute => "requests_per_minute",
529                UsageMeasure::TokensPerMinute => "tokens_per_minute",
530                UsageMeasure::TokensPerDay => "tokens_per_day",
531                _ => "",
532            };
533
534            if let Some(client) = state.clickhouse_client.as_ref() {
535                tracing::info!(
536                    target: "user rate limit",
537                    user_id = claims.user_id,
538                    login = claims.github_user_login,
539                    authn.jti = claims.jti,
540                    is_staff = claims.is_staff,
541                    provider = provider.to_string(),
542                    model = model.name,
543                    requests_this_minute = usage.requests_this_minute,
544                    tokens_this_minute = usage.tokens_this_minute,
545                    tokens_this_day = usage.tokens_this_day,
546                    users_in_recent_minutes = users_in_recent_minutes,
547                    users_in_recent_days = users_in_recent_days,
548                    max_requests_per_minute = per_user_max_requests_per_minute,
549                    max_tokens_per_minute = per_user_max_tokens_per_minute,
550                    max_tokens_per_day = per_user_max_tokens_per_day,
551                );
552
553                report_llm_rate_limit(
554                    client,
555                    LlmRateLimitEventRow {
556                        time: Utc::now().timestamp_millis(),
557                        user_id: claims.user_id as i32,
558                        is_staff: claims.is_staff,
559                        plan: match claims.plan {
560                            Plan::Free => "free".to_string(),
561                            Plan::ZedPro => "zed_pro".to_string(),
562                        },
563                        model: model.name.clone(),
564                        provider: provider.to_string(),
565                        usage_measure: resource.to_string(),
566                        requests_this_minute: usage.requests_this_minute as u64,
567                        tokens_this_minute: usage.tokens_this_minute as u64,
568                        tokens_this_day: usage.tokens_this_day as u64,
569                        users_in_recent_minutes: users_in_recent_minutes as u64,
570                        users_in_recent_days: users_in_recent_days as u64,
571                        max_requests_per_minute: per_user_max_requests_per_minute as u64,
572                        max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
573                        max_tokens_per_day: per_user_max_tokens_per_day as u64,
574                    },
575                )
576                .await
577                .log_err();
578            }
579
580            return Err(Error::http(
581                StatusCode::TOO_MANY_REQUESTS,
582                format!("Rate limit exceeded. Maximum {} reached.", resource),
583            ));
584        }
585    }
586
587    Ok(())
588}
589
590struct TokenCountingStream<S> {
591    state: Arc<LlmState>,
592    claims: LlmTokenClaims,
593    provider: LanguageModelProvider,
594    model: String,
595    input_tokens: usize,
596    output_tokens: usize,
597    inner_stream: S,
598}
599
600impl<S> Stream for TokenCountingStream<S>
601where
602    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
603{
604    type Item = Result<Vec<u8>, anyhow::Error>;
605
606    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
607        match Pin::new(&mut self.inner_stream).poll_next(cx) {
608            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
609                bytes.push(b'\n');
610                self.input_tokens += input_tokens;
611                self.output_tokens += output_tokens;
612                Poll::Ready(Some(Ok(bytes)))
613            }
614            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
615            Poll::Ready(None) => Poll::Ready(None),
616            Poll::Pending => Poll::Pending,
617        }
618    }
619}
620
621impl<S> Drop for TokenCountingStream<S> {
622    fn drop(&mut self) {
623        let state = self.state.clone();
624        let claims = self.claims.clone();
625        let provider = self.provider;
626        let model = std::mem::take(&mut self.model);
627        let input_token_count = self.input_tokens;
628        let output_token_count = self.output_tokens;
629        self.state.executor.spawn_detached(async move {
630            let usage = state
631                .db
632                .record_usage(
633                    UserId::from_proto(claims.user_id),
634                    claims.is_staff,
635                    provider,
636                    &model,
637                    input_token_count,
638                    output_token_count,
639                    Utc::now(),
640                )
641                .await
642                .log_err();
643
644            if let Some(usage) = usage {
645                tracing::info!(
646                    target: "user usage",
647                    user_id = claims.user_id,
648                    login = claims.github_user_login,
649                    authn.jti = claims.jti,
650                    is_staff = claims.is_staff,
651                    requests_this_minute = usage.requests_this_minute,
652                    tokens_this_minute = usage.tokens_this_minute,
653                );
654
655                if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
656                    report_llm_usage(
657                        clickhouse_client,
658                        LlmUsageEventRow {
659                            time: Utc::now().timestamp_millis(),
660                            user_id: claims.user_id as i32,
661                            is_staff: claims.is_staff,
662                            plan: match claims.plan {
663                                Plan::Free => "free".to_string(),
664                                Plan::ZedPro => "zed_pro".to_string(),
665                            },
666                            model,
667                            provider: provider.to_string(),
668                            input_token_count: input_token_count as u64,
669                            output_token_count: output_token_count as u64,
670                            requests_this_minute: usage.requests_this_minute as u64,
671                            tokens_this_minute: usage.tokens_this_minute as u64,
672                            tokens_this_day: usage.tokens_this_day as u64,
673                            input_tokens_this_month: usage.input_tokens_this_month as u64,
674                            output_tokens_this_month: usage.output_tokens_this_month as u64,
675                            spending_this_month: usage.spending_this_month as u64,
676                            lifetime_spending: usage.lifetime_spending as u64,
677                        },
678                    )
679                    .await
680                    .log_err();
681                }
682            }
683        })
684    }
685}
686
687pub fn log_usage_periodically(state: Arc<LlmState>) {
688    state.executor.clone().spawn_detached(async move {
689        loop {
690            state
691                .executor
692                .sleep(std::time::Duration::from_secs(30))
693                .await;
694
695            for provider in LanguageModelProvider::iter() {
696                for model in state.db.model_names_for_provider(provider) {
697                    if let Some(active_user_count) = state
698                        .get_active_user_count(provider, &model)
699                        .await
700                        .log_err()
701                    {
702                        tracing::info!(
703                            target: "active user counts",
704                            provider = provider.to_string(),
705                            model = model,
706                            users_in_recent_minutes = active_user_count.users_in_recent_minutes,
707                            users_in_recent_days = active_user_count.users_in_recent_days,
708                        );
709                    }
710                }
711            }
712
713            if let Some(usages) = state
714                .db
715                .get_application_wide_usages_by_model(Utc::now())
716                .await
717                .log_err()
718            {
719                for usage in usages {
720                    tracing::info!(
721                        target: "computed usage",
722                        provider = usage.provider.to_string(),
723                        model = usage.model,
724                        requests_this_minute = usage.requests_this_minute,
725                        tokens_this_minute = usage.tokens_this_minute,
726                    );
727                }
728            }
729        }
730    })
731}