llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::{
  7    api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
  8    Config, Error, Result,
  9};
 10use anyhow::{anyhow, Context as _};
 11use authorization::authorize_access_to_language_model;
 12use axum::{
 13    body::Body,
 14    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 15    middleware::{self, Next},
 16    response::{IntoResponse, Response},
 17    routing::post,
 18    Extension, Json, Router, TypedHeader,
 19};
 20use chrono::{DateTime, Duration, Utc};
 21use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 22use futures::{Stream, StreamExt as _};
 23use http_client::IsahcHttpClient;
 24use rpc::{
 25    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 26};
 27use std::{
 28    pin::Pin,
 29    sync::Arc,
 30    task::{Context, Poll},
 31};
 32use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
 33use tokio::sync::RwLock;
 34use util::ResultExt;
 35
 36pub use token::*;
 37
 38pub struct LlmState {
 39    pub config: Config,
 40    pub executor: Executor,
 41    pub db: Arc<LlmDatabase>,
 42    pub http_client: IsahcHttpClient,
 43    pub clickhouse_client: Option<clickhouse::Client>,
 44    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 45}
 46
 47const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 48
 49impl LlmState {
 50    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 51        let database_url = config
 52            .llm_database_url
 53            .as_ref()
 54            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 55        let max_connections = config
 56            .llm_database_max_connections
 57            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 58
 59        let mut db_options = db::ConnectOptions::new(database_url);
 60        db_options.max_connections(max_connections);
 61        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 62        db.initialize().await?;
 63
 64        let db = Arc::new(db);
 65
 66        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 67        let http_client = IsahcHttpClient::builder()
 68            .default_header("User-Agent", user_agent)
 69            .build()
 70            .context("failed to construct http client")?;
 71
 72        let initial_active_user_count =
 73            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
 74
 75        let this = Self {
 76            executor,
 77            db,
 78            http_client,
 79            clickhouse_client: config
 80                .clickhouse_url
 81                .as_ref()
 82                .and_then(|_| build_clickhouse_client(&config).log_err()),
 83            active_user_count: RwLock::new(initial_active_user_count),
 84            config,
 85        };
 86
 87        Ok(Arc::new(this))
 88    }
 89
 90    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
 91        let now = Utc::now();
 92
 93        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
 94            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
 95                return Ok(*count);
 96            }
 97        }
 98
 99        let mut cache = self.active_user_count.write().await;
