llm.rs

  1mod authorization;
  2pub mod db;
  3mod telemetry;
  4mod token;
  5
  6use crate::{
  7    api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error,
  8    Result,
  9};
 10use anyhow::{anyhow, Context as _};
 11use authorization::authorize_access_to_language_model;
 12use axum::{
 13    body::Body,
 14    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 15    middleware::{self, Next},
 16    response::{IntoResponse, Response},
 17    routing::post,
 18    Extension, Json, Router, TypedHeader,
 19};
 20use chrono::{DateTime, Duration, Utc};
 21use db::{ActiveUserCount, LlmDatabase};
 22use futures::{Stream, StreamExt as _};
 23use http_client::IsahcHttpClient;
 24use rpc::{
 25    proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
 26};
 27use std::{
 28    pin::Pin,
 29    sync::Arc,
 30    task::{Context, Poll},
 31};
 32use telemetry::{report_llm_usage, LlmUsageEventRow};
 33use tokio::sync::RwLock;
 34use util::ResultExt;
 35
 36pub use token::*;
 37
 38pub struct LlmState {
 39    pub config: Config,
 40    pub executor: Executor,
 41    pub db: Arc<LlmDatabase>,
 42    pub http_client: IsahcHttpClient,
 43    pub clickhouse_client: Option<clickhouse::Client>,
 44    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 45}
 46
 47const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 48
 49impl LlmState {
 50    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 51        let database_url = config
 52            .llm_database_url
 53            .as_ref()
 54            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 55        let max_connections = config
 56            .llm_database_max_connections
 57            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 58
 59        let mut db_options = db::ConnectOptions::new(database_url);
 60        db_options.max_connections(max_connections);
 61        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 62        db.initialize().await?;
 63
 64        let db = Arc::new(db);
 65
 66        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 67        let http_client = IsahcHttpClient::builder()
 68            .default_header("User-Agent", user_agent)
 69            .build()
 70            .context("failed to construct http client")?;
 71
 72        let initial_active_user_count =
 73            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
 74
 75        let this = Self {
 76            executor,
 77            db,
 78            http_client,
 79            clickhouse_client: config
 80                .clickhouse_url
 81                .as_ref()
 82                .and_then(|_| build_clickhouse_client(&config).log_err()),
 83            active_user_count: RwLock::new(initial_active_user_count),
 84            config,
 85        };
 86
 87        Ok(Arc::new(this))
 88    }
 89
 90    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
 91        let now = Utc::now();
 92
 93        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
 94            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
 95                return Ok(*count);
 96            }
 97        }
 98
 99        let mut cache = self.active_user_count.write().await;
100        let new_count = self.db.get_active_user_count(now).await?;
101        *cache = Some((now, new_count));
102        Ok(new_count)
103    }
104}
105
106pub fn routes() -> Router<(), Body> {
107    Router::new()
108        .route("/completion", post(perform_completion))
109        .layer(middleware::from_fn(validate_api_token))
110}
111
112async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
113    let token = req
114        .headers()
115        .get(http::header::AUTHORIZATION)
116        .and_then(|header| header.to_str().ok())
117        .ok_or_else(|| {
118            Error::http(
119                StatusCode::BAD_REQUEST,
120                "missing authorization header".to_string(),
121            )
122        })?
123        .strip_prefix("Bearer ")
124        .ok_or_else(|| {
125            Error::http(
126                StatusCode::BAD_REQUEST,
127                "invalid authorization header".to_string(),
128            )
129        })?;
130
131    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
132    match LlmTokenClaims::validate(&token, &state.config) {
133        Ok(claims) => {
134            req.extensions_mut().insert(claims);
135            Ok::<_, Error>(next.run(req).await.into_response())
136        }
137        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
138            StatusCode::UNAUTHORIZED,
139            "unauthorized".to_string(),
140            [(
141                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
142                HeaderValue::from_static("true"),
143            )]
144            .into_iter()
145            .collect(),
146        )),
147        Err(_err) => Err(Error::http(
148            StatusCode::UNAUTHORIZED,
149            "unauthorized".to_string(),
150        )),
151    }
152}
153
154async fn perform_completion(
155    Extension(state): Extension<Arc<LlmState>>,
156    Extension(claims): Extension<LlmTokenClaims>,
157    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
158    Json(params): Json<PerformCompletionParams>,
159) -> Result<impl IntoResponse> {
160    let model = normalize_model_name(params.provider, params.model);
161
162    authorize_access_to_language_model(
163        &state.config,
164        &claims,
165        country_code_header.map(|header| header.to_string()),
166        params.provider,
167        &model,
168    )?;
169
170    check_usage_limit(&state, params.provider, &model, &claims).await?;
171
172    let stream = match params.provider {
173        LanguageModelProvider::Anthropic => {
174            let api_key = state
175                .config
176                .anthropic_api_key
177                .as_ref()
178                .context("no Anthropic AI API key configured on the server")?;
179
180            let mut request: anthropic::Request =
181                serde_json::from_str(&params.provider_request.get())?;
182
183            // Parse the model, throw away the version that was included, and then set a specific
184            // version that we control on the server.
185            // Right now, we use the version that's defined in `model.id()`, but we will likely
186            // want to change this code once a new version of an Anthropic model is released,
187            // so that users can use the new version, without having to update Zed.
188            request.model = match anthropic::Model::from_id(&request.model) {
189                Ok(model) => model.id().to_string(),
190                Err(_) => request.model,
191            };
192
193            let chunks = anthropic::stream_completion(
194                &state.http_client,
195                anthropic::ANTHROPIC_API_URL,
196                api_key,
197                request,
198                None,
199            )
200            .await
201            .map_err(|err| match err {
202                anthropic::AnthropicError::ApiError(ref api_error) => {
203                    if api_error.code() == Some(anthropic::ApiErrorCode::RateLimitError) {
204                        return Error::http(
205                            StatusCode::TOO_MANY_REQUESTS,
206                            "Upstream Anthropic rate limit exceeded.".to_string(),
207                        );
208                    }
209
210                    Error::Internal(anyhow!(err))
211                }
212                anthropic::AnthropicError::Other(err) => Error::Internal(err),
213            })?;
214
215            chunks
216                .map(move |event| {
217                    let chunk = event?;
218                    let (input_tokens, output_tokens) = match &chunk {
219                        anthropic::Event::MessageStart {
220                            message: anthropic::Response { usage, .. },
221                        }
222                        | anthropic::Event::MessageDelta { usage, .. } => (
223                            usage.input_tokens.unwrap_or(0) as usize,
224                            usage.output_tokens.unwrap_or(0) as usize,
225                        ),
226                        _ => (0, 0),
227                    };
228
229                    anyhow::Ok((
230                        serde_json::to_vec(&chunk).unwrap(),
231                        input_tokens,
232                        output_tokens,
233                    ))
234                })
235                .boxed()
236        }
237        LanguageModelProvider::OpenAi => {
238            let api_key = state
239                .config
240                .openai_api_key
241                .as_ref()
242                .context("no OpenAI API key configured on the server")?;
243            let chunks = open_ai::stream_completion(
244                &state.http_client,
245                open_ai::OPEN_AI_API_URL,
246                api_key,
247                serde_json::from_str(&params.provider_request.get())?,
248                None,
249            )
250            .await?;
251
252            chunks
253                .map(|event| {
254                    event.map(|chunk| {
255                        let input_tokens =
256                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
257                        let output_tokens =
258                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
259                        (
260                            serde_json::to_vec(&chunk).unwrap(),
261                            input_tokens,
262                            output_tokens,
263                        )
264                    })
265                })
266                .boxed()
267        }
268        LanguageModelProvider::Google => {
269            let api_key = state
270                .config
271                .google_ai_api_key
272                .as_ref()
273                .context("no Google AI API key configured on the server")?;
274            let chunks = google_ai::stream_generate_content(
275                &state.http_client,
276                google_ai::API_URL,
277                api_key,
278                serde_json::from_str(&params.provider_request.get())?,
279            )
280            .await?;
281
282            chunks
283                .map(|event| {
284                    event.map(|chunk| {
285                        // TODO - implement token counting for Google AI
286                        let input_tokens = 0;
287                        let output_tokens = 0;
288                        (
289                            serde_json::to_vec(&chunk).unwrap(),
290                            input_tokens,
291                            output_tokens,
292                        )
293                    })
294                })
295                .boxed()
296        }
297        LanguageModelProvider::Zed => {
298            let api_key = state
299                .config
300                .qwen2_7b_api_key
301                .as_ref()
302                .context("no Qwen2-7B API key configured on the server")?;
303            let api_url = state
304                .config
305                .qwen2_7b_api_url
306                .as_ref()
307                .context("no Qwen2-7B URL configured on the server")?;
308            let chunks = open_ai::stream_completion(
309                &state.http_client,
310                &api_url,
311                api_key,
312                serde_json::from_str(&params.provider_request.get())?,
313                None,
314            )
315            .await?;
316
317            chunks
318                .map(|event| {
319                    event.map(|chunk| {
320                        let input_tokens =
321                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
322                        let output_tokens =
323                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
324                        (
325                            serde_json::to_vec(&chunk).unwrap(),
326                            input_tokens,
327                            output_tokens,
328                        )
329                    })
330                })
331                .boxed()
332        }
333    };
334
335    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
336        state,
337        claims,
338        provider: params.provider,
339        model,
340        input_tokens: 0,
341        output_tokens: 0,
342        inner_stream: stream,
343    })))
344}
345
346fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
347    let prefixes: &[_] = match provider {
348        LanguageModelProvider::Anthropic => &[
349            "claude-3-5-sonnet",
350            "claude-3-haiku",
351            "claude-3-opus",
352            "claude-3-sonnet",
353        ],
354        LanguageModelProvider::OpenAi => &[
355            "gpt-3.5-turbo",
356            "gpt-4-turbo-preview",
357            "gpt-4o-mini",
358            "gpt-4o",
359            "gpt-4",
360        ],
361        LanguageModelProvider::Google => &[],
362        LanguageModelProvider::Zed => &[],
363    };
364
365    if let Some(prefix) = prefixes
366        .iter()
367        .filter(|&&prefix| name.starts_with(prefix))
368        .max_by_key(|&&prefix| prefix.len())
369    {
370        prefix.to_string()
371    } else {
372        name
373    }
374}
375
376async fn check_usage_limit(
377    state: &Arc<LlmState>,
378    provider: LanguageModelProvider,
379    model_name: &str,
380    claims: &LlmTokenClaims,
381) -> Result<()> {
382    let model = state.db.model(provider, model_name)?;
383    let usage = state
384        .db
385        .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
386        .await?;
387
388    let active_users = state.get_active_user_count().await?;
389
390    let per_user_max_requests_per_minute =
391        model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
392    let per_user_max_tokens_per_minute =
393        model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
394    let per_user_max_tokens_per_day =
395        model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
396
397    let checks = [
398        (
399            usage.requests_this_minute,
400            per_user_max_requests_per_minute,
401            "requests per minute",
402        ),
403        (
404            usage.tokens_this_minute,
405            per_user_max_tokens_per_minute,
406            "tokens per minute",
407        ),
408        (
409            usage.tokens_this_day,
410            per_user_max_tokens_per_day,
411            "tokens per day",
412        ),
413    ];
414
415    for (usage, limit, resource) in checks {
416        // Temporarily bypass rate-limiting for staff members.
417        if claims.is_staff {
418            continue;
419        }
420
421        if usage > limit {
422            return Err(Error::http(
423                StatusCode::TOO_MANY_REQUESTS,
424                format!("Rate limit exceeded. Maximum {} reached.", resource),
425            ));
426        }
427    }
428
429    Ok(())
430}
431
432struct TokenCountingStream<S> {
433    state: Arc<LlmState>,
434    claims: LlmTokenClaims,
435    provider: LanguageModelProvider,
436    model: String,
437    input_tokens: usize,
438    output_tokens: usize,
439    inner_stream: S,
440}
441
442impl<S> Stream for TokenCountingStream<S>
443where
444    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
445{
446    type Item = Result<Vec<u8>, anyhow::Error>;
447
448    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
449        match Pin::new(&mut self.inner_stream).poll_next(cx) {
450            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
451                bytes.push(b'\n');
452                self.input_tokens += input_tokens;
453                self.output_tokens += output_tokens;
454                Poll::Ready(Some(Ok(bytes)))
455            }
456            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
457            Poll::Ready(None) => Poll::Ready(None),
458            Poll::Pending => Poll::Pending,
459        }
460    }
461}
462
463impl<S> Drop for TokenCountingStream<S> {
464    fn drop(&mut self) {
465        let state = self.state.clone();
466        let claims = self.claims.clone();
467        let provider = self.provider;
468        let model = std::mem::take(&mut self.model);
469        let input_token_count = self.input_tokens;
470        let output_token_count = self.output_tokens;
471        self.state.executor.spawn_detached(async move {
472            let usage = state
473                .db
474                .record_usage(
475                    claims.user_id as i32,
476                    provider,
477                    &model,
478                    input_token_count,
479                    output_token_count,
480                    Utc::now(),
481                )
482                .await
483                .log_err();
484
485            if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
486                report_llm_usage(
487                    clickhouse_client,
488                    LlmUsageEventRow {
489                        time: Utc::now().timestamp_millis(),
490                        user_id: claims.user_id as i32,
491                        is_staff: claims.is_staff,
492                        plan: match claims.plan {
493                            Plan::Free => "free".to_string(),
494                            Plan::ZedPro => "zed_pro".to_string(),
495                        },
496                        model,
497                        provider: provider.to_string(),
498                        input_token_count: input_token_count as u64,
499                        output_token_count: output_token_count as u64,
500                        requests_this_minute: usage.requests_this_minute as u64,
501                        tokens_this_minute: usage.tokens_this_minute as u64,
502                        tokens_this_day: usage.tokens_this_day as u64,
503                        input_tokens_this_month: usage.input_tokens_this_month as u64,
504                        output_tokens_this_month: usage.output_tokens_this_month as u64,
505                        spending_this_month: usage.spending_this_month as u64,
506                    },
507                )
508                .await
509                .log_err();
510            }
511        })
512    }
513}