llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::{
  7    api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
  8    Config, Error, Result,
  9};
 10use anyhow::{anyhow, Context as _};
 11use authorization::authorize_access_to_language_model;
 12use axum::routing::get;
 13use axum::{
 14    body::Body,
 15    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 16    middleware::{self, Next},
 17    response::{IntoResponse, Response},
 18    routing::post,
 19    Extension, Json, Router, TypedHeader,
 20};
 21use chrono::{DateTime, Duration, Utc};
 22use collections::HashMap;
 23use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 24use futures::{Stream, StreamExt as _};
 25use http_client::IsahcHttpClient;
 26use rpc::ListModelsResponse;
 27use rpc::{
 28    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 29};
 30use std::{
 31    pin::Pin,
 32    sync::Arc,
 33    task::{Context, Poll},
 34};
 35use strum::IntoEnumIterator;
 36use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
 37use tokio::sync::RwLock;
 38use util::ResultExt;
 39
 40pub use token::*;
 41
 42pub struct LlmState {
 43    pub config: Config,
 44    pub executor: Executor,
 45    pub db: Arc<LlmDatabase>,
 46    pub http_client: IsahcHttpClient,
 47    pub clickhouse_client: Option<clickhouse::Client>,
 48    active_user_count_by_model:
 49        RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
 50}
 51
 52const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 53
 54impl LlmState {
 55    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 56        let database_url = config
 57            .llm_database_url
 58            .as_ref()
 59            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 60        let max_connections = config
 61            .llm_database_max_connections
 62            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 63
 64        let mut db_options = db::ConnectOptions::new(database_url);
 65        db_options.max_connections(max_connections);
 66        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 67        db.initialize().await?;
 68
 69        let db = Arc::new(db);
 70
 71        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 72        let http_client = IsahcHttpClient::builder()
 73            .default_header("User-Agent", user_agent)
 74            .build()
 75            .context("failed to construct http client")?;
 76
 77        let this = Self {
 78            executor,
 79            db,
 80            http_client,
 81            clickhouse_client: config
 82                .clickhouse_url
 83                .as_ref()
 84                .and_then(|_| build_clickhouse_client(&config).log_err()),
 85            active_user_count_by_model: RwLock::new(HashMap::default()),
 86            config,
 87        };
 88
 89        Ok(Arc::new(this))
 90    }
 91
 92    pub async fn get_active_user_count(
 93        &self,
 94        provider: LanguageModelProvider,
 95        model: &str,
 96    ) -> Result<ActiveUserCount> {
 97        let now = Utc::now();
 98
 99        {
100            let active_user_count_by_model = self.active_user_count_by_model.read().await;
101            if let Some((last_updated, count)) =
102                active_user_count_by_model.get(&(provider, model.to_string()))
103            {
104                if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
105                    return Ok(*count);
106                }
107            }
108        }
109
110        let mut cache = self.active_user_count_by_model.write().await;
111        let new_count = self.db.get_active_user_count(provider, model, now).await?;
112        cache.insert((provider, model.to_string()), (now, new_count));
113        Ok(new_count)
114    }
115}
116
117pub fn routes() -> Router<(), Body> {
118    Router::new()
119        .route("/models", get(list_models))
120        .route("/completion", post(perform_completion))
121        .layer(middleware::from_fn(validate_api_token))
122}
123
124async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
125    let token = req
126        .headers()
127        .get(http::header::AUTHORIZATION)
128        .and_then(|header| header.to_str().ok())
129        .ok_or_else(|| {
130            Error::http(
131                StatusCode::BAD_REQUEST,
132                "missing authorization header".to_string(),
133            )
134        })?
135        .strip_prefix("Bearer ")
136        .ok_or_else(|| {
137            Error::http(
138                StatusCode::BAD_REQUEST,
139                "invalid authorization header".to_string(),
140            )
141        })?;
142
143    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
144    match LlmTokenClaims::validate(token, &state.config) {
145        Ok(claims) => {
146            if state.db.is_access_token_revoked(&claims.jti).await? {
147                return Err(Error::http(
148                    StatusCode::UNAUTHORIZED,
149                    "unauthorized".to_string(),
150                ));
151            }
152
153            tracing::Span::current()
154                .record("user_id", claims.user_id)
155                .record("login", claims.github_user_login.clone())
156                .record("authn.jti", &claims.jti)
157                .record("is_staff", claims.is_staff);
158
159            req.extensions_mut().insert(claims);
160            Ok::<_, Error>(next.run(req).await.into_response())
161        }
162        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
163            StatusCode::UNAUTHORIZED,
164            "unauthorized".to_string(),
165            [(
166                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
167                HeaderValue::from_static("true"),
168            )]
169            .into_iter()
170            .collect(),
171        )),
172        Err(_err) => Err(Error::http(
173            StatusCode::UNAUTHORIZED,
174            "unauthorized".to_string(),
175        )),
176    }
177}
178
179async fn list_models(
180    Extension(state): Extension<Arc<LlmState>>,
181    Extension(claims): Extension<LlmTokenClaims>,
182    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
183) -> Result<Json<ListModelsResponse>> {
184    let country_code = country_code_header.map(|header| header.to_string());
185
186    let mut accessible_models = Vec::new();
187
188    for (provider, model) in state.db.all_models() {
189        let authorize_result = authorize_access_to_language_model(
190            &state.config,
191            &claims,
192            country_code.as_deref(),
193            provider,
194            &model.name,
195        );
196
197        if authorize_result.is_ok() {
198            accessible_models.push(rpc::LanguageModel {
199                provider,
200                name: model.name,
201            });
202        }
203    }
204
205    Ok(Json(ListModelsResponse {
206        models: accessible_models,
207    }))
208}
209
210async fn perform_completion(
211    Extension(state): Extension<Arc<LlmState>>,
212    Extension(claims): Extension<LlmTokenClaims>,
213    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
214    Json(params): Json<PerformCompletionParams>,
215) -> Result<impl IntoResponse> {
216    let model = normalize_model_name(
217        state.db.model_names_for_provider(params.provider),
218        params.model,
219    );
220
221    authorize_access_to_language_model(
222        &state.config,
223        &claims,
224        country_code_header
225            .map(|header| header.to_string())
226            .as_deref(),
227        params.provider,
228        &model,
229    )?;
230
231    check_usage_limit(&state, params.provider, &model, &claims).await?;
232
233    let stream = match params.provider {
234        LanguageModelProvider::Anthropic => {
235            let api_key = if claims.is_staff {
236                state
237                    .config
238                    .anthropic_staff_api_key
239                    .as_ref()
240                    .context("no Anthropic AI staff API key configured on the server")?
241            } else {
242                state
243                    .config
244                    .anthropic_api_key
245                    .as_ref()
246                    .context("no Anthropic AI API key configured on the server")?
247            };
248
249            let mut request: anthropic::Request =
250                serde_json::from_str(params.provider_request.get())?;
251
252            // Override the model on the request with the latest version of the model that is
253            // known to the server.
254            //
255            // Right now, we use the version that's defined in `model.id()`, but we will likely
256            // want to change this code once a new version of an Anthropic model is released,
257            // so that users can use the new version, without having to update Zed.
258            request.model = match model.as_str() {
259                "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
260                "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
261                "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
262                "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
263                _ => request.model,
264            };
265
266            let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
267                &state.http_client,
268                anthropic::ANTHROPIC_API_URL,
269                api_key,
270                request,
271                None,
272            )
273            .await
274            .map_err(|err| match err {
275                anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
276                    Some(anthropic::ApiErrorCode::RateLimitError) => {
277                        tracing::info!(
278                            target: "upstream rate limit exceeded",
279                            user_id = claims.user_id,
280                            login = claims.github_user_login,
281                            authn.jti = claims.jti,
282                            is_staff = claims.is_staff,
283                            provider = params.provider.to_string(),
284                            model = model
285                        );
286
287                        Error::http(
288                            StatusCode::TOO_MANY_REQUESTS,
289                            "Upstream Anthropic rate limit exceeded.".to_string(),
290                        )
291                    }
292                    Some(anthropic::ApiErrorCode::InvalidRequestError) => {
293                        Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
294                    }
295                    Some(anthropic::ApiErrorCode::OverloadedError) => {
296                        Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
297                    }
298                    Some(_) => {
299                        Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
300                    }
301                    None => Error::Internal(anyhow!(err)),
302                },
303                anthropic::AnthropicError::Other(err) => Error::Internal(err),
304            })?;
305
306            if let Some(rate_limit_info) = rate_limit_info {
307                tracing::info!(
308                    target: "upstream rate limit",
309                    is_staff = claims.is_staff,
310                    provider = params.provider.to_string(),
311                    model = model,
312                    tokens_remaining = rate_limit_info.tokens_remaining,
313                    requests_remaining = rate_limit_info.requests_remaining,
314                    requests_reset = ?rate_limit_info.requests_reset,
315                    tokens_reset = ?rate_limit_info.tokens_reset,
316                );
317            }
318
319            chunks
320                .map(move |event| {
321                    let chunk = event?;
322                    let (input_tokens, output_tokens) = match &chunk {
323                        anthropic::Event::MessageStart {
324                            message: anthropic::Response { usage, .. },
325                        }
326                        | anthropic::Event::MessageDelta { usage, .. } => (
327                            usage.input_tokens.unwrap_or(0) as usize,
328                            usage.output_tokens.unwrap_or(0) as usize,
329                        ),
330                        _ => (0, 0),
331                    };
332
333                    anyhow::Ok((
334                        serde_json::to_vec(&chunk).unwrap(),
335                        input_tokens,
336                        output_tokens,
337                    ))
338                })
339                .boxed()
340        }
341        LanguageModelProvider::OpenAi => {
342            let api_key = state
343                .config
344                .openai_api_key
345                .as_ref()
346                .context("no OpenAI API key configured on the server")?;
347            let chunks = open_ai::stream_completion(
348                &state.http_client,
349                open_ai::OPEN_AI_API_URL,
350                api_key,
351                serde_json::from_str(params.provider_request.get())?,
352                None,
353            )
354            .await?;
355
356            chunks
357                .map(|event| {
358                    event.map(|chunk| {
359                        let input_tokens =
360                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
361                        let output_tokens =
362                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
363                        (
364                            serde_json::to_vec(&chunk).unwrap(),
365                            input_tokens,
366                            output_tokens,
367                        )
368                    })
369                })
370                .boxed()
371        }
372        LanguageModelProvider::Google => {
373            let api_key = state
374                .config
375                .google_ai_api_key
376                .as_ref()
377                .context("no Google AI API key configured on the server")?;
378            let chunks = google_ai::stream_generate_content(
379                &state.http_client,
380                google_ai::API_URL,
381                api_key,
382                serde_json::from_str(params.provider_request.get())?,
383                None,
384            )
385            .await?;
386
387            chunks
388                .map(|event| {
389                    event.map(|chunk| {
390                        // TODO - implement token counting for Google AI
391                        let input_tokens = 0;
392                        let output_tokens = 0;
393                        (
394                            serde_json::to_vec(&chunk).unwrap(),
395                            input_tokens,
396                            output_tokens,
397                        )
398                    })
399                })
400                .boxed()
401        }
402        LanguageModelProvider::Zed => {
403            let api_key = state
404                .config
405                .qwen2_7b_api_key
406                .as_ref()
407                .context("no Qwen2-7B API key configured on the server")?;
408            let api_url = state
409                .config
410                .qwen2_7b_api_url
411                .as_ref()
412                .context("no Qwen2-7B URL configured on the server")?;
413            let chunks = open_ai::stream_completion(
414                &state.http_client,
415                api_url,
416                api_key,
417                serde_json::from_str(params.provider_request.get())?,
418                None,
419            )
420            .await?;
421
422            chunks
423                .map(|event| {
424                    event.map(|chunk| {
425                        let input_tokens =
426                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
427                        let output_tokens =
428                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
429                        (
430                            serde_json::to_vec(&chunk).unwrap(),
431                            input_tokens,
432                            output_tokens,
433                        )
434                    })
435                })
436                .boxed()
437        }
438    };
439
440    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
441        state,
442        claims,
443        provider: params.provider,
444        model,
445        input_tokens: 0,
446        output_tokens: 0,
447        inner_stream: stream,
448    })))
449}
450
451fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
452    if let Some(known_model_name) = known_models
453        .iter()
454        .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
455        .max_by_key(|known_model_name| known_model_name.len())
456    {
457        known_model_name.to_string()
458    } else {
459        name
460    }
461}
462
463/// The maximum lifetime spending an individual user can reach before being cut off.
464///
465/// Represented in cents.
466const LIFETIME_SPENDING_LIMIT_IN_CENTS: usize = 1_000 * 100;
467
468async fn check_usage_limit(
469    state: &Arc<LlmState>,
470    provider: LanguageModelProvider,
471    model_name: &str,
472    claims: &LlmTokenClaims,
473) -> Result<()> {
474    let model = state.db.model(provider, model_name)?;
475    let usage = state
476        .db
477        .get_usage(
478            UserId::from_proto(claims.user_id),
479            provider,
480            model_name,
481            Utc::now(),
482        )
483        .await?;
484
485    if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS {
486        return Err(Error::http(
487            StatusCode::FORBIDDEN,
488            "Maximum spending limit reached.".to_string(),
489        ));
490    }
491
492    let active_users = state.get_active_user_count(provider, model_name).await?;
493
494    let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
495    let users_in_recent_days = active_users.users_in_recent_days.max(1);
496
497    let per_user_max_requests_per_minute =
498        model.max_requests_per_minute as usize / users_in_recent_minutes;
499    let per_user_max_tokens_per_minute =
500        model.max_tokens_per_minute as usize / users_in_recent_minutes;
501    let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
502
503    let checks = [
504        (
505            usage.requests_this_minute,
506            per_user_max_requests_per_minute,
507            UsageMeasure::RequestsPerMinute,
508        ),
509        (
510            usage.tokens_this_minute,
511            per_user_max_tokens_per_minute,
512            UsageMeasure::TokensPerMinute,
513        ),
514        (
515            usage.tokens_this_day,
516            per_user_max_tokens_per_day,
517            UsageMeasure::TokensPerDay,
518        ),
519    ];
520
521    for (used, limit, usage_measure) in checks {
522        // Temporarily bypass rate-limiting for staff members.
523        if claims.is_staff {
524            continue;
525        }
526
527        if used > limit {
528            let resource = match usage_measure {
529                UsageMeasure::RequestsPerMinute => "requests_per_minute",
530                UsageMeasure::TokensPerMinute => "tokens_per_minute",
531                UsageMeasure::TokensPerDay => "tokens_per_day",
532                _ => "",
533            };
534
535            if let Some(client) = state.clickhouse_client.as_ref() {
536                tracing::info!(
537                    target: "user rate limit",
538                    user_id = claims.user_id,
539                    login = claims.github_user_login,
540                    authn.jti = claims.jti,
541                    is_staff = claims.is_staff,
542                    provider = provider.to_string(),
543                    model = model.name,
544                    requests_this_minute = usage.requests_this_minute,
545                    tokens_this_minute = usage.tokens_this_minute,
546                    tokens_this_day = usage.tokens_this_day,
547                    users_in_recent_minutes = users_in_recent_minutes,
548                    users_in_recent_days = users_in_recent_days,
549                    max_requests_per_minute = per_user_max_requests_per_minute,
550                    max_tokens_per_minute = per_user_max_tokens_per_minute,
551                    max_tokens_per_day = per_user_max_tokens_per_day,
552                );
553
554                report_llm_rate_limit(
555                    client,
556                    LlmRateLimitEventRow {
557                        time: Utc::now().timestamp_millis(),
558                        user_id: claims.user_id as i32,
559                        is_staff: claims.is_staff,
560                        plan: match claims.plan {
561                            Plan::Free => "free".to_string(),
562                            Plan::ZedPro => "zed_pro".to_string(),
563                        },
564                        model: model.name.clone(),
565                        provider: provider.to_string(),
566                        usage_measure: resource.to_string(),
567                        requests_this_minute: usage.requests_this_minute as u64,
568                        tokens_this_minute: usage.tokens_this_minute as u64,
569                        tokens_this_day: usage.tokens_this_day as u64,
570                        users_in_recent_minutes: users_in_recent_minutes as u64,
571                        users_in_recent_days: users_in_recent_days as u64,
572                        max_requests_per_minute: per_user_max_requests_per_minute as u64,
573                        max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
574                        max_tokens_per_day: per_user_max_tokens_per_day as u64,
575                    },
576                )
577                .await
578                .log_err();
579            }
580
581            return Err(Error::http(
582                StatusCode::TOO_MANY_REQUESTS,
583                format!("Rate limit exceeded. Maximum {} reached.", resource),
584            ));
585        }
586    }
587
588    Ok(())
589}
590
591struct TokenCountingStream<S> {
592    state: Arc<LlmState>,
593    claims: LlmTokenClaims,
594    provider: LanguageModelProvider,
595    model: String,
596    input_tokens: usize,
597    output_tokens: usize,
598    inner_stream: S,
599}
600
601impl<S> Stream for TokenCountingStream<S>
602where
603    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
604{
605    type Item = Result<Vec<u8>, anyhow::Error>;
606
607    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
608        match Pin::new(&mut self.inner_stream).poll_next(cx) {
609            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
610                bytes.push(b'\n');
611                self.input_tokens += input_tokens;
612                self.output_tokens += output_tokens;
613                Poll::Ready(Some(Ok(bytes)))
614            }
615            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
616            Poll::Ready(None) => Poll::Ready(None),
617            Poll::Pending => Poll::Pending,
618        }
619    }
620}
621
622impl<S> Drop for TokenCountingStream<S> {
623    fn drop(&mut self) {
624        let state = self.state.clone();
625        let claims = self.claims.clone();
626        let provider = self.provider;
627        let model = std::mem::take(&mut self.model);
628        let input_token_count = self.input_tokens;
629        let output_token_count = self.output_tokens;
630        self.state.executor.spawn_detached(async move {
631            let usage = state
632                .db
633                .record_usage(
634                    UserId::from_proto(claims.user_id),
635                    claims.is_staff,
636                    provider,
637                    &model,
638                    input_token_count,
639                    output_token_count,
640                    Utc::now(),
641                )
642                .await
643                .log_err();
644
645            if let Some(usage) = usage {
646                tracing::info!(
647                    target: "user usage",
648                    user_id = claims.user_id,
649                    login = claims.github_user_login,
650                    authn.jti = claims.jti,
651                    is_staff = claims.is_staff,
652                    requests_this_minute = usage.requests_this_minute,
653                    tokens_this_minute = usage.tokens_this_minute,
654                );
655
656                if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
657                    report_llm_usage(
658                        clickhouse_client,
659                        LlmUsageEventRow {
660                            time: Utc::now().timestamp_millis(),
661                            user_id: claims.user_id as i32,
662                            is_staff: claims.is_staff,
663                            plan: match claims.plan {
664                                Plan::Free => "free".to_string(),
665                                Plan::ZedPro => "zed_pro".to_string(),
666                            },
667                            model,
668                            provider: provider.to_string(),
669                            input_token_count: input_token_count as u64,
670                            output_token_count: output_token_count as u64,
671                            requests_this_minute: usage.requests_this_minute as u64,
672                            tokens_this_minute: usage.tokens_this_minute as u64,
673                            tokens_this_day: usage.tokens_this_day as u64,
674                            input_tokens_this_month: usage.input_tokens_this_month as u64,
675                            output_tokens_this_month: usage.output_tokens_this_month as u64,
676                            spending_this_month: usage.spending_this_month as u64,
677                            lifetime_spending: usage.lifetime_spending as u64,
678                        },
679                    )
680                    .await
681                    .log_err();
682                }
683            }
684        })
685    }
686}
687
688pub fn log_usage_periodically(state: Arc<LlmState>) {
689    state.executor.clone().spawn_detached(async move {
690        loop {
691            state
692                .executor
693                .sleep(std::time::Duration::from_secs(30))
694                .await;
695
696            for provider in LanguageModelProvider::iter() {
697                for model in state.db.model_names_for_provider(provider) {
698                    if let Some(active_user_count) = state
699                        .get_active_user_count(provider, &model)
700                        .await
701                        .log_err()
702                    {
703                        tracing::info!(
704                            target: "active user counts",
705                            provider = provider.to_string(),
706                            model = model,
707                            users_in_recent_minutes = active_user_count.users_in_recent_minutes,
708                            users_in_recent_days = active_user_count.users_in_recent_days,
709                        );
710                    }
711                }
712            }
713
714            if let Some(usages) = state
715                .db
716                .get_application_wide_usages_by_model(Utc::now())
717                .await
718                .log_err()
719            {
720                for usage in usages {
721                    tracing::info!(
722                        target: "computed usage",
723                        provider = usage.provider.to_string(),
724                        model = usage.model,
725                        requests_this_minute = usage.requests_this_minute,
726                        tokens_this_minute = usage.tokens_this_minute,
727                    );
728                }
729            }
730        }
731    })
732}