llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::{
  7    api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error,
  8    Result,
  9};
 10use anyhow::{anyhow, Context as _};
 11use authorization::authorize_access_to_language_model;
 12use axum::{
 13    body::Body,
 14    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 15    middleware::{self, Next},
 16    response::{IntoResponse, Response},
 17    routing::post,
 18    Extension, Json, Router, TypedHeader,
 19};
 20use chrono::{DateTime, Duration, Utc};
 21use db::{ActiveUserCount, LlmDatabase};
 22use futures::{Stream, StreamExt as _};
 23use http_client::IsahcHttpClient;
 24use rpc::{
 25    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 26};
 27use std::{
 28    pin::Pin,
 29    sync::Arc,
 30    task::{Context, Poll},
 31};
 32use telemetry::{report_llm_usage, LlmUsageEventRow};
 33use tokio::sync::RwLock;
 34use util::ResultExt;
 35
 36pub use token::*;
 37
 38pub struct LlmState {
 39    pub config: Config,
 40    pub executor: Executor,
 41    pub db: Arc<LlmDatabase>,
 42    pub http_client: IsahcHttpClient,
 43    pub clickhouse_client: Option<clickhouse::Client>,
 44    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 45}
 46
 47const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 48
 49impl LlmState {
 50    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 51        let database_url = config
 52            .llm_database_url
 53            .as_ref()
 54            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 55        let max_connections = config
 56            .llm_database_max_connections
 57            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 58
 59        let mut db_options = db::ConnectOptions::new(database_url);
 60        db_options.max_connections(max_connections);
 61        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 62        db.initialize().await?;
 63
 64        let db = Arc::new(db);
 65
 66        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 67        let http_client = IsahcHttpClient::builder()
 68            .default_header("User-Agent", user_agent)
 69            .build()
 70            .context("failed to construct http client")?;
 71
 72        let initial_active_user_count =
 73            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
 74
 75        let this = Self {
 76            executor,
 77            db,
 78            http_client,
 79            clickhouse_client: config
 80                .clickhouse_url
 81                .as_ref()
 82                .and_then(|_| build_clickhouse_client(&config).log_err()),
 83            active_user_count: RwLock::new(initial_active_user_count),
 84            config,
 85        };
 86
 87        Ok(Arc::new(this))
 88    }
 89
 90    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
 91        let now = Utc::now();
 92
 93        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
 94            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
 95                return Ok(*count);
 96            }
 97        }
 98
 99        let mut cache = self.active_user_count.write().await;
