llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::{
  7    api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
  8    Config, Error, Result,
  9};
 10use anyhow::{anyhow, Context as _};
 11use authorization::authorize_access_to_language_model;
 12use axum::{
 13    body::Body,
 14    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 15    middleware::{self, Next},
 16    response::{IntoResponse, Response},
 17    routing::post,
 18    Extension, Json, Router, TypedHeader,
 19};
 20use chrono::{DateTime, Duration, Utc};
 21use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
 22use futures::{Stream, StreamExt as _};
 23use http_client::IsahcHttpClient;
 24use rpc::{
 25    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 26};
 27use std::{
 28    pin::Pin,
 29    sync::Arc,
 30    task::{Context, Poll},
 31};
 32use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
 33use tokio::sync::RwLock;
 34use util::ResultExt;
 35
 36pub use token::*;
 37
 38pub struct LlmState {
 39    pub config: Config,
 40    pub executor: Executor,
 41    pub db: Arc<LlmDatabase>,
 42    pub http_client: IsahcHttpClient,
 43    pub clickhouse_client: Option<clickhouse::Client>,
 44    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 45}
 46
 47const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 48
 49impl LlmState {
 50    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 51        let database_url = config
 52            .llm_database_url
 53            .as_ref()
 54            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 55        let max_connections = config
 56            .llm_database_max_connections
 57            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 58
 59        let mut db_options = db::ConnectOptions::new(database_url);
 60        db_options.max_connections(max_connections);
 61        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 62        db.initialize().await?;
 63
 64        let db = Arc::new(db);
 65
 66        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 67        let http_client = IsahcHttpClient::builder()
 68            .default_header("User-Agent", user_agent)
 69            .build()
 70            .context("failed to construct http client")?;
 71
 72        let initial_active_user_count =
 73            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
 74
 75        let this = Self {
 76            executor,
 77            db,
 78            http_client,
 79            clickhouse_client: config
 80                .clickhouse_url
 81                .as_ref()
 82                .and_then(|_| build_clickhouse_client(&config).log_err()),
 83            active_user_count: RwLock::new(initial_active_user_count),
 84            config,
 85        };
 86
 87        Ok(Arc::new(this))
 88    }
 89
 90    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
 91        let now = Utc::now();
 92
 93        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
 94            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
 95                return Ok(*count);
 96            }
 97        }
 98
 99        let mut cache = self.active_user_count.write().await;