100        let new_count = self.db.get_active_user_count(now).await?;
101        *cache = Some((now, new_count));
102        Ok(new_count)
103    }
104}
105
106pub fn routes() -> Router<(), Body> {
107    Router::new()
108        .route("/completion", post(perform_completion))
109        .layer(middleware::from_fn(validate_api_token))
110}
111
112async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
113    let token = req
114        .headers()
115        .get(http::header::AUTHORIZATION)
116        .and_then(|header| header.to_str().ok())
117        .ok_or_else(|| {
118            Error::http(
119                StatusCode::BAD_REQUEST,
120                "missing authorization header".to_string(),
121            )
122        })?
123        .strip_prefix("Bearer ")
124        .ok_or_else(|| {
125            Error::http(
126                StatusCode::BAD_REQUEST,
127                "invalid authorization header".to_string(),
128            )
129        })?;
130
131    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
132    match LlmTokenClaims::validate(&token, &state.config) {
133        Ok(claims) => {
134            if state.db.is_access_token_revoked(&claims.jti).await? {
135                return Err(Error::http(
136                    StatusCode::UNAUTHORIZED,
137                    "unauthorized".to_string(),
138                ));
139            }
140
141            tracing::Span::current()
142                .record("user_id", claims.user_id)
143                .record("login", claims.github_user_login.clone())
144                .record("authn.jti", &claims.jti)
145                .record("is_staff", &claims.is_staff);
146
147            req.extensions_mut().insert(claims);
148            Ok::<_, Error>(next.run(req).await.into_response())
149        }
150        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
151            StatusCode::UNAUTHORIZED,
152            "unauthorized".to_string(),
153            [(
154                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
155                HeaderValue::from_static("true"),
156            )]
157            .into_iter()
158            .collect(),
159        )),
160        Err(_err) => Err(Error::http(
161            StatusCode::UNAUTHORIZED,
162            "unauthorized".to_string(),
163        )),
164    }
165}
166
167async fn perform_completion(
168    Extension(state): Extension<Arc<LlmState>>,
169    Extension(claims): Extension<LlmTokenClaims>,
170    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
171    Json(params): Json<PerformCompletionParams>,
172) -> Result<impl IntoResponse> {
173    let model = normalize_model_name(
174        state.db.model_names_for_provider(params.provider),
175        params.model,
176    );
177
178    authorize_access_to_language_model(
179        &state.config,
180        &claims,
181        country_code_header.map(|header| header.to_string()),
182        params.provider,
183        &model,
184    )?;
185
186    check_usage_limit(&state, params.provider, &model, &claims).await?;
187
188    let stream = match params.provider {
189        LanguageModelProvider::Anthropic => {
190            let api_key = if claims.is_staff {
191                state
192                    .config
193                    .anthropic_staff_api_key
194                    .as_ref()
195                    .context("no Anthropic AI staff API key configured on the server")?
196            } else {
197                state
198                    .config
199                    .anthropic_api_key
200                    .as_ref()
201                    .context("no Anthropic AI API key configured on the server")?
202            };
203
204            let mut request: anthropic::Request =
205                serde_json::from_str(&params.provider_request.get())?;
206
207            // Override the model on the request with the latest version of the model that is
208            // known to the server.
209            //
210            // Right now, we use the version that's defined in `model.id()`, but we will likely
211            // want to change this code once a new version of an Anthropic model is released,
212            // so that users can use the new version, without having to update Zed.
213            request.model = match model.as_str() {
214                "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
215                "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
216                "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
217                "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
218                _ => request.model,
219            };
220
221            let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
222                &state.http_client,
223                anthropic::ANTHROPIC_API_URL,
224                api_key,
225                request,
226                None,
227            )
228            .await
229            .map_err(|err| match err {
230                anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
231                    Some(anthropic::ApiErrorCode::RateLimitError) => Error::http(
232                        StatusCode::TOO_MANY_REQUESTS,
233                        "Upstream Anthropic rate limit exceeded.".to_string(),
234                    ),
235                    Some(anthropic::ApiErrorCode::InvalidRequestError) => {
236                        Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
237                    }
238                    Some(anthropic::ApiErrorCode::OverloadedError) => {
239                        Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
240                    }
241                    Some(_) => {
242                        Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
243                    }
244                    None => Error::Internal(anyhow!(err)),
245                },
246                anthropic::AnthropicError::Other(err) => Error::Internal(err),
247            })?;
248
249            if let Some(rate_limit_info) = rate_limit_info {
250                tracing::info!(
251                    target: "upstream rate limit",
252                    provider = params.provider.to_string(),
253                    model = model,
254                    tokens_remaining = rate_limit_info.tokens_remaining,
255                    requests_remaining = rate_limit_info.requests_remaining,
256                    requests_reset = ?rate_limit_info.requests_reset,
257                    tokens_reset = ?rate_limit_info.tokens_reset,
258                );
259            }
260
261            chunks
262                .map(move |event| {
263                    let chunk = event?;
264                    let (input_tokens, output_tokens) = match &chunk {
265                        anthropic::Event::MessageStart {
266                            message: anthropic::Response { usage, .. },
267                        }
268                        | anthropic::Event::MessageDelta { usage, .. } => (
269                            usage.input_tokens.unwrap_or(0) as usize,
270                            usage.output_tokens.unwrap_or(0) as usize,
271                        ),
272                        _ => (0, 0),
273                    };
274
275                    anyhow::Ok((
276                        serde_json::to_vec(&chunk).unwrap(),
277                        input_tokens,
278                        output_tokens,
279                    ))
280                })
281                .boxed()
282        }
283        LanguageModelProvider::OpenAi => {
284            let api_key = state
285                .config
286                .openai_api_key
287                .as_ref()
288                .context("no OpenAI API key configured on the server")?;
289            let chunks = open_ai::stream_completion(
290                &state.http_client,
291                open_ai::OPEN_AI_API_URL,
292                api_key,
293                serde_json::from_str(&params.provider_request.get())?,
294                None,
295            )
296            .await?;
297
298            chunks
299                .map(|event| {
300                    event.map(|chunk| {
301                        let input_tokens =
302                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
303                        let output_tokens =
304                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
305                        (
306                            serde_json::to_vec(&chunk).unwrap(),
307                            input_tokens,
308                            output_tokens,
309                        )
310                    })
311                })
312                .boxed()
313        }
314        LanguageModelProvider::Google => {
315            let api_key = state
316                .config
317                .google_ai_api_key
318                .as_ref()
319                .context("no Google AI API key configured on the server")?;
320            let chunks = google_ai::stream_generate_content(
321                &state.http_client,
322                google_ai::API_URL,
323                api_key,
324                serde_json::from_str(&params.provider_request.get())?,
325            )
326            .await?;
327
328            chunks
329                .map(|event| {
330                    event.map(|chunk| {
331                        // TODO - implement token counting for Google AI
332                        let input_tokens = 0;
333                        let output_tokens = 0;
334                        (
335                            serde_json::to_vec(&chunk).unwrap(),
336                            input_tokens,
337                            output_tokens,
338                        )
339                    })
340                })
341                .boxed()
342        }
343        LanguageModelProvider::Zed => {
344            let api_key = state
345                .config
346                .qwen2_7b_api_key
347                .as_ref()
348                .context("no Qwen2-7B API key configured on the server")?;
349            let api_url = state
350                .config
351                .qwen2_7b_api_url
352                .as_ref()
353                .context("no Qwen2-7B URL configured on the server")?;
354            let chunks = open_ai::stream_completion(
355                &state.http_client,
356                &api_url,
357                api_key,
358                serde_json::from_str(&params.provider_request.get())?,
359                None,
360            )
361            .await?;
362
363            chunks
364                .map(|event| {
365                    event.map(|chunk| {
366                        let input_tokens =
367                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
368                        let output_tokens =
369                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
370                        (
371                            serde_json::to_vec(&chunk).unwrap(),
372                            input_tokens,
373                            output_tokens,
374                        )
375                    })
376                })
377                .boxed()
378        }
379    };
380
381    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
382        state,
383        claims,
384        provider: params.provider,
385        model,
386        input_tokens: 0,
387        output_tokens: 0,
388        inner_stream: stream,
389    })))
390}
391
392fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
393    if let Some(known_model_name) = known_models
394        .iter()
395        .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
396        .max_by_key(|known_model_name| known_model_name.len())
397    {
398        known_model_name.to_string()
399    } else {
400        name
401    }
402}
403
404async fn check_usage_limit(
405    state: &Arc<LlmState>,
406    provider: LanguageModelProvider,
407    model_name: &str,
408    claims: &LlmTokenClaims,
409) -> Result<()> {
410    let model = state.db.model(provider, model_name)?;
411    let usage = state
412        .db
413        .get_usage(
414            UserId::from_proto(claims.user_id),
415            provider,
416            model_name,
417            Utc::now(),
418        )
419        .await?;
420
421    let active_users = state.get_active_user_count().await?;
422
423    let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
424    let users_in_recent_days = active_users.users_in_recent_days.max(1);
425
426    let per_user_max_requests_per_minute =
427        model.max_requests_per_minute as usize / users_in_recent_minutes;
428    let per_user_max_tokens_per_minute =
429        model.max_tokens_per_minute as usize / users_in_recent_minutes;
430    let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
431
432    let checks = [
433        (
434            usage.requests_this_minute,
435            per_user_max_requests_per_minute,
436            UsageMeasure::RequestsPerMinute,
437        ),
438        (
439            usage.tokens_this_minute,
440            per_user_max_tokens_per_minute,
441            UsageMeasure::TokensPerMinute,
442        ),
443        (
444            usage.tokens_this_day,
445            per_user_max_tokens_per_day,
446            UsageMeasure::TokensPerDay,
447        ),
448    ];
449
450    for (used, limit, usage_measure) in checks {
451        // Temporarily bypass rate-limiting for staff members.
452        if claims.is_staff {
453            continue;
454        }
455
456        if used > limit {
457            let resource = match usage_measure {
458                UsageMeasure::RequestsPerMinute => "requests_per_minute",
459                UsageMeasure::TokensPerMinute => "tokens_per_minute",
460                UsageMeasure::TokensPerDay => "tokens_per_day",
461                _ => "",
462            };
463
464            if let Some(client) = state.clickhouse_client.as_ref() {
465                report_llm_rate_limit(
466                    client,
467                    LlmRateLimitEventRow {
468                        time: Utc::now().timestamp_millis(),
469                        user_id: claims.user_id as i32,
470                        is_staff: claims.is_staff,
471                        plan: match claims.plan {
472                            Plan::Free => "free".to_string(),
473                            Plan::ZedPro => "zed_pro".to_string(),
474                        },
475                        model: model.name.clone(),
476                        provider: provider.to_string(),
477                        usage_measure: resource.to_string(),
478                        requests_this_minute: usage.requests_this_minute as u64,
479                        tokens_this_minute: usage.tokens_this_minute as u64,
480                        tokens_this_day: usage.tokens_this_day as u64,
481                        users_in_recent_minutes: users_in_recent_minutes as u64,
482                        users_in_recent_days: users_in_recent_days as u64,
483                        max_requests_per_minute: per_user_max_requests_per_minute as u64,
484                        max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
485                        max_tokens_per_day: per_user_max_tokens_per_day as u64,
486                    },
487                )
488                .await
489                .log_err();
490            }
491
492            return Err(Error::http(
493                StatusCode::TOO_MANY_REQUESTS,
494                format!("Rate limit exceeded. Maximum {} reached.", resource),
495            ));
496        }
497    }
498
499    Ok(())
500}
501
502struct TokenCountingStream<S> {
503    state: Arc<LlmState>,
504    claims: LlmTokenClaims,
505    provider: LanguageModelProvider,
506    model: String,
507    input_tokens: usize,
508    output_tokens: usize,
509    inner_stream: S,
510}
511
512impl<S> Stream for TokenCountingStream<S>
513where
514    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
515{
516    type Item = Result<Vec<u8>, anyhow::Error>;
517
518    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
519        match Pin::new(&mut self.inner_stream).poll_next(cx) {
520            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
521                bytes.push(b'\n');
522                self.input_tokens += input_tokens;
523                self.output_tokens += output_tokens;
524                Poll::Ready(Some(Ok(bytes)))
525            }
526            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
527            Poll::Ready(None) => Poll::Ready(None),
528            Poll::Pending => Poll::Pending,
529        }
530    }
531}
532
533impl<S> Drop for TokenCountingStream<S> {
534    fn drop(&mut self) {
535        let state = self.state.clone();
536        let claims = self.claims.clone();
537        let provider = self.provider;
538        let model = std::mem::take(&mut self.model);
539        let input_token_count = self.input_tokens;
540        let output_token_count = self.output_tokens;
541        self.state.executor.spawn_detached(async move {
542            let usage = state
543                .db
544                .record_usage(
545                    UserId::from_proto(claims.user_id),
546                    claims.is_staff,
547                    provider,
548                    &model,
549                    input_token_count,
550                    output_token_count,
551                    Utc::now(),
552                )
553                .await
554                .log_err();
555
556            if let Some(usage) = usage {
557                tracing::info!(
558                    target: "user usage",
559                    user_id = claims.user_id,
560                    login = claims.github_user_login,
561                    authn.jti = claims.jti,
562                    is_staff = claims.is_staff,
563                    requests_this_minute = usage.requests_this_minute,
564                    tokens_this_minute = usage.tokens_this_minute,
565                );
566
567                if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
568                    report_llm_usage(
569                        clickhouse_client,
570                        LlmUsageEventRow {
571                            time: Utc::now().timestamp_millis(),
572                            user_id: claims.user_id as i32,
573                            is_staff: claims.is_staff,
574                            plan: match claims.plan {
575                                Plan::Free => "free".to_string(),
576                                Plan::ZedPro => "zed_pro".to_string(),
577                            },
578                            model,
579                            provider: provider.to_string(),
580                            input_token_count: input_token_count as u64,
581                            output_token_count: output_token_count as u64,
582                            requests_this_minute: usage.requests_this_minute as u64,
583                            tokens_this_minute: usage.tokens_this_minute as u64,
584                            tokens_this_day: usage.tokens_this_day as u64,
585                            input_tokens_this_month: usage.input_tokens_this_month as u64,
586                            output_tokens_this_month: usage.output_tokens_this_month as u64,
587                            spending_this_month: usage.spending_this_month as u64,
588                            lifetime_spending: usage.lifetime_spending as u64,
589                        },
590                    )
591                    .await
592                    .log_err();
593                }
594            }
595        })
596    }
597}
598
599pub fn log_usage_periodically(state: Arc<LlmState>) {
600    state.executor.clone().spawn_detached(async move {
601        loop {
602            state
603                .executor
604                .sleep(std::time::Duration::from_secs(30))
605                .await;
606
607            let Some(usages) = state
608                .db
609                .get_application_wide_usages_by_model(Utc::now())
610                .await
611                .log_err()
612            else {
613                continue;
614            };
615
616            for usage in usages {
617                tracing::info!(
618                    target: "computed usage",
619                    provider = usage.provider.to_string(),
620                    model = usage.model,
621                    requests_this_minute = usage.requests_this_minute,
622                    tokens_this_minute = usage.tokens_this_minute,
623                );
624            }
625        }
626    })
627}