1mod authorization;
2pub mod db;
3mod telemetry;
4mod token;
5
6use crate::{
7 api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
8 Config, Error, Result,
9};
10use anyhow::{anyhow, Context as _};
11use authorization::authorize_access_to_language_model;
12use axum::{
13 body::Body,
14 http::{self, HeaderName, HeaderValue, Request, StatusCode},
15 middleware::{self, Next},
16 response::{IntoResponse, Response},
17 routing::post,
18 Extension, Json, Router, TypedHeader,
19};
20use chrono::{DateTime, Duration, Utc};
21use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
22use futures::{Stream, StreamExt as _};
23use http_client::IsahcHttpClient;
24use rpc::{
25 proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
26};
27use std::{
28 pin::Pin,
29 sync::Arc,
30 task::{Context, Poll},
31};
32use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
33use tokio::sync::RwLock;
34use util::ResultExt;
35
36pub use token::*;
37
38pub struct LlmState {
39 pub config: Config,
40 pub executor: Executor,
41 pub db: Arc<LlmDatabase>,
42 pub http_client: IsahcHttpClient,
43 pub clickhouse_client: Option<clickhouse::Client>,
44 active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
45}
46
47const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
48
49impl LlmState {
50 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
51 let database_url = config
52 .llm_database_url
53 .as_ref()
54 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
55 let max_connections = config
56 .llm_database_max_connections
57 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
58
59 let mut db_options = db::ConnectOptions::new(database_url);
60 db_options.max_connections(max_connections);
61 let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
62 db.initialize().await?;
63
64 let db = Arc::new(db);
65
66 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
67 let http_client = IsahcHttpClient::builder()
68 .default_header("User-Agent", user_agent)
69 .build()
70 .context("failed to construct http client")?;
71
72 let initial_active_user_count =
73 Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
74
75 let this = Self {
76 executor,
77 db,
78 http_client,
79 clickhouse_client: config
80 .clickhouse_url
81 .as_ref()
82 .and_then(|_| build_clickhouse_client(&config).log_err()),
83 active_user_count: RwLock::new(initial_active_user_count),
84 config,
85 };
86
87 Ok(Arc::new(this))
88 }
89
90 pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
91 let now = Utc::now();
92
93 if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
94 if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
95 return Ok(*count);
96 }
97 }
98
99 let mut cache = self.active_user_count.write().await;
100 let new_count = self.db.get_active_user_count(now).await?;
101 *cache = Some((now, new_count));
102 Ok(new_count)
103 }
104}
105
106pub fn routes() -> Router<(), Body> {
107 Router::new()
108 .route("/completion", post(perform_completion))
109 .layer(middleware::from_fn(validate_api_token))
110}
111
112async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
113 let token = req
114 .headers()
115 .get(http::header::AUTHORIZATION)
116 .and_then(|header| header.to_str().ok())
117 .ok_or_else(|| {
118 Error::http(
119 StatusCode::BAD_REQUEST,
120 "missing authorization header".to_string(),
121 )
122 })?
123 .strip_prefix("Bearer ")
124 .ok_or_else(|| {
125 Error::http(
126 StatusCode::BAD_REQUEST,
127 "invalid authorization header".to_string(),
128 )
129 })?;
130
131 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
132 match LlmTokenClaims::validate(&token, &state.config) {
133 Ok(claims) => {
134 if state.db.is_access_token_revoked(&claims.jti).await? {
135 return Err(Error::http(
136 StatusCode::UNAUTHORIZED,
137 "unauthorized".to_string(),
138 ));
139 }
140
141 tracing::Span::current()
142 .record("user_id", claims.user_id)
143 .record("login", claims.github_user_login.clone())
144 .record("authn.jti", &claims.jti);
145
146 req.extensions_mut().insert(claims);
147 Ok::<_, Error>(next.run(req).await.into_response())
148 }
149 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
150 StatusCode::UNAUTHORIZED,
151 "unauthorized".to_string(),
152 [(
153 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
154 HeaderValue::from_static("true"),
155 )]
156 .into_iter()
157 .collect(),
158 )),
159 Err(_err) => Err(Error::http(
160 StatusCode::UNAUTHORIZED,
161 "unauthorized".to_string(),
162 )),
163 }
164}
165
166async fn perform_completion(
167 Extension(state): Extension<Arc<LlmState>>,
168 Extension(claims): Extension<LlmTokenClaims>,
169 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
170 Json(params): Json<PerformCompletionParams>,
171) -> Result<impl IntoResponse> {
172 let model = normalize_model_name(params.provider, params.model);
173
174 authorize_access_to_language_model(
175 &state.config,
176 &claims,
177 country_code_header.map(|header| header.to_string()),
178 params.provider,
179 &model,
180 )?;
181
182 check_usage_limit(&state, params.provider, &model, &claims).await?;
183
184 let stream = match params.provider {
185 LanguageModelProvider::Anthropic => {
186 let api_key = if claims.is_staff {
187 state
188 .config
189 .anthropic_staff_api_key
190 .as_ref()
191 .context("no Anthropic AI staff API key configured on the server")?
192 } else {
193 state
194 .config
195 .anthropic_api_key
196 .as_ref()
197 .context("no Anthropic AI API key configured on the server")?
198 };
199
200 let mut request: anthropic::Request =
201 serde_json::from_str(¶ms.provider_request.get())?;
202
203 // Parse the model, throw away the version that was included, and then set a specific
204 // version that we control on the server.
205 // Right now, we use the version that's defined in `model.id()`, but we will likely
206 // want to change this code once a new version of an Anthropic model is released,
207 // so that users can use the new version, without having to update Zed.
208 request.model = match anthropic::Model::from_id(&request.model) {
209 Ok(model) => model.id().to_string(),
210 Err(_) => request.model,
211 };
212
213 let chunks = anthropic::stream_completion(
214 &state.http_client,
215 anthropic::ANTHROPIC_API_URL,
216 api_key,
217 request,
218 None,
219 )
220 .await
221 .map_err(|err| match err {
222 anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
223 Some(anthropic::ApiErrorCode::RateLimitError) => Error::http(
224 StatusCode::TOO_MANY_REQUESTS,
225 "Upstream Anthropic rate limit exceeded.".to_string(),
226 ),
227 Some(anthropic::ApiErrorCode::InvalidRequestError) => {
228 Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
229 }
230 Some(anthropic::ApiErrorCode::OverloadedError) => {
231 Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
232 }
233 Some(_) => {
234 Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
235 }
236 None => Error::Internal(anyhow!(err)),
237 },
238 anthropic::AnthropicError::Other(err) => Error::Internal(err),
239 })?;
240
241 chunks
242 .map(move |event| {
243 let chunk = event?;
244 let (input_tokens, output_tokens) = match &chunk {
245 anthropic::Event::MessageStart {
246 message: anthropic::Response { usage, .. },
247 }
248 | anthropic::Event::MessageDelta { usage, .. } => (
249 usage.input_tokens.unwrap_or(0) as usize,
250 usage.output_tokens.unwrap_or(0) as usize,
251 ),
252 _ => (0, 0),
253 };
254
255 anyhow::Ok((
256 serde_json::to_vec(&chunk).unwrap(),
257 input_tokens,
258 output_tokens,
259 ))
260 })
261 .boxed()
262 }
263 LanguageModelProvider::OpenAi => {
264 let api_key = state
265 .config
266 .openai_api_key
267 .as_ref()
268 .context("no OpenAI API key configured on the server")?;
269 let chunks = open_ai::stream_completion(
270 &state.http_client,
271 open_ai::OPEN_AI_API_URL,
272 api_key,
273 serde_json::from_str(¶ms.provider_request.get())?,
274 None,
275 )
276 .await?;
277
278 chunks
279 .map(|event| {
280 event.map(|chunk| {
281 let input_tokens =
282 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
283 let output_tokens =
284 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
285 (
286 serde_json::to_vec(&chunk).unwrap(),
287 input_tokens,
288 output_tokens,
289 )
290 })
291 })
292 .boxed()
293 }
294 LanguageModelProvider::Google => {
295 let api_key = state
296 .config
297 .google_ai_api_key
298 .as_ref()
299 .context("no Google AI API key configured on the server")?;
300 let chunks = google_ai::stream_generate_content(
301 &state.http_client,
302 google_ai::API_URL,
303 api_key,
304 serde_json::from_str(¶ms.provider_request.get())?,
305 )
306 .await?;
307
308 chunks
309 .map(|event| {
310 event.map(|chunk| {
311 // TODO - implement token counting for Google AI
312 let input_tokens = 0;
313 let output_tokens = 0;
314 (
315 serde_json::to_vec(&chunk).unwrap(),
316 input_tokens,
317 output_tokens,
318 )
319 })
320 })
321 .boxed()
322 }
323 LanguageModelProvider::Zed => {
324 let api_key = state
325 .config
326 .qwen2_7b_api_key
327 .as_ref()
328 .context("no Qwen2-7B API key configured on the server")?;
329 let api_url = state
330 .config
331 .qwen2_7b_api_url
332 .as_ref()
333 .context("no Qwen2-7B URL configured on the server")?;
334 let chunks = open_ai::stream_completion(
335 &state.http_client,
336 &api_url,
337 api_key,
338 serde_json::from_str(¶ms.provider_request.get())?,
339 None,
340 )
341 .await?;
342
343 chunks
344 .map(|event| {
345 event.map(|chunk| {
346 let input_tokens =
347 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
348 let output_tokens =
349 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
350 (
351 serde_json::to_vec(&chunk).unwrap(),
352 input_tokens,
353 output_tokens,
354 )
355 })
356 })
357 .boxed()
358 }
359 };
360
361 Ok(Response::new(Body::wrap_stream(TokenCountingStream {
362 state,
363 claims,
364 provider: params.provider,
365 model,
366 input_tokens: 0,
367 output_tokens: 0,
368 inner_stream: stream,
369 })))
370}
371
372fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
373 let prefixes: &[_] = match provider {
374 LanguageModelProvider::Anthropic => &[
375 "claude-3-5-sonnet",
376 "claude-3-haiku",
377 "claude-3-opus",
378 "claude-3-sonnet",
379 ],
380 LanguageModelProvider::OpenAi => &[
381 "gpt-3.5-turbo",
382 "gpt-4-turbo-preview",
383 "gpt-4o-mini",
384 "gpt-4o",
385 "gpt-4",
386 ],
387 LanguageModelProvider::Google => &[],
388 LanguageModelProvider::Zed => &[],
389 };
390
391 if let Some(prefix) = prefixes
392 .iter()
393 .filter(|&&prefix| name.starts_with(prefix))
394 .max_by_key(|&&prefix| prefix.len())
395 {
396 prefix.to_string()
397 } else {
398 name
399 }
400}
401
402async fn check_usage_limit(
403 state: &Arc<LlmState>,
404 provider: LanguageModelProvider,
405 model_name: &str,
406 claims: &LlmTokenClaims,
407) -> Result<()> {
408 let model = state.db.model(provider, model_name)?;
409 let usage = state
410 .db
411 .get_usage(
412 UserId::from_proto(claims.user_id),
413 provider,
414 model_name,
415 Utc::now(),
416 )
417 .await?;
418
419 let active_users = state.get_active_user_count().await?;
420
421 let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
422 let users_in_recent_days = active_users.users_in_recent_days.max(1);
423
424 let per_user_max_requests_per_minute =
425 model.max_requests_per_minute as usize / users_in_recent_minutes;
426 let per_user_max_tokens_per_minute =
427 model.max_tokens_per_minute as usize / users_in_recent_minutes;
428 let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
429
430 let checks = [
431 (
432 usage.requests_this_minute,
433 per_user_max_requests_per_minute,
434 UsageMeasure::RequestsPerMinute,
435 ),
436 (
437 usage.tokens_this_minute,
438 per_user_max_tokens_per_minute,
439 UsageMeasure::TokensPerMinute,
440 ),
441 (
442 usage.tokens_this_day,
443 per_user_max_tokens_per_day,
444 UsageMeasure::TokensPerDay,
445 ),
446 ];
447
448 for (used, limit, usage_measure) in checks {
449 // Temporarily bypass rate-limiting for staff members.
450 if claims.is_staff {
451 continue;
452 }
453
454 if used > limit {
455 let resource = match usage_measure {
456 UsageMeasure::RequestsPerMinute => "requests_per_minute",
457 UsageMeasure::TokensPerMinute => "tokens_per_minute",
458 UsageMeasure::TokensPerDay => "tokens_per_day",
459 _ => "",
460 };
461
462 if let Some(client) = state.clickhouse_client.as_ref() {
463 report_llm_rate_limit(
464 client,
465 LlmRateLimitEventRow {
466 time: Utc::now().timestamp_millis(),
467 user_id: claims.user_id as i32,
468 is_staff: claims.is_staff,
469 plan: match claims.plan {
470 Plan::Free => "free".to_string(),
471 Plan::ZedPro => "zed_pro".to_string(),
472 },
473 model: model.name.clone(),
474 provider: provider.to_string(),
475 usage_measure: resource.to_string(),
476 requests_this_minute: usage.requests_this_minute as u64,
477 tokens_this_minute: usage.tokens_this_minute as u64,
478 tokens_this_day: usage.tokens_this_day as u64,
479 users_in_recent_minutes: users_in_recent_minutes as u64,
480 users_in_recent_days: users_in_recent_days as u64,
481 max_requests_per_minute: per_user_max_requests_per_minute as u64,
482 max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
483 max_tokens_per_day: per_user_max_tokens_per_day as u64,
484 },
485 )
486 .await
487 .log_err();
488 }
489
490 return Err(Error::http(
491 StatusCode::TOO_MANY_REQUESTS,
492 format!("Rate limit exceeded. Maximum {} reached.", resource),
493 ));
494 }
495 }
496
497 Ok(())
498}
499
500struct TokenCountingStream<S> {
501 state: Arc<LlmState>,
502 claims: LlmTokenClaims,
503 provider: LanguageModelProvider,
504 model: String,
505 input_tokens: usize,
506 output_tokens: usize,
507 inner_stream: S,
508}
509
510impl<S> Stream for TokenCountingStream<S>
511where
512 S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
513{
514 type Item = Result<Vec<u8>, anyhow::Error>;
515
516 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
517 match Pin::new(&mut self.inner_stream).poll_next(cx) {
518 Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
519 bytes.push(b'\n');
520 self.input_tokens += input_tokens;
521 self.output_tokens += output_tokens;
522 Poll::Ready(Some(Ok(bytes)))
523 }
524 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
525 Poll::Ready(None) => Poll::Ready(None),
526 Poll::Pending => Poll::Pending,
527 }
528 }
529}
530
531impl<S> Drop for TokenCountingStream<S> {
532 fn drop(&mut self) {
533 let state = self.state.clone();
534 let claims = self.claims.clone();
535 let provider = self.provider;
536 let model = std::mem::take(&mut self.model);
537 let input_token_count = self.input_tokens;
538 let output_token_count = self.output_tokens;
539 self.state.executor.spawn_detached(async move {
540 let usage = state
541 .db
542 .record_usage(
543 UserId::from_proto(claims.user_id),
544 claims.is_staff,
545 provider,
546 &model,
547 input_token_count,
548 output_token_count,
549 Utc::now(),
550 )
551 .await
552 .log_err();
553
554 if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
555 report_llm_usage(
556 clickhouse_client,
557 LlmUsageEventRow {
558 time: Utc::now().timestamp_millis(),
559 user_id: claims.user_id as i32,
560 is_staff: claims.is_staff,
561 plan: match claims.plan {
562 Plan::Free => "free".to_string(),
563 Plan::ZedPro => "zed_pro".to_string(),
564 },
565 model,
566 provider: provider.to_string(),
567 input_token_count: input_token_count as u64,
568 output_token_count: output_token_count as u64,
569 requests_this_minute: usage.requests_this_minute as u64,
570 tokens_this_minute: usage.tokens_this_minute as u64,
571 tokens_this_day: usage.tokens_this_day as u64,
572 input_tokens_this_month: usage.input_tokens_this_month as u64,
573 output_tokens_this_month: usage.output_tokens_this_month as u64,
574 spending_this_month: usage.spending_this_month as u64,
575 lifetime_spending: usage.lifetime_spending as u64,
576 },
577 )
578 .await
579 .log_err();
580 }
581 })
582 }
583}