100        let new_count = self.db.get_active_user_count(now).await?;
101        *cache = Some((now, new_count));
102        Ok(new_count)
103    }
104}
105
106pub fn routes() -> Router<(), Body> {
107    Router::new()
108        .route("/completion", post(perform_completion))
109        .layer(middleware::from_fn(validate_api_token))
110}
111
112async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
113    let token = req
114        .headers()
115        .get(http::header::AUTHORIZATION)
116        .and_then(|header| header.to_str().ok())
117        .ok_or_else(|| {
118            Error::http(
119                StatusCode::BAD_REQUEST,
120                "missing authorization header".to_string(),
121            )
122        })?
123        .strip_prefix("Bearer ")
124        .ok_or_else(|| {
125            Error::http(
126                StatusCode::BAD_REQUEST,
127                "invalid authorization header".to_string(),
128            )
129        })?;
130
131    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
132    match LlmTokenClaims::validate(&token, &state.config) {
133        Ok(claims) => {
134            if state.db.is_access_token_revoked(&claims.jti).await? {
135                return Err(Error::http(
136                    StatusCode::UNAUTHORIZED,
137                    "unauthorized".to_string(),
138                ));
139            }
140
141            tracing::Span::current().record("authn.jti", &claims.jti);
142
143            req.extensions_mut().insert(claims);
144            Ok::<_, Error>(next.run(req).await.into_response())
145        }
146        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
147            StatusCode::UNAUTHORIZED,
148            "unauthorized".to_string(),
149            [(
150                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
151                HeaderValue::from_static("true"),
152            )]
153            .into_iter()
154            .collect(),
155        )),
156        Err(_err) => Err(Error::http(
157            StatusCode::UNAUTHORIZED,
158            "unauthorized".to_string(),
159        )),
160    }
161}
162
163async fn perform_completion(
164    Extension(state): Extension<Arc<LlmState>>,
165    Extension(claims): Extension<LlmTokenClaims>,
166    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
167    Json(params): Json<PerformCompletionParams>,
168) -> Result<impl IntoResponse> {
169    let model = normalize_model_name(params.provider, params.model);
170
171    authorize_access_to_language_model(
172        &state.config,
173        &claims,
174        country_code_header.map(|header| header.to_string()),
175        params.provider,
176        &model,
177    )?;
178
179    check_usage_limit(&state, params.provider, &model, &claims).await?;
180
181    let stream = match params.provider {
182        LanguageModelProvider::Anthropic => {
183            let api_key = if claims.is_staff {
184                state
185                    .config
186                    .anthropic_staff_api_key
187                    .as_ref()
188                    .context("no Anthropic AI staff API key configured on the server")?
189            } else {
190                state
191                    .config
192                    .anthropic_api_key
193                    .as_ref()
194                    .context("no Anthropic AI API key configured on the server")?
195            };
196
197            let mut request: anthropic::Request =
198                serde_json::from_str(&params.provider_request.get())?;
199
200            // Parse the model, throw away the version that was included, and then set a specific
201            // version that we control on the server.
202            // Right now, we use the version that's defined in `model.id()`, but we will likely
203            // want to change this code once a new version of an Anthropic model is released,
204            // so that users can use the new version, without having to update Zed.
205            request.model = match anthropic::Model::from_id(&request.model) {
206                Ok(model) => model.id().to_string(),
207                Err(_) => request.model,
208            };
209
210            let chunks = anthropic::stream_completion(
211                &state.http_client,
212                anthropic::ANTHROPIC_API_URL,
213                api_key,
214                request,
215                None,
216            )
217            .await
218            .map_err(|err| match err {
219                anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
220                    Some(anthropic::ApiErrorCode::RateLimitError) => Error::http(
221                        StatusCode::TOO_MANY_REQUESTS,
222                        "Upstream Anthropic rate limit exceeded.".to_string(),
223                    ),
224                    Some(anthropic::ApiErrorCode::InvalidRequestError) => {
225                        Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
226                    }
227                    Some(anthropic::ApiErrorCode::OverloadedError) => {
228                        Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
229                    }
230                    Some(_) => {
231                        Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
232                    }
233                    None => Error::Internal(anyhow!(err)),
234                },
235                anthropic::AnthropicError::Other(err) => Error::Internal(err),
236            })?;
237
238            chunks
239                .map(move |event| {
240                    let chunk = event?;
241                    let (input_tokens, output_tokens) = match &chunk {
242                        anthropic::Event::MessageStart {
243                            message: anthropic::Response { usage, .. },
244                        }
245                        | anthropic::Event::MessageDelta { usage, .. } => (
246                            usage.input_tokens.unwrap_or(0) as usize,
247                            usage.output_tokens.unwrap_or(0) as usize,
248                        ),
249                        _ => (0, 0),
250                    };
251
252                    anyhow::Ok((
253                        serde_json::to_vec(&chunk).unwrap(),
254                        input_tokens,
255                        output_tokens,
256                    ))
257                })
258                .boxed()
259        }
260        LanguageModelProvider::OpenAi => {
261            let api_key = state
262                .config
263                .openai_api_key
264                .as_ref()
265                .context("no OpenAI API key configured on the server")?;
266            let chunks = open_ai::stream_completion(
267                &state.http_client,
268                open_ai::OPEN_AI_API_URL,
269                api_key,
270                serde_json::from_str(&params.provider_request.get())?,
271                None,
272            )
273            .await?;
274
275            chunks
276                .map(|event| {
277                    event.map(|chunk| {
278                        let input_tokens =
279                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
280                        let output_tokens =
281                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
282                        (
283                            serde_json::to_vec(&chunk).unwrap(),
284                            input_tokens,
285                            output_tokens,
286                        )
287                    })
288                })
289                .boxed()
290        }
291        LanguageModelProvider::Google => {
292            let api_key = state
293                .config
294                .google_ai_api_key
295                .as_ref()
296                .context("no Google AI API key configured on the server")?;
297            let chunks = google_ai::stream_generate_content(
298                &state.http_client,
299                google_ai::API_URL,
300                api_key,
301                serde_json::from_str(&params.provider_request.get())?,
302            )
303            .await?;
304
305            chunks
306                .map(|event| {
307                    event.map(|chunk| {
308                        // TODO - implement token counting for Google AI
309                        let input_tokens = 0;
310                        let output_tokens = 0;
311                        (
312                            serde_json::to_vec(&chunk).unwrap(),
313                            input_tokens,
314                            output_tokens,
315                        )
316                    })
317                })
318                .boxed()
319        }
320        LanguageModelProvider::Zed => {
321            let api_key = state
322                .config
323                .qwen2_7b_api_key
324                .as_ref()
325                .context("no Qwen2-7B API key configured on the server")?;
326            let api_url = state
327                .config
328                .qwen2_7b_api_url
329                .as_ref()
330                .context("no Qwen2-7B URL configured on the server")?;
331            let chunks = open_ai::stream_completion(
332                &state.http_client,
333                &api_url,
334                api_key,
335                serde_json::from_str(&params.provider_request.get())?,
336                None,
337            )
338            .await?;
339
340            chunks
341                .map(|event| {
342                    event.map(|chunk| {
343                        let input_tokens =
344                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
345                        let output_tokens =
346                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
347                        (
348                            serde_json::to_vec(&chunk).unwrap(),
349                            input_tokens,
350                            output_tokens,
351                        )
352                    })
353                })
354                .boxed()
355        }
356    };
357
358    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
359        state,
360        claims,
361        provider: params.provider,
362        model,
363        input_tokens: 0,
364        output_tokens: 0,
365        inner_stream: stream,
366    })))
367}
368
369fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
370    let prefixes: &[_] = match provider {
371        LanguageModelProvider::Anthropic => &[
372            "claude-3-5-sonnet",
373            "claude-3-haiku",
374            "claude-3-opus",
375            "claude-3-sonnet",
376        ],
377        LanguageModelProvider::OpenAi => &[
378            "gpt-3.5-turbo",
379            "gpt-4-turbo-preview",
380            "gpt-4o-mini",
381            "gpt-4o",
382            "gpt-4",
383        ],
384        LanguageModelProvider::Google => &[],
385        LanguageModelProvider::Zed => &[],
386    };
387
388    if let Some(prefix) = prefixes
389        .iter()
390        .filter(|&&prefix| name.starts_with(prefix))
391        .max_by_key(|&&prefix| prefix.len())
392    {
393        prefix.to_string()
394    } else {
395        name
396    }
397}
398
399async fn check_usage_limit(
400    state: &Arc<LlmState>,
401    provider: LanguageModelProvider,
402    model_name: &str,
403    claims: &LlmTokenClaims,
404) -> Result<()> {
405    let model = state.db.model(provider, model_name)?;
406    let usage = state
407        .db
408        .get_usage(
409            UserId::from_proto(claims.user_id),
410            provider,
411            model_name,
412            Utc::now(),
413        )
414        .await?;
415
416    let active_users = state.get_active_user_count().await?;
417
418    let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
419    let users_in_recent_days = active_users.users_in_recent_days.max(1);
420
421    let per_user_max_requests_per_minute =
422        model.max_requests_per_minute as usize / users_in_recent_minutes;
423    let per_user_max_tokens_per_minute =
424        model.max_tokens_per_minute as usize / users_in_recent_minutes;
425    let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
426
427    let checks = [
428        (
429            usage.requests_this_minute,
430            per_user_max_requests_per_minute,
431            UsageMeasure::RequestsPerMinute,
432        ),
433        (
434            usage.tokens_this_minute,
435            per_user_max_tokens_per_minute,
436            UsageMeasure::TokensPerMinute,
437        ),
438        (
439            usage.tokens_this_day,
440            per_user_max_tokens_per_day,
441            UsageMeasure::TokensPerDay,
442        ),
443    ];
444
445    for (used, limit, usage_measure) in checks {
446        // Temporarily bypass rate-limiting for staff members.
447        if claims.is_staff {
448            continue;
449        }
450
451        if used > limit {
452            let resource = match usage_measure {
453                UsageMeasure::RequestsPerMinute => "requests_per_minute",
454                UsageMeasure::TokensPerMinute => "tokens_per_minute",
455                UsageMeasure::TokensPerDay => "tokens_per_day",
456                _ => "",
457            };
458
459            if let Some(client) = state.clickhouse_client.as_ref() {
460                report_llm_rate_limit(
461                    client,
462                    LlmRateLimitEventRow {
463                        time: Utc::now().timestamp_millis(),
464                        user_id: claims.user_id as i32,
465                        is_staff: claims.is_staff,
466                        plan: match claims.plan {
467                            Plan::Free => "free".to_string(),
468                            Plan::ZedPro => "zed_pro".to_string(),
469                        },
470                        model: model.name.clone(),
471                        provider: provider.to_string(),
472                        usage_measure: resource.to_string(),
473                        requests_this_minute: usage.requests_this_minute as u64,
474                        tokens_this_minute: usage.tokens_this_minute as u64,
475                        tokens_this_day: usage.tokens_this_day as u64,
476                        users_in_recent_minutes: users_in_recent_minutes as u64,
477                        users_in_recent_days: users_in_recent_days as u64,
478                        max_requests_per_minute: per_user_max_requests_per_minute as u64,
479                        max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
480                        max_tokens_per_day: per_user_max_tokens_per_day as u64,
481                    },
482                )
483                .await
484                .log_err();
485            }
486
487            return Err(Error::http(
488                StatusCode::TOO_MANY_REQUESTS,
489                format!("Rate limit exceeded. Maximum {} reached.", resource),
490            ));
491        }
492    }
493
494    Ok(())
495}
496
497struct TokenCountingStream<S> {
498    state: Arc<LlmState>,
499    claims: LlmTokenClaims,
500    provider: LanguageModelProvider,
501    model: String,
502    input_tokens: usize,
503    output_tokens: usize,
504    inner_stream: S,
505}
506
507impl<S> Stream for TokenCountingStream<S>
508where
509    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
510{
511    type Item = Result<Vec<u8>, anyhow::Error>;
512
513    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
514        match Pin::new(&mut self.inner_stream).poll_next(cx) {
515            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
516                bytes.push(b'\n');
517                self.input_tokens += input_tokens;
518                self.output_tokens += output_tokens;
519                Poll::Ready(Some(Ok(bytes)))
520            }
521            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
522            Poll::Ready(None) => Poll::Ready(None),
523            Poll::Pending => Poll::Pending,
524        }
525    }
526}
527
528impl<S> Drop for TokenCountingStream<S> {
529    fn drop(&mut self) {
530        let state = self.state.clone();
531        let claims = self.claims.clone();
532        let provider = self.provider;
533        let model = std::mem::take(&mut self.model);
534        let input_token_count = self.input_tokens;
535        let output_token_count = self.output_tokens;
536        self.state.executor.spawn_detached(async move {
537            let usage = state
538                .db
539                .record_usage(
540                    UserId::from_proto(claims.user_id),
541                    claims.is_staff,
542                    provider,
543                    &model,
544                    input_token_count,
545                    output_token_count,
546                    Utc::now(),
547                )
548                .await
549                .log_err();
550
551            if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
552                report_llm_usage(
553                    clickhouse_client,
554                    LlmUsageEventRow {
555                        time: Utc::now().timestamp_millis(),
556                        user_id: claims.user_id as i32,
557                        is_staff: claims.is_staff,
558                        plan: match claims.plan {
559                            Plan::Free => "free".to_string(),
560                            Plan::ZedPro => "zed_pro".to_string(),
561                        },
562                        model,
563                        provider: provider.to_string(),
564                        input_token_count: input_token_count as u64,
565                        output_token_count: output_token_count as u64,
566                        requests_this_minute: usage.requests_this_minute as u64,
567                        tokens_this_minute: usage.tokens_this_minute as u64,
568                        tokens_this_day: usage.tokens_this_day as u64,
569                        input_tokens_this_month: usage.input_tokens_this_month as u64,
570                        output_tokens_this_month: usage.output_tokens_this_month as u64,
571                        spending_this_month: usage.spending_this_month as u64,
572                        lifetime_spending: usage.lifetime_spending as u64,
573                    },
574                )
575                .await
576                .log_err();
577            }
578        })
579    }
580}