llm.rs

  1mod authorization;
  2pub mod db;
  3mod token;
  4
  5use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
  6use anyhow::{anyhow, Context as _};
  7use authorization::authorize_access_to_language_model;
  8use axum::{
  9    body::Body,
 10    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 11    middleware::{self, Next},
 12    response::{IntoResponse, Response},
 13    routing::post,
 14    Extension, Json, Router, TypedHeader,
 15};
 16use chrono::{DateTime, Duration, Utc};
 17use db::{ActiveUserCount, LlmDatabase};
 18use futures::StreamExt as _;
 19use http_client::IsahcHttpClient;
 20use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 21use std::sync::Arc;
 22use tokio::sync::RwLock;
 23use util::ResultExt;
 24
 25pub use token::*;
 26
 27pub struct LlmState {
 28    pub config: Config,
 29    pub executor: Executor,
 30    pub db: Option<Arc<LlmDatabase>>,
 31    pub http_client: IsahcHttpClient,
 32    active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 33}
 34
 35const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 36
 37impl LlmState {
 38    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 39        // TODO: This is temporary until we have the LLM database stood up.
 40        let db = if config.is_development() {
 41            let database_url = config
 42                .llm_database_url
 43                .as_ref()
 44                .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 45            let max_connections = config
 46                .llm_database_max_connections
 47                .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 48
 49            let mut db_options = db::ConnectOptions::new(database_url);
 50            db_options.max_connections(max_connections);
 51            let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
 52            db.initialize().await?;
 53
 54            Some(Arc::new(db))
 55        } else {
 56            None
 57        };
 58
 59        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 60        let http_client = IsahcHttpClient::builder()
 61            .default_header("User-Agent", user_agent)
 62            .build()
 63            .context("failed to construct http client")?;
 64
 65        let initial_active_user_count = if let Some(db) = &db {
 66            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
 67        } else {
 68            None
 69        };
 70
 71        let this = Self {
 72            config,
 73            executor,
 74            db,
 75            http_client,
 76            active_user_count: RwLock::new(initial_active_user_count),
 77        };
 78
 79        Ok(Arc::new(this))
 80    }
 81
 82    pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
 83        let now = Utc::now();
 84
 85        if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
 86            if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
 87                return Ok(*count);
 88            }
 89        }
 90
 91        if let Some(db) = &self.db {
 92            let mut cache = self.active_user_count.write().await;
 93            let new_count = db.get_active_user_count(now).await?;
 94            *cache = Some((now, new_count));
 95            Ok(new_count)
 96        } else {
 97            Ok(ActiveUserCount::default())
 98        }
 99    }
