llm.rs

  1mod authorization;
  2pub mod db;
  3mod token;
  4
  5use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
  6use anyhow::{anyhow, Context as _};
  7use authorization::authorize_access_to_language_model;
  8use axum::{
  9    body::Body,
 10    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 11    middleware::{self, Next},
 12    response::{IntoResponse, Response},
 13    routing::post,
 14    Extension, Json, Router, TypedHeader,
 15};
 16use chrono::{DateTime, Duration, Utc};
 17use db::{ActiveUserCount, LlmDatabase};
 18use futures::{Stream, StreamExt as _};
 19use http_client::IsahcHttpClient;
 20use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 21use std::{
 22    pin::Pin,
 23    sync::Arc,
 24    task::{Context, Poll},
 25};
 26use tokio::sync::RwLock;
 27use util::ResultExt;
 28
 29pub use token::*;
 30
 31pub struct LlmState {
 32    pub config: Config,
 33    pub executor: Executor,
 34    pub db: Arc<LlmDatabase>,
 35    pub http_client: IsahcHttpClient,
 36    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 37}
 38
 39const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 40
 41impl LlmState {
 42    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 43        let database_url = config
 44            .llm_database_url
 45            .as_ref()
 46            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 47        let max_connections = config
 48            .llm_database_max_connections
 49            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 50
 51        let mut db_options = db::ConnectOptions::new(database_url);
 52        db_options.max_connections(max_connections);
 53        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 54        db.initialize().await?;
 55
 56        let db = Arc::new(db);
 57
 58        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 59        let http_client = IsahcHttpClient::builder()
 60            .default_header("User-Agent", user_agent)
 61            .build()
 62            .context("failed to construct http client")?;
 63
 64        let initial_active_user_count =
 65            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
 66
 67        let this = Self {
 68            config,
 69            executor,
 70            db,
 71            http_client,
 72            active_user_count: RwLock::new(initial_active_user_count),
 73        };
 74
 75        Ok(Arc::new(this))
 76    }
 77
 78    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
 79        let now = Utc::now();
 80
 81        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
 82            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
 83                return Ok(*count);
 84            }
 85        }
 86
 87        let mut cache = self.active_user_count.write().await;
 88        let new_count = self.db.get_active_user_count(now).await?;
 89        *cache = Some((now, new_count));
 90        Ok(new_count)
 91    }
 92}
 93
 94pub fn routes() -> Router<(), Body> {
 95    Router::new()
 96        .route("/completion", post(perform_completion))
 97        .layer(middleware::from_fn(validate_api_token))
 98}
 99