100        let new_count = self.db.get_active_user_count(now).await?;
101        *cache = Some((now, new_count));
102        Ok(new_count)
103    }
104}
105
106pub fn routes() -> Router<(), Body> {
107    Router::new()
108        .route("/completion", post(perform_completion))
109        .layer(middleware::from_fn(validate_api_token))
110}
111
112async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
113    let token = req
114        .headers()
115        .get(http::header::AUTHORIZATION)
116        .and_then(|header| header.to_str().ok())
117        .ok_or_else(|| {
118            Error::http(
119                StatusCode::BAD_REQUEST,
120                "missing authorization header".to_string(),
121            )
122        })?
123        .strip_prefix("Bearer ")
124        .ok_or_else(|| {
125            Error::http(
126                StatusCode::BAD_REQUEST,
127                "invalid authorization header".to_string(),
128            )
129        })?;
130
131    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
132    match LlmTokenClaims::validate(&token, &state.config) {
133        Ok(claims) => {
134            req.extensions_mut().insert(claims);
135            Ok::<_, Error>(next.run(req).await.into_response())
136        }
137        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
138            StatusCode::UNAUTHORIZED,
139            "unauthorized".to_string(),
140            [(
141                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
142                HeaderValue::from_static("true"),
143            )]
144            .into_iter()
145            .collect(),
146        )),
147        Err(_err) => Err(Error::http(
148            StatusCode::UNAUTHORIZED,
149            "unauthorized".to_string(),
150        )),
151    }
152}
153
154async fn perform_completion(
155    Extension(state): Extension<Arc<LlmState>>,
156    Extension(claims): Extension<LlmTokenClaims>,
157    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
158    Json(params): Json<PerformCompletionParams>,
159) -> Result<impl IntoResponse> {
160    let model = normalize_model_name(params.provider, params.model);
161
162    authorize_access_to_language_model(
163        &state.config,
164        &claims,
165        country_code_header.map(|header| header.to_string()),
166        params.provider,
167        &model,
168    )?;
169
170    check_usage_limit(&state, params.provider, &model, &claims).await?;
171
172    let stream = match params.provider {
173        LanguageModelProvider::Anthropic => {
174            let api_key = if claims.is_staff {
175                state
176                    .config
177                    .anthropic_staff_api_key
178                    .as_ref()
179                    .context("no Anthropic AI staff API key configured on the server")?
180            } else {
181                state
182                    .config
183                    .anthropic_api_key
184                    .as_ref()
185                    .context("no Anthropic AI API key configured on the server")?
186            };
187
188            let mut request: anthropic::Request =
189                serde_json::from_str(&params.provider_request.get())?;
190
191            // Parse the model, throw away the version that was included, and then set a specific
192            // version that we control on the server.
193            // Right now, we use the version that's defined in `model.id()`, but we will likely
194            // want to change this code once a new version of an Anthropic model is released,
195            // so that users can use the new version, without having to update Zed.
196            request.model = match anthropic::Model::from_id(&request.model) {
197                Ok(model) => model.id().to_string(),
198                Err(_) => request.model,
199            };
200
201            let chunks = anthropic::stream_completion(
202                &state.http_client,
203                anthropic::ANTHROPIC_API_URL,
204                api_key,
205                request,
206                None,
207            )
208            .await
209            .map_err(|err| match err {
210                anthropic::AnthropicError::ApiError(ref api_error) => {
211                    if api_error.code() == Some(anthropic::ApiErrorCode::RateLimitError) {
212                        return Error::http(
213                            StatusCode::TOO_MANY_REQUESTS,
214                            "Upstream Anthropic rate limit exceeded.".to_string(),
215                        );
216                    }
217
218                    Error::Internal(anyhow!(err))
219                }
220                anthropic::AnthropicError::Other(err) => Error::Internal(err),
221            })?;
222
223            chunks
224                .map(move |event| {
225                    let chunk = event?;
226                    let (input_tokens, output_tokens) = match &chunk {
227                        anthropic::Event::MessageStart {
228                            message: anthropic::Response { usage, .. },
229                        }
230                        | anthropic::Event::MessageDelta { usage, .. } => (
231                            usage.input_tokens.unwrap_or(0) as usize,
232                            usage.output_tokens.unwrap_or(0) as usize,
233                        ),
234                        _ => (0, 0),
235                    };
236
237                    anyhow::Ok((
238                        serde_json::to_vec(&chunk).unwrap(),
239                        input_tokens,
240                        output_tokens,
241                    ))
242                })
243                .boxed()
244        }
245        LanguageModelProvider::OpenAi => {
246            let api_key = state
247                .config
248                .openai_api_key
249                .as_ref()
250                .context("no OpenAI API key configured on the server")?;
251            let chunks = open_ai::stream_completion(
252                &state.http_client,
253                open_ai::OPEN_AI_API_URL,
254                api_key,
255                serde_json::from_str(&params.provider_request.get())?,
256                None,
257            )
258            .await?;
259
260            chunks
261                .map(|event| {
262                    event.map(|chunk| {
263                        let input_tokens =
264                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
265                        let output_tokens =
266                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
267                        (
268                            serde_json::to_vec(&chunk).unwrap(),
269                            input_tokens,
270                            output_tokens,
271                        )
272                    })
273                })
274                .boxed()
275        }
276        LanguageModelProvider::Google => {
277            let api_key = state
278                .config
279                .google_ai_api_key
280                .as_ref()
281                .context("no Google AI API key configured on the server")?;
282            let chunks = google_ai::stream_generate_content(
283                &state.http_client,
284                google_ai::API_URL,
285                api_key,
286                serde_json::from_str(&params.provider_request.get())?,
287            )
288            .await?;
289
290            chunks
291                .map(|event| {
292                    event.map(|chunk| {
293                        // TODO - implement token counting for Google AI
294                        let input_tokens = 0;
295                        let output_tokens = 0;
296                        (
297                            serde_json::to_vec(&chunk).unwrap(),
298                            input_tokens,
299                            output_tokens,
300                        )
301                    })
302                })
303                .boxed()
304        }
305        LanguageModelProvider::Zed => {
306            let api_key = state
307                .config
308                .qwen2_7b_api_key
309                .as_ref()
310                .context("no Qwen2-7B API key configured on the server")?;
311            let api_url = state
312                .config
313                .qwen2_7b_api_url
314                .as_ref()
315                .context("no Qwen2-7B URL configured on the server")?;
316            let chunks = open_ai::stream_completion(
317                &state.http_client,
318                &api_url,
319                api_key,
320                serde_json::from_str(&params.provider_request.get())?,
321                None,
322            )
323            .await?;
324
325            chunks
326                .map(|event| {
327                    event.map(|chunk| {
328                        let input_tokens =
329                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
330                        let output_tokens =
331                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
332                        (
333                            serde_json::to_vec(&chunk).unwrap(),
334                            input_tokens,
335                            output_tokens,
336                        )
337                    })
338                })
339                .boxed()
340        }
341    };
342
343    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
344        state,
345        claims,
346        provider: params.provider,
347        model,
348        input_tokens: 0,
349        output_tokens: 0,
350        inner_stream: stream,
351    })))
352}
353
354fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
355    let prefixes: &[_] = match provider {
356        LanguageModelProvider::Anthropic => &[
357            "claude-3-5-sonnet",
358            "claude-3-haiku",
359            "claude-3-opus",
360            "claude-3-sonnet",
361        ],
362        LanguageModelProvider::OpenAi => &[
363            "gpt-3.5-turbo",
364            "gpt-4-turbo-preview",
365            "gpt-4o-mini",
366            "gpt-4o",
367            "gpt-4",
368        ],
369        LanguageModelProvider::Google => &[],
370        LanguageModelProvider::Zed => &[],
371    };
372
373    if let Some(prefix) = prefixes
374        .iter()
375        .filter(|&&prefix| name.starts_with(prefix))
376        .max_by_key(|&&prefix| prefix.len())
377    {
378        prefix.to_string()
379    } else {
380        name
381    }
382}
383
384async fn check_usage_limit(
385    state: &Arc<LlmState>,
386    provider: LanguageModelProvider,
387    model_name: &str,
388    claims: &LlmTokenClaims,
389) -> Result<()> {
390    let model = state.db.model(provider, model_name)?;
391    let usage = state
392        .db
393        .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
394        .await?;
395
396    let active_users = state.get_active_user_count().await?;
397
398    let per_user_max_requests_per_minute =
399        model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
400    let per_user_max_tokens_per_minute =
401        model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
402    let per_user_max_tokens_per_day =
403        model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
404
405    let checks = [
406        (
407            usage.requests_this_minute,
408            per_user_max_requests_per_minute,
409            "requests per minute",
410        ),
411        (
412            usage.tokens_this_minute,
413            per_user_max_tokens_per_minute,
414            "tokens per minute",
415        ),
416        (
417            usage.tokens_this_day,
418            per_user_max_tokens_per_day,
419            "tokens per day",
420        ),
421    ];
422
423    for (usage, limit, resource) in checks {
424        // Temporarily bypass rate-limiting for staff members.
425        if claims.is_staff {
426            continue;
427        }
428
429        if usage > limit {
430            return Err(Error::http(
431                StatusCode::TOO_MANY_REQUESTS,
432                format!("Rate limit exceeded. Maximum {} reached.", resource),
433            ));
434        }
435    }
436
437    Ok(())
438}
439
440struct TokenCountingStream<S> {
441    state: Arc<LlmState>,
442    claims: LlmTokenClaims,
443    provider: LanguageModelProvider,
444    model: String,
445    input_tokens: usize,
446    output_tokens: usize,
447    inner_stream: S,
448}
449
450impl<S> Stream for TokenCountingStream<S>
451where
452    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
453{
454    type Item = Result<Vec<u8>, anyhow::Error>;
455
456    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
457        match Pin::new(&mut self.inner_stream).poll_next(cx) {
458            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
459                bytes.push(b'\n');
460                self.input_tokens += input_tokens;
461                self.output_tokens += output_tokens;
462                Poll::Ready(Some(Ok(bytes)))
463            }
464            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
465            Poll::Ready(None) => Poll::Ready(None),
466            Poll::Pending => Poll::Pending,
467        }
468    }
469}
470
471impl<S> Drop for TokenCountingStream<S> {
472    fn drop(&mut self) {
473        let state = self.state.clone();
474        let claims = self.claims.clone();
475        let provider = self.provider;
476        let model = std::mem::take(&mut self.model);
477        let input_token_count = self.input_tokens;
478        let output_token_count = self.output_tokens;
479        self.state.executor.spawn_detached(async move {
480            let usage = state
481                .db
482                .record_usage(
483                    claims.user_id as i32,
484                    claims.is_staff,
485                    provider,
486                    &model,
487                    input_token_count,
488                    output_token_count,
489                    Utc::now(),
490                )
491                .await
492                .log_err();
493
494            if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
495                report_llm_usage(
496                    clickhouse_client,
497                    LlmUsageEventRow {
498                        time: Utc::now().timestamp_millis(),
499                        user_id: claims.user_id as i32,
500                        is_staff: claims.is_staff,
501                        plan: match claims.plan {
502                            Plan::Free => "free".to_string(),
503                            Plan::ZedPro => "zed_pro".to_string(),
504                        },
505                        model,
506                        provider: provider.to_string(),
507                        input_token_count: input_token_count as u64,
508                        output_token_count: output_token_count as u64,
509                        requests_this_minute: usage.requests_this_minute as u64,
510                        tokens_this_minute: usage.tokens_this_minute as u64,
511                        tokens_this_day: usage.tokens_this_day as u64,
512                        input_tokens_this_month: usage.input_tokens_this_month as u64,
513                        output_tokens_this_month: usage.output_tokens_this_month as u64,
514                        spending_this_month: usage.spending_this_month as u64,
515                    },
516                )
517                .await
518                .log_err();
519            }
520        })
521    }
522}