llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::{
  7    api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
  8    Config, Error, Result,
  9};
 10use anyhow::{anyhow, Context as _};
 11use authorization::authorize_access_to_language_model;
 12use axum::{
 13    body::Body,
 14    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 15    middleware::{self, Next},
 16    response::{IntoResponse, Response},
 17    routing::post,
 18    Extension, Json, Router, TypedHeader,
 19};
 20use chrono::{DateTime, Duration, Utc};
 21use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 22use futures::{Stream, StreamExt as _};
 23use http_client::IsahcHttpClient;
 24use rpc::{
 25    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 26};
 27use std::{
 28    pin::Pin,
 29    sync::Arc,
 30    task::{Context, Poll},
 31};
 32use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
 33use tokio::sync::RwLock;
 34use util::ResultExt;
 35
 36pub use token::*;
 37
 38pub struct LlmState {
 39    pub config: Config,
 40    pub executor: Executor,
 41    pub db: Arc<LlmDatabase>,
 42    pub http_client: IsahcHttpClient,
 43    pub clickhouse_client: Option<clickhouse::Client>,
 44    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 45}
 46
 47const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 48
 49impl LlmState {
 50    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 51        let database_url = config
 52            .llm_database_url
 53            .as_ref()
 54            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 55        let max_connections = config
 56            .llm_database_max_connections
 57            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 58
 59        let mut db_options = db::ConnectOptions::new(database_url);
 60        db_options.max_connections(max_connections);
 61        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 62        db.initialize().await?;
 63
 64        let db = Arc::new(db);
 65
 66        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 67        let http_client = IsahcHttpClient::builder()
 68            .default_header("User-Agent", user_agent)
 69            .build()
 70            .context("failed to construct http client")?;
 71
 72        let initial_active_user_count =
 73            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
 74
 75        let this = Self {
 76            executor,
 77            db,
 78            http_client,
 79            clickhouse_client: config
 80                .clickhouse_url
 81                .as_ref()
 82                .and_then(|_| build_clickhouse_client(&config).log_err()),
 83            active_user_count: RwLock::new(initial_active_user_count),
 84            config,
 85        };
 86
 87        Ok(Arc::new(this))
 88    }
 89
 90    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
 91        let now = Utc::now();
 92
 93        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
 94            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
 95                return Ok(*count);
 96            }
 97        }
 98
 99        let mut cache = self.active_user_count.write().await;