100async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
101    let token = req
102        .headers()
103        .get(http::header::AUTHORIZATION)
104        .and_then(|header| header.to_str().ok())
105        .ok_or_else(|| {
106            Error::http(
107                StatusCode::BAD_REQUEST,
108                "missing authorization header".to_string(),
109            )
110        })?
111        .strip_prefix("Bearer ")
112        .ok_or_else(|| {
113            Error::http(
114                StatusCode::BAD_REQUEST,
115                "invalid authorization header".to_string(),
116            )
117        })?;
118
119    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
120    match LlmTokenClaims::validate(&token, &state.config) {
121        Ok(claims) => {
122            req.extensions_mut().insert(claims);
123            Ok::<_, Error>(next.run(req).await.into_response())
124        }
125        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
126            StatusCode::UNAUTHORIZED,
127            "unauthorized".to_string(),
128            [(
129                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
130                HeaderValue::from_static("true"),
131            )]
132            .into_iter()
133            .collect(),
134        )),
135        Err(_err) => Err(Error::http(
136            StatusCode::UNAUTHORIZED,
137            "unauthorized".to_string(),
138        )),
139    }
140}
141
142async fn perform_completion(
143    Extension(state): Extension<Arc<LlmState>>,
144    Extension(claims): Extension<LlmTokenClaims>,
145    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
146    Json(params): Json<PerformCompletionParams>,
147) -> Result<impl IntoResponse> {
148    let model = normalize_model_name(params.provider, params.model);
149
150    authorize_access_to_language_model(
151        &state.config,
152        &claims,
153        country_code_header.map(|header| header.to_string()),
154        params.provider,
155        &model,
156    )?;
157
158    let user_id = claims.user_id as i32;
159
160    check_usage_limit(&state, params.provider, &model, &claims).await?;
161
162    let stream = match params.provider {
163        LanguageModelProvider::Anthropic => {
164            let api_key = state
165                .config
166                .anthropic_api_key
167                .as_ref()
168                .context("no Anthropic AI API key configured on the server")?;
169
170            let mut request: anthropic::Request =
171                serde_json::from_str(&params.provider_request.get())?;
172
173            // Parse the model, throw away the version that was included, and then set a specific
174            // version that we control on the server.
175            // Right now, we use the version that's defined in `model.id()`, but we will likely
176            // want to change this code once a new version of an Anthropic model is released,
177            // so that users can use the new version, without having to update Zed.
178            request.model = match anthropic::Model::from_id(&request.model) {
179                Ok(model) => model.id().to_string(),
180                Err(_) => request.model,
181            };
182
183            let chunks = anthropic::stream_completion(
184                &state.http_client,
185                anthropic::ANTHROPIC_API_URL,
186                api_key,
187                request,
188                None,
189            )
190            .await?;
191
192            chunks
193                .map(move |event| {
194                    let chunk = event?;
195                    let (input_tokens, output_tokens) = match &chunk {
196                        anthropic::Event::MessageStart {
197                            message: anthropic::Response { usage, .. },
198                        }
199                        | anthropic::Event::MessageDelta { usage, .. } => (
200                            usage.input_tokens.unwrap_or(0) as usize,
201                            usage.output_tokens.unwrap_or(0) as usize,
202                        ),
203                        _ => (0, 0),
204                    };
205
206                    anyhow::Ok((
207                        serde_json::to_vec(&chunk).unwrap(),
208                        input_tokens,
209                        output_tokens,
210                    ))
211                })
212                .boxed()
213        }
214        LanguageModelProvider::OpenAi => {
215            let api_key = state
216                .config
217                .openai_api_key
218                .as_ref()
219                .context("no OpenAI API key configured on the server")?;
220            let chunks = open_ai::stream_completion(
221                &state.http_client,
222                open_ai::OPEN_AI_API_URL,
223                api_key,
224                serde_json::from_str(&params.provider_request.get())?,
225                None,
226            )
227            .await?;
228
229            chunks
230                .map(|event| {
231                    event.map(|chunk| {
232                        let input_tokens =
233                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
234                        let output_tokens =
235                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
236                        (
237                            serde_json::to_vec(&chunk).unwrap(),
238                            input_tokens,
239                            output_tokens,
240                        )
241                    })
242                })
243                .boxed()
244        }
245        LanguageModelProvider::Google => {
246            let api_key = state
247                .config
248                .google_ai_api_key
249                .as_ref()
250                .context("no Google AI API key configured on the server")?;
251            let chunks = google_ai::stream_generate_content(
252                &state.http_client,
253                google_ai::API_URL,
254                api_key,
255                serde_json::from_str(&params.provider_request.get())?,
256            )
257            .await?;
258
259            chunks
260                .map(|event| {
261                    event.map(|chunk| {
262                        // TODO - implement token counting for Google AI
263                        let input_tokens = 0;
264                        let output_tokens = 0;
265                        (
266                            serde_json::to_vec(&chunk).unwrap(),
267                            input_tokens,
268                            output_tokens,
269                        )
270                    })
271                })
272                .boxed()
273        }
274        LanguageModelProvider::Zed => {
275            let api_key = state
276                .config
277                .qwen2_7b_api_key
278                .as_ref()
279                .context("no Qwen2-7B API key configured on the server")?;
280            let api_url = state
281                .config
282                .qwen2_7b_api_url
283                .as_ref()
284                .context("no Qwen2-7B URL configured on the server")?;
285            let chunks = open_ai::stream_completion(
286                &state.http_client,
287                &api_url,
288                api_key,
289                serde_json::from_str(&params.provider_request.get())?,
290                None,
291            )
292            .await?;
293
294            chunks
295                .map(|event| {
296                    event.map(|chunk| {
297                        let input_tokens =
298                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
299                        let output_tokens =
300                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
301                        (
302                            serde_json::to_vec(&chunk).unwrap(),
303                            input_tokens,
304                            output_tokens,
305                        )
306                    })
307                })
308                .boxed()
309        }
310    };
311
312    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
313        db: state.db.clone(),
314        executor: state.executor.clone(),
315        user_id,
316        provider: params.provider,
317        model,
318        input_tokens: 0,
319        output_tokens: 0,
320        inner_stream: stream,
321    })))
322}
323
324fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
325    match provider {
326        LanguageModelProvider::Anthropic => {
327            for prefix in &[
328                "claude-3-5-sonnet",
329                "claude-3-haiku",
330                "claude-3-opus",
331                "claude-3-sonnet",
332            ] {
333                if name.starts_with(prefix) {
334                    return prefix.to_string();
335                }
336            }
337        }
338        LanguageModelProvider::OpenAi => {}
339        LanguageModelProvider::Google => {}
340        LanguageModelProvider::Zed => {}
341    }
342
343    name
344}
345
346async fn check_usage_limit(
347    state: &Arc<LlmState>,
348    provider: LanguageModelProvider,
349    model_name: &str,
350    claims: &LlmTokenClaims,
351) -> Result<()> {
352    let model = state.db.model(provider, model_name)?;
353    let usage = state
354        .db
355        .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
356        .await?;
357
358    let active_users = state.get_active_user_count().await?;
359
360    let per_user_max_requests_per_minute =
361        model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
362    let per_user_max_tokens_per_minute =
363        model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
364    let per_user_max_tokens_per_day =
365        model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
366
367    let checks = [
368        (
369            usage.requests_this_minute,
370            per_user_max_requests_per_minute,
371            "requests per minute",
372        ),
373        (
374            usage.tokens_this_minute,
375            per_user_max_tokens_per_minute,
376            "tokens per minute",
377        ),
378        (
379            usage.tokens_this_day,
380            per_user_max_tokens_per_day,
381            "tokens per day",
382        ),
383    ];
384
385    for (usage, limit, resource) in checks {
386        if usage > limit {
387            return Err(Error::http(
388                StatusCode::TOO_MANY_REQUESTS,
389                format!("Rate limit exceeded. Maximum {} reached.", resource),
390            ));
391        }
392    }
393
394    Ok(())
395}
396
397struct TokenCountingStream<S> {
398    db: Arc<LlmDatabase>,
399    executor: Executor,
400    user_id: i32,
401    provider: LanguageModelProvider,
402    model: String,
403    input_tokens: usize,
404    output_tokens: usize,
405    inner_stream: S,
406}
407
408impl<S> Stream for TokenCountingStream<S>
409where
410    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
411{
412    type Item = Result<Vec<u8>, anyhow::Error>;
413
414    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
415        match Pin::new(&mut self.inner_stream).poll_next(cx) {
416            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
417                bytes.push(b'\n');
418                self.input_tokens += input_tokens;
419                self.output_tokens += output_tokens;
420                Poll::Ready(Some(Ok(bytes)))
421            }
422            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
423            Poll::Ready(None) => Poll::Ready(None),
424            Poll::Pending => Poll::Pending,
425        }
426    }
427}
428
429impl<S> Drop for TokenCountingStream<S> {
430    fn drop(&mut self) {
431        let db = self.db.clone();
432        let user_id = self.user_id;
433        let provider = self.provider;
434        let model = std::mem::take(&mut self.model);
435        let token_count = self.input_tokens + self.output_tokens;
436        self.executor.spawn_detached(async move {
437            db.record_usage(user_id, provider, &model, token_count, Utc::now())
438                .await
439                .log_err();
440        })
441    }
442}