100}
101
102pub fn routes() -> Router<(), Body> {
103    Router::new()
104        .route("/completion", post(perform_completion))
105        .layer(middleware::from_fn(validate_api_token))
106}
107
108async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
109    let token = req
110        .headers()
111        .get(http::header::AUTHORIZATION)
112        .and_then(|header| header.to_str().ok())
113        .ok_or_else(|| {
114            Error::http(
115                StatusCode::BAD_REQUEST,
116                "missing authorization header".to_string(),
117            )
118        })?
119        .strip_prefix("Bearer ")
120        .ok_or_else(|| {
121            Error::http(
122                StatusCode::BAD_REQUEST,
123                "invalid authorization header".to_string(),
124            )
125        })?;
126
127    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
128    match LlmTokenClaims::validate(&token, &state.config) {
129        Ok(claims) => {
130            req.extensions_mut().insert(claims);
131            Ok::<_, Error>(next.run(req).await.into_response())
132        }
133        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
134            StatusCode::UNAUTHORIZED,
135            "unauthorized".to_string(),
136            [(
137                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
138                HeaderValue::from_static("true"),
139            )]
140            .into_iter()
141            .collect(),
142        )),
143        Err(_err) => Err(Error::http(
144            StatusCode::UNAUTHORIZED,
145            "unauthorized".to_string(),
146        )),
147    }
148}
149
150async fn perform_completion(
151    Extension(state): Extension<Arc<LlmState>>,
152    Extension(claims): Extension<LlmTokenClaims>,
153    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
154    Json(params): Json<PerformCompletionParams>,
155) -> Result<impl IntoResponse> {
156    let model = normalize_model_name(params.provider, params.model);
157
158    authorize_access_to_language_model(
159        &state.config,
160        &claims,
161        country_code_header.map(|header| header.to_string()),
162        params.provider,
163        &model,
164    )?;
165
166    let user_id = claims.user_id as i32;
167
168    if state.db.is_some() {
169        check_usage_limit(&state, params.provider, &model, &claims).await?;
170    }
171
172    match params.provider {
173        LanguageModelProvider::Anthropic => {
174            let api_key = state
175                .config
176                .anthropic_api_key
177                .as_ref()
178                .context("no Anthropic AI API key configured on the server")?;
179
180            let mut request: anthropic::Request =
181                serde_json::from_str(&params.provider_request.get())?;
182
183            // Parse the model, throw away the version that was included, and then set a specific
184            // version that we control on the server.
185            // Right now, we use the version that's defined in `model.id()`, but we will likely
186            // want to change this code once a new version of an Anthropic model is released,
187            // so that users can use the new version, without having to update Zed.
188            request.model = match anthropic::Model::from_id(&request.model) {
189                Ok(model) => model.id().to_string(),
190                Err(_) => request.model,
191            };
192
193            let chunks = anthropic::stream_completion(
194                &state.http_client,
195                anthropic::ANTHROPIC_API_URL,
196                api_key,
197                request,
198                None,
199            )
200            .await?;
201
202            let mut recorder = state.db.clone().map(|db| UsageRecorder {
203                db,
204                executor: state.executor.clone(),
205                user_id,
206                provider: params.provider,
207                model,
208                token_count: 0,
209            });
210
211            let stream = chunks.map(move |event| {
212                let mut buffer = Vec::new();
213                event.map(|chunk| {
214                    match &chunk {
215                        anthropic::Event::MessageStart {
216                            message: anthropic::Response { usage, .. },
217                        }
218                        | anthropic::Event::MessageDelta { usage, .. } => {
219                            if let Some(recorder) = &mut recorder {
220                                recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
221                                recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
222                            }
223                        }
224                        _ => {}
225                    }
226
227                    buffer.clear();
228                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
229                    buffer.push(b'\n');
230                    buffer
231                })
232            });
233
234            Ok(Response::new(Body::wrap_stream(stream)))
235        }
236        LanguageModelProvider::OpenAi => {
237            let api_key = state
238                .config
239                .openai_api_key
240                .as_ref()
241                .context("no OpenAI API key configured on the server")?;
242            let chunks = open_ai::stream_completion(
243                &state.http_client,
244                open_ai::OPEN_AI_API_URL,
245                api_key,
246                serde_json::from_str(&params.provider_request.get())?,
247                None,
248            )
249            .await?;
250
251            let stream = chunks.map(|event| {
252                let mut buffer = Vec::new();
253                event.map(|chunk| {
254                    buffer.clear();
255                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
256                    buffer.push(b'\n');
257                    buffer
258                })
259            });
260
261            Ok(Response::new(Body::wrap_stream(stream)))
262        }
263        LanguageModelProvider::Google => {
264            let api_key = state
265                .config
266                .google_ai_api_key
267                .as_ref()
268                .context("no Google AI API key configured on the server")?;
269            let chunks = google_ai::stream_generate_content(
270                &state.http_client,
271                google_ai::API_URL,
272                api_key,
273                serde_json::from_str(&params.provider_request.get())?,
274            )
275            .await?;
276
277            let stream = chunks.map(|event| {
278                let mut buffer = Vec::new();
279                event.map(|chunk| {
280                    buffer.clear();
281                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
282                    buffer.push(b'\n');
283                    buffer
284                })
285            });
286
287            Ok(Response::new(Body::wrap_stream(stream)))
288        }
289        LanguageModelProvider::Zed => {
290            let api_key = state
291                .config
292                .qwen2_7b_api_key
293                .as_ref()
294                .context("no Qwen2-7B API key configured on the server")?;
295            let api_url = state
296                .config
297                .qwen2_7b_api_url
298                .as_ref()
299                .context("no Qwen2-7B URL configured on the server")?;
300            let chunks = open_ai::stream_completion(
301                &state.http_client,
302                &api_url,
303                api_key,
304                serde_json::from_str(&params.provider_request.get())?,
305                None,
306            )
307            .await?;
308
309            let stream = chunks.map(|event| {
310                let mut buffer = Vec::new();
311                event.map(|chunk| {
312                    buffer.clear();
313                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
314                    buffer.push(b'\n');
315                    buffer
316                })
317            });
318
319            Ok(Response::new(Body::wrap_stream(stream)))
320        }
321    }
322}
323
324fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
325    match provider {
326        LanguageModelProvider::Anthropic => {
327            for prefix in &[
328                "claude-3-5-sonnet",
329                "claude-3-haiku",
330                "claude-3-opus",
331                "claude-3-sonnet",
332            ] {
333                if name.starts_with(prefix) {
334                    return prefix.to_string();
335                }
336            }
337        }
338        LanguageModelProvider::OpenAi => {}
339        LanguageModelProvider::Google => {}
340        LanguageModelProvider::Zed => {}
341    }
342
343    name
344}
345
346async fn check_usage_limit(
347    state: &Arc<LlmState>,
348    provider: LanguageModelProvider,
349    model_name: &str,
350    claims: &LlmTokenClaims,
351) -> Result<()> {
352    let db = state
353        .db
354        .as_ref()
355        .ok_or_else(|| anyhow!("LLM database not configured"))?;
356    let model = db.model(provider, model_name)?;
357    let usage = db
358        .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
359        .await?;
360
361    let active_users = state.get_active_user_count().await?;
362
363    let per_user_max_requests_per_minute =
364        model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
365    let per_user_max_tokens_per_minute =
366        model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
367    let per_user_max_tokens_per_day =
368        model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
369
370    let checks = [
371        (
372            usage.requests_this_minute,
373            per_user_max_requests_per_minute,
374            "requests per minute",
375        ),
376        (
377            usage.tokens_this_minute,
378            per_user_max_tokens_per_minute,
379            "tokens per minute",
380        ),
381        (
382            usage.tokens_this_day,
383            per_user_max_tokens_per_day,
384            "tokens per day",
385        ),
386    ];
387
388    for (usage, limit, resource) in checks {
389        if usage > limit {
390            return Err(Error::http(
391                StatusCode::TOO_MANY_REQUESTS,
392                format!("Rate limit exceeded. Maximum {} reached.", resource),
393            ));
394        }
395    }
396
397    Ok(())
398}
399struct UsageRecorder {
400    db: Arc<LlmDatabase>,
401    executor: Executor,
402    user_id: i32,
403    provider: LanguageModelProvider,
404    model: String,
405    token_count: usize,
406}
407
408impl Drop for UsageRecorder {
409    fn drop(&mut self) {
410        let db = self.db.clone();
411        let user_id = self.user_id;
412        let provider = self.provider;
413        let model = std::mem::take(&mut self.model);
414        let token_count = self.token_count;
415        self.executor.spawn_detached(async move {
416            db.record_usage(user_id, provider, &model, token_count, Utc::now())
417                .await
418                .log_err();
419        })
420    }
421}