100        let new_count = self.db.get_active_user_count(now).await?;
101        *cache = Some((now, new_count));
102        Ok(new_count)
103    }
104}
105
106pub fn routes() -> Router<(), Body> {
107    Router::new()
108        .route("/completion", post(perform_completion))
109        .layer(middleware::from_fn(validate_api_token))
110}
111
112async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
113    let token = req
114        .headers()
115        .get(http::header::AUTHORIZATION)
116        .and_then(|header| header.to_str().ok())
117        .ok_or_else(|| {
118            Error::http(
119                StatusCode::BAD_REQUEST,
120                "missing authorization header".to_string(),
121            )
122        })?
123        .strip_prefix("Bearer ")
124        .ok_or_else(|| {
125            Error::http(
126                StatusCode::BAD_REQUEST,
127                "invalid authorization header".to_string(),
128            )
129        })?;
130
131    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
132    match LlmTokenClaims::validate(&token, &state.config) {
133        Ok(claims) => {
134            if state.db.is_access_token_revoked(&claims.jti).await? {
135                return Err(Error::http(
136                    StatusCode::UNAUTHORIZED,
137                    "unauthorized".to_string(),
138                ));
139            }
140
141            tracing::Span::current()
142                .record("user_id", claims.user_id)
143                .record("login", claims.github_user_login.clone())
144                .record("authn.jti", &claims.jti);
145
146            req.extensions_mut().insert(claims);
147            Ok::<_, Error>(next.run(req).await.into_response())
148        }
149        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
150            StatusCode::UNAUTHORIZED,
151            "unauthorized".to_string(),
152            [(
153                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
154                HeaderValue::from_static("true"),
155            )]
156            .into_iter()
157            .collect(),
158        )),
159        Err(_err) => Err(Error::http(
160            StatusCode::UNAUTHORIZED,
161            "unauthorized".to_string(),
162        )),
163    }
164}
165
166async fn perform_completion(
167    Extension(state): Extension<Arc<LlmState>>,
168    Extension(claims): Extension<LlmTokenClaims>,
169    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
170    Json(params): Json<PerformCompletionParams>,
171) -> Result<impl IntoResponse> {
172    let model = normalize_model_name(params.provider, params.model);
173
174    authorize_access_to_language_model(
175        &state.config,
176        &claims,
177        country_code_header.map(|header| header.to_string()),
178        params.provider,
179        &model,
180    )?;
181
182    check_usage_limit(&state, params.provider, &model, &claims).await?;
183
184    let stream = match params.provider {
185        LanguageModelProvider::Anthropic => {
186            let api_key = if claims.is_staff {
187                state
188                    .config
189                    .anthropic_staff_api_key
190                    .as_ref()
191                    .context("no Anthropic AI staff API key configured on the server")?
192            } else {
193                state
194                    .config
195                    .anthropic_api_key
196                    .as_ref()
197                    .context("no Anthropic AI API key configured on the server")?
198            };
199
200            let mut request: anthropic::Request =
201                serde_json::from_str(&params.provider_request.get())?;
202
203            // Parse the model, throw away the version that was included, and then set a specific
204            // version that we control on the server.
205            // Right now, we use the version that's defined in `model.id()`, but we will likely
206            // want to change this code once a new version of an Anthropic model is released,
207            // so that users can use the new version, without having to update Zed.
208            request.model = match anthropic::Model::from_id(&request.model) {
209                Ok(model) => model.id().to_string(),
210                Err(_) => request.model,
211            };
212
213            let chunks = anthropic::stream_completion(
214                &state.http_client,
215                anthropic::ANTHROPIC_API_URL,
216                api_key,
217                request,
218                None,
219            )
220            .await
221            .map_err(|err| match err {
222                anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
223                    Some(anthropic::ApiErrorCode::RateLimitError) => Error::http(
224                        StatusCode::TOO_MANY_REQUESTS,
225                        "Upstream Anthropic rate limit exceeded.".to_string(),
226                    ),
227                    Some(anthropic::ApiErrorCode::InvalidRequestError) => {
228                        Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
229                    }
230                    Some(anthropic::ApiErrorCode::OverloadedError) => {
231                        Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
232                    }
233                    Some(_) => {
234                        Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
235                    }
236                    None => Error::Internal(anyhow!(err)),
237                },
238                anthropic::AnthropicError::Other(err) => Error::Internal(err),
239            })?;
240
241            chunks
242                .map(move |event| {
243                    let chunk = event?;
244                    let (input_tokens, output_tokens) = match &chunk {
245                        anthropic::Event::MessageStart {
246                            message: anthropic::Response { usage, .. },
247                        }
248                        | anthropic::Event::MessageDelta { usage, .. } => (
249                            usage.input_tokens.unwrap_or(0) as usize,
250                            usage.output_tokens.unwrap_or(0) as usize,
251                        ),
252                        _ => (0, 0),
253                    };
254
255                    anyhow::Ok((
256                        serde_json::to_vec(&chunk).unwrap(),
257                        input_tokens,
258                        output_tokens,
259                    ))
260                })
261                .boxed()
262        }
263        LanguageModelProvider::OpenAi => {
264            let api_key = state
265                .config
266                .openai_api_key
267                .as_ref()
268                .context("no OpenAI API key configured on the server")?;
269            let chunks = open_ai::stream_completion(
270                &state.http_client,
271                open_ai::OPEN_AI_API_URL,
272                api_key,
273                serde_json::from_str(&params.provider_request.get())?,
274                None,
275            )
276            .await?;
277
278            chunks
279                .map(|event| {
280                    event.map(|chunk| {
281                        let input_tokens =
282                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
283                        let output_tokens =
284                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
285                        (
286                            serde_json::to_vec(&chunk).unwrap(),
287                            input_tokens,
288                            output_tokens,
289                        )
290                    })
291                })
292                .boxed()
293        }
294        LanguageModelProvider::Google => {
295            let api_key = state
296                .config
297                .google_ai_api_key
298                .as_ref()
299                .context("no Google AI API key configured on the server")?;
300            let chunks = google_ai::stream_generate_content(
301                &state.http_client,
302                google_ai::API_URL,
303                api_key,
304                serde_json::from_str(&params.provider_request.get())?,
305            )
306            .await?;
307
308            chunks
309                .map(|event| {
310                    event.map(|chunk| {
311                        // TODO - implement token counting for Google AI
312                        let input_tokens = 0;
313                        let output_tokens = 0;
314                        (
315                            serde_json::to_vec(&chunk).unwrap(),
316                            input_tokens,
317                            output_tokens,
318                        )
319                    })
320                })
321                .boxed()
322        }
323        LanguageModelProvider::Zed => {
324            let api_key = state
325                .config
326                .qwen2_7b_api_key
327                .as_ref()
328                .context("no Qwen2-7B API key configured on the server")?;
329            let api_url = state
330                .config
331                .qwen2_7b_api_url
332                .as_ref()
333                .context("no Qwen2-7B URL configured on the server")?;
334            let chunks = open_ai::stream_completion(
335                &state.http_client,
336                &api_url,
337                api_key,
338                serde_json::from_str(&params.provider_request.get())?,
339                None,
340            )
341            .await?;
342
343            chunks
344                .map(|event| {
345                    event.map(|chunk| {
346                        let input_tokens =
347                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
348                        let output_tokens =
349                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
350                        (
351                            serde_json::to_vec(&chunk).unwrap(),
352                            input_tokens,
353                            output_tokens,
354                        )
355                    })
356                })
357                .boxed()
358        }
359    };
360
361    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
362        state,
363        claims,
364        provider: params.provider,
365        model,
366        input_tokens: 0,
367        output_tokens: 0,
368        inner_stream: stream,
369    })))
370}
371
372fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
373    let prefixes: &[_] = match provider {
374        LanguageModelProvider::Anthropic => &[
375            "claude-3-5-sonnet",
376            "claude-3-haiku",
377            "claude-3-opus",
378            "claude-3-sonnet",
379        ],
380        LanguageModelProvider::OpenAi => &[
381            "gpt-3.5-turbo",
382            "gpt-4-turbo-preview",
383            "gpt-4o-mini",
384            "gpt-4o",
385            "gpt-4",
386        ],
387        LanguageModelProvider::Google => &[],
388        LanguageModelProvider::Zed => &[],
389    };
390
391    if let Some(prefix) = prefixes
392        .iter()
393        .filter(|&&prefix| name.starts_with(prefix))
394        .max_by_key(|&&prefix| prefix.len())
395    {
396        prefix.to_string()
397    } else {
398        name
399    }
400}
401
402async fn check_usage_limit(
403    state: &Arc<LlmState>,
404    provider: LanguageModelProvider,
405    model_name: &str,
406    claims: &LlmTokenClaims,
407) -> Result<()> {
408    let model = state.db.model(provider, model_name)?;
409    let usage = state
410        .db
411        .get_usage(
412            UserId::from_proto(claims.user_id),
413            provider,
414            model_name,
415            Utc::now(),
416        )
417        .await?;
418
419    let active_users = state.get_active_user_count().await?;
420
421    let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
422    let users_in_recent_days = active_users.users_in_recent_days.max(1);
423
424    let per_user_max_requests_per_minute =
425        model.max_requests_per_minute as usize / users_in_recent_minutes;
426    let per_user_max_tokens_per_minute =
427        model.max_tokens_per_minute as usize / users_in_recent_minutes;
428    let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
429
430    let checks = [
431        (
432            usage.requests_this_minute,
433            per_user_max_requests_per_minute,
434            UsageMeasure::RequestsPerMinute,
435        ),
436        (
437            usage.tokens_this_minute,
438            per_user_max_tokens_per_minute,
439            UsageMeasure::TokensPerMinute,
440        ),
441        (
442            usage.tokens_this_day,
443            per_user_max_tokens_per_day,
444            UsageMeasure::TokensPerDay,
445        ),
446    ];
447
448    for (used, limit, usage_measure) in checks {
449        // Temporarily bypass rate-limiting for staff members.
450        if claims.is_staff {
451            continue;
452        }
453
454        if used > limit {
455            let resource = match usage_measure {
456                UsageMeasure::RequestsPerMinute => "requests_per_minute",
457                UsageMeasure::TokensPerMinute => "tokens_per_minute",
458                UsageMeasure::TokensPerDay => "tokens_per_day",
459                _ => "",
460            };
461
462            if let Some(client) = state.clickhouse_client.as_ref() {
463                report_llm_rate_limit(
464                    client,
465                    LlmRateLimitEventRow {
466                        time: Utc::now().timestamp_millis(),
467                        user_id: claims.user_id as i32,
468                        is_staff: claims.is_staff,
469                        plan: match claims.plan {
470                            Plan::Free => "free".to_string(),
471                            Plan::ZedPro => "zed_pro".to_string(),
472                        },
473                        model: model.name.clone(),
474                        provider: provider.to_string(),
475                        usage_measure: resource.to_string(),
476                        requests_this_minute: usage.requests_this_minute as u64,
477                        tokens_this_minute: usage.tokens_this_minute as u64,
478                        tokens_this_day: usage.tokens_this_day as u64,
479                        users_in_recent_minutes: users_in_recent_minutes as u64,
480                        users_in_recent_days: users_in_recent_days as u64,
481                        max_requests_per_minute: per_user_max_requests_per_minute as u64,
482                        max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
483                        max_tokens_per_day: per_user_max_tokens_per_day as u64,
484                    },
485                )
486                .await
487                .log_err();
488            }
489
490            return Err(Error::http(
491                StatusCode::TOO_MANY_REQUESTS,
492                format!("Rate limit exceeded. Maximum {} reached.", resource),
493            ));
494        }
495    }
496
497    Ok(())
498}
499
500struct TokenCountingStream<S> {
501    state: Arc<LlmState>,
502    claims: LlmTokenClaims,
503    provider: LanguageModelProvider,
504    model: String,
505    input_tokens: usize,
506    output_tokens: usize,
507    inner_stream: S,
508}
509
510impl<S> Stream for TokenCountingStream<S>
511where
512    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
513{
514    type Item = Result<Vec<u8>, anyhow::Error>;
515
516    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
517        match Pin::new(&mut self.inner_stream).poll_next(cx) {
518            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
519                bytes.push(b'\n');
520                self.input_tokens += input_tokens;
521                self.output_tokens += output_tokens;
522                Poll::Ready(Some(Ok(bytes)))
523            }
524            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
525            Poll::Ready(None) => Poll::Ready(None),
526            Poll::Pending => Poll::Pending,
527        }
528    }
529}
530
531impl<S> Drop for TokenCountingStream<S> {
532    fn drop(&mut self) {
533        let state = self.state.clone();
534        let claims = self.claims.clone();
535        let provider = self.provider;
536        let model = std::mem::take(&mut self.model);
537        let input_token_count = self.input_tokens;
538        let output_token_count = self.output_tokens;
539        self.state.executor.spawn_detached(async move {
540            let usage = state
541                .db
542                .record_usage(
543                    UserId::from_proto(claims.user_id),
544                    claims.is_staff,
545                    provider,
546                    &model,
547                    input_token_count,
548                    output_token_count,
549                    Utc::now(),
550                )
551                .await
552                .log_err();
553
554            if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
555                report_llm_usage(
556                    clickhouse_client,
557                    LlmUsageEventRow {
558                        time: Utc::now().timestamp_millis(),
559                        user_id: claims.user_id as i32,
560                        is_staff: claims.is_staff,
561                        plan: match claims.plan {
562                            Plan::Free => "free".to_string(),
563                            Plan::ZedPro => "zed_pro".to_string(),
564                        },
565                        model,
566                        provider: provider.to_string(),
567                        input_token_count: input_token_count as u64,
568                        output_token_count: output_token_count as u64,
569                        requests_this_minute: usage.requests_this_minute as u64,
570                        tokens_this_minute: usage.tokens_this_minute as u64,
571                        tokens_this_day: usage.tokens_this_day as u64,
572                        input_tokens_this_month: usage.input_tokens_this_month as u64,
573                        output_tokens_this_month: usage.output_tokens_this_month as u64,
574                        spending_this_month: usage.spending_this_month as u64,
575                        lifetime_spending: usage.lifetime_spending as u64,
576                    },
577                )
578                .await
579                .log_err();
580            }
581        })
582    }
583}