1mod authorization;
2pub mod db;
3mod telemetry;
4mod token;
5
6use crate::{
7 api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
8 Config, Error, Result,
9};
10use anyhow::{anyhow, Context as _};
11use authorization::authorize_access_to_language_model;
12use axum::{
13 body::Body,
14 http::{self, HeaderName, HeaderValue, Request, StatusCode},
15 middleware::{self, Next},
16 response::{IntoResponse, Response},
17 routing::post,
18 Extension, Json, Router, TypedHeader,
19};
20use chrono::{DateTime, Duration, Utc};
21use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
22use futures::{Stream, StreamExt as _};
23use http_client::IsahcHttpClient;
24use rpc::{
25 proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
26};
27use std::{
28 pin::Pin,
29 sync::Arc,
30 task::{Context, Poll},
31};
32use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
33use tokio::sync::RwLock;
34use util::ResultExt;
35
36pub use token::*;
37
38pub struct LlmState {
39 pub config: Config,
40 pub executor: Executor,
41 pub db: Arc<LlmDatabase>,
42 pub http_client: IsahcHttpClient,
43 pub clickhouse_client: Option<clickhouse::Client>,
44 active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
45}
46
47const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
48
49impl LlmState {
50 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
51 let database_url = config
52 .llm_database_url
53 .as_ref()
54 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
55 let max_connections = config
56 .llm_database_max_connections
57 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
58
59 let mut db_options = db::ConnectOptions::new(database_url);
60 db_options.max_connections(max_connections);
61 let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
62 db.initialize().await?;
63
64 let db = Arc::new(db);
65
66 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
67 let http_client = IsahcHttpClient::builder()
68 .default_header("User-Agent", user_agent)
69 .build()
70 .context("failed to construct http client")?;
71
72 let initial_active_user_count =
73 Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
74
75 let this = Self {
76 executor,
77 db,
78 http_client,
79 clickhouse_client: config
80 .clickhouse_url
81 .as_ref()
82 .and_then(|_| build_clickhouse_client(&config).log_err()),
83 active_user_count: RwLock::new(initial_active_user_count),
84 config,
85 };
86
87 Ok(Arc::new(this))
88 }
89
90 pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
91 let now = Utc::now();
92
93 if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
94 if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
95 return Ok(*count);
96 }
97 }
98
99 let mut cache = self.active_user_count.write().await;
100 let new_count = self.db.get_active_user_count(now).await?;
101 *cache = Some((now, new_count));
102 Ok(new_count)
103 }
104}
105
106pub fn routes() -> Router<(), Body> {
107 Router::new()
108 .route("/completion", post(perform_completion))
109 .layer(middleware::from_fn(validate_api_token))
110}
111
112async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
113 let token = req
114 .headers()
115 .get(http::header::AUTHORIZATION)
116 .and_then(|header| header.to_str().ok())
117 .ok_or_else(|| {
118 Error::http(
119 StatusCode::BAD_REQUEST,
120 "missing authorization header".to_string(),
121 )
122 })?
123 .strip_prefix("Bearer ")
124 .ok_or_else(|| {
125 Error::http(
126 StatusCode::BAD_REQUEST,
127 "invalid authorization header".to_string(),
128 )
129 })?;
130
131 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
132 match LlmTokenClaims::validate(&token, &state.config) {
133 Ok(claims) => {
134 if state.db.is_access_token_revoked(&claims.jti).await? {
135 return Err(Error::http(
136 StatusCode::UNAUTHORIZED,
137 "unauthorized".to_string(),
138 ));
139 }
140
141 tracing::Span::current()
142 .record("user_id", claims.user_id)
143 .record("login", claims.github_user_login.clone())
144 .record("authn.jti", &claims.jti)
145 .record("is_staff", &claims.is_staff);
146
147 req.extensions_mut().insert(claims);
148 Ok::<_, Error>(next.run(req).await.into_response())
149 }
150 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
151 StatusCode::UNAUTHORIZED,
152 "unauthorized".to_string(),
153 [(
154 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
155 HeaderValue::from_static("true"),
156 )]
157 .into_iter()
158 .collect(),
159 )),
160 Err(_err) => Err(Error::http(
161 StatusCode::UNAUTHORIZED,
162 "unauthorized".to_string(),
163 )),
164 }
165}
166
167async fn perform_completion(
168 Extension(state): Extension<Arc<LlmState>>,
169 Extension(claims): Extension<LlmTokenClaims>,
170 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
171 Json(params): Json<PerformCompletionParams>,
172) -> Result<impl IntoResponse> {
173 let model = normalize_model_name(
174 state.db.model_names_for_provider(params.provider),
175 params.model,
176 );
177
178 authorize_access_to_language_model(
179 &state.config,
180 &claims,
181 country_code_header.map(|header| header.to_string()),
182 params.provider,
183 &model,
184 )?;
185
186 check_usage_limit(&state, params.provider, &model, &claims).await?;
187
188 let stream = match params.provider {
189 LanguageModelProvider::Anthropic => {
190 let api_key = if claims.is_staff {
191 state
192 .config
193 .anthropic_staff_api_key
194 .as_ref()
195 .context("no Anthropic AI staff API key configured on the server")?
196 } else {
197 state
198 .config
199 .anthropic_api_key
200 .as_ref()
201 .context("no Anthropic AI API key configured on the server")?
202 };
203
204 let mut request: anthropic::Request =
205 serde_json::from_str(¶ms.provider_request.get())?;
206
207 // Override the model on the request with the latest version of the model that is
208 // known to the server.
209 //
210 // Right now, we use the version that's defined in `model.id()`, but we will likely
211 // want to change this code once a new version of an Anthropic model is released,
212 // so that users can use the new version, without having to update Zed.
213 request.model = match model.as_str() {
214 "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
215 "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
216 "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
217 "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
218 _ => request.model,
219 };
220
221 let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
222 &state.http_client,
223 anthropic::ANTHROPIC_API_URL,
224 api_key,
225 request,
226 None,
227 )
228 .await
229 .map_err(|err| match err {
230 anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
231 Some(anthropic::ApiErrorCode::RateLimitError) => Error::http(
232 StatusCode::TOO_MANY_REQUESTS,
233 "Upstream Anthropic rate limit exceeded.".to_string(),
234 ),
235 Some(anthropic::ApiErrorCode::InvalidRequestError) => {
236 Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
237 }
238 Some(anthropic::ApiErrorCode::OverloadedError) => {
239 Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
240 }
241 Some(_) => {
242 Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
243 }
244 None => Error::Internal(anyhow!(err)),
245 },
246 anthropic::AnthropicError::Other(err) => Error::Internal(err),
247 })?;
248
249 if let Some(rate_limit_info) = rate_limit_info {
250 tracing::info!(
251 target: "upstream rate limit",
252 is_staff = claims.is_staff,
253 provider = params.provider.to_string(),
254 model = model,
255 tokens_remaining = rate_limit_info.tokens_remaining,
256 requests_remaining = rate_limit_info.requests_remaining,
257 requests_reset = ?rate_limit_info.requests_reset,
258 tokens_reset = ?rate_limit_info.tokens_reset,
259 );
260 }
261
262 chunks
263 .map(move |event| {
264 let chunk = event?;
265 let (input_tokens, output_tokens) = match &chunk {
266 anthropic::Event::MessageStart {
267 message: anthropic::Response { usage, .. },
268 }
269 | anthropic::Event::MessageDelta { usage, .. } => (
270 usage.input_tokens.unwrap_or(0) as usize,
271 usage.output_tokens.unwrap_or(0) as usize,
272 ),
273 _ => (0, 0),
274 };
275
276 anyhow::Ok((
277 serde_json::to_vec(&chunk).unwrap(),
278 input_tokens,
279 output_tokens,
280 ))
281 })
282 .boxed()
283 }
284 LanguageModelProvider::OpenAi => {
285 let api_key = state
286 .config
287 .openai_api_key
288 .as_ref()
289 .context("no OpenAI API key configured on the server")?;
290 let chunks = open_ai::stream_completion(
291 &state.http_client,
292 open_ai::OPEN_AI_API_URL,
293 api_key,
294 serde_json::from_str(¶ms.provider_request.get())?,
295 None,
296 )
297 .await?;
298
299 chunks
300 .map(|event| {
301 event.map(|chunk| {
302 let input_tokens =
303 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
304 let output_tokens =
305 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
306 (
307 serde_json::to_vec(&chunk).unwrap(),
308 input_tokens,
309 output_tokens,
310 )
311 })
312 })
313 .boxed()
314 }
315 LanguageModelProvider::Google => {
316 let api_key = state
317 .config
318 .google_ai_api_key
319 .as_ref()
320 .context("no Google AI API key configured on the server")?;
321 let chunks = google_ai::stream_generate_content(
322 &state.http_client,
323 google_ai::API_URL,
324 api_key,
325 serde_json::from_str(¶ms.provider_request.get())?,
326 )
327 .await?;
328
329 chunks
330 .map(|event| {
331 event.map(|chunk| {
332 // TODO - implement token counting for Google AI
333 let input_tokens = 0;
334 let output_tokens = 0;
335 (
336 serde_json::to_vec(&chunk).unwrap(),
337 input_tokens,
338 output_tokens,
339 )
340 })
341 })
342 .boxed()
343 }
344 LanguageModelProvider::Zed => {
345 let api_key = state
346 .config
347 .qwen2_7b_api_key
348 .as_ref()
349 .context("no Qwen2-7B API key configured on the server")?;
350 let api_url = state
351 .config
352 .qwen2_7b_api_url
353 .as_ref()
354 .context("no Qwen2-7B URL configured on the server")?;
355 let chunks = open_ai::stream_completion(
356 &state.http_client,
357 &api_url,
358 api_key,
359 serde_json::from_str(¶ms.provider_request.get())?,
360 None,
361 )
362 .await?;
363
364 chunks
365 .map(|event| {
366 event.map(|chunk| {
367 let input_tokens =
368 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
369 let output_tokens =
370 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
371 (
372 serde_json::to_vec(&chunk).unwrap(),
373 input_tokens,
374 output_tokens,
375 )
376 })
377 })
378 .boxed()
379 }
380 };
381
382 Ok(Response::new(Body::wrap_stream(TokenCountingStream {
383 state,
384 claims,
385 provider: params.provider,
386 model,
387 input_tokens: 0,
388 output_tokens: 0,
389 inner_stream: stream,
390 })))
391}
392
393fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
394 if let Some(known_model_name) = known_models
395 .iter()
396 .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
397 .max_by_key(|known_model_name| known_model_name.len())
398 {
399 known_model_name.to_string()
400 } else {
401 name
402 }
403}
404
405async fn check_usage_limit(
406 state: &Arc<LlmState>,
407 provider: LanguageModelProvider,
408 model_name: &str,
409 claims: &LlmTokenClaims,
410) -> Result<()> {
411 let model = state.db.model(provider, model_name)?;
412 let usage = state
413 .db
414 .get_usage(
415 UserId::from_proto(claims.user_id),
416 provider,
417 model_name,
418 Utc::now(),
419 )
420 .await?;
421
422 let active_users = state.get_active_user_count().await?;
423
424 let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
425 let users_in_recent_days = active_users.users_in_recent_days.max(1);
426
427 let per_user_max_requests_per_minute =
428 model.max_requests_per_minute as usize / users_in_recent_minutes;
429 let per_user_max_tokens_per_minute =
430 model.max_tokens_per_minute as usize / users_in_recent_minutes;
431 let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
432
433 let checks = [
434 (
435 usage.requests_this_minute,
436 per_user_max_requests_per_minute,
437 UsageMeasure::RequestsPerMinute,
438 ),
439 (
440 usage.tokens_this_minute,
441 per_user_max_tokens_per_minute,
442 UsageMeasure::TokensPerMinute,
443 ),
444 (
445 usage.tokens_this_day,
446 per_user_max_tokens_per_day,
447 UsageMeasure::TokensPerDay,
448 ),
449 ];
450
451 for (used, limit, usage_measure) in checks {
452 // Temporarily bypass rate-limiting for staff members.
453 if claims.is_staff {
454 continue;
455 }
456
457 if used > limit {
458 let resource = match usage_measure {
459 UsageMeasure::RequestsPerMinute => "requests_per_minute",
460 UsageMeasure::TokensPerMinute => "tokens_per_minute",
461 UsageMeasure::TokensPerDay => "tokens_per_day",
462 _ => "",
463 };
464
465 if let Some(client) = state.clickhouse_client.as_ref() {
466 report_llm_rate_limit(
467 client,
468 LlmRateLimitEventRow {
469 time: Utc::now().timestamp_millis(),
470 user_id: claims.user_id as i32,
471 is_staff: claims.is_staff,
472 plan: match claims.plan {
473 Plan::Free => "free".to_string(),
474 Plan::ZedPro => "zed_pro".to_string(),
475 },
476 model: model.name.clone(),
477 provider: provider.to_string(),
478 usage_measure: resource.to_string(),
479 requests_this_minute: usage.requests_this_minute as u64,
480 tokens_this_minute: usage.tokens_this_minute as u64,
481 tokens_this_day: usage.tokens_this_day as u64,
482 users_in_recent_minutes: users_in_recent_minutes as u64,
483 users_in_recent_days: users_in_recent_days as u64,
484 max_requests_per_minute: per_user_max_requests_per_minute as u64,
485 max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
486 max_tokens_per_day: per_user_max_tokens_per_day as u64,
487 },
488 )
489 .await
490 .log_err();
491 }
492
493 return Err(Error::http(
494 StatusCode::TOO_MANY_REQUESTS,
495 format!("Rate limit exceeded. Maximum {} reached.", resource),
496 ));
497 }
498 }
499
500 Ok(())
501}
502
503struct TokenCountingStream<S> {
504 state: Arc<LlmState>,
505 claims: LlmTokenClaims,
506 provider: LanguageModelProvider,
507 model: String,
508 input_tokens: usize,
509 output_tokens: usize,
510 inner_stream: S,
511}
512
513impl<S> Stream for TokenCountingStream<S>
514where
515 S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
516{
517 type Item = Result<Vec<u8>, anyhow::Error>;
518
519 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
520 match Pin::new(&mut self.inner_stream).poll_next(cx) {
521 Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
522 bytes.push(b'\n');
523 self.input_tokens += input_tokens;
524 self.output_tokens += output_tokens;
525 Poll::Ready(Some(Ok(bytes)))
526 }
527 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
528 Poll::Ready(None) => Poll::Ready(None),
529 Poll::Pending => Poll::Pending,
530 }
531 }
532}
533
534impl<S> Drop for TokenCountingStream<S> {
535 fn drop(&mut self) {
536 let state = self.state.clone();
537 let claims = self.claims.clone();
538 let provider = self.provider;
539 let model = std::mem::take(&mut self.model);
540 let input_token_count = self.input_tokens;
541 let output_token_count = self.output_tokens;
542 self.state.executor.spawn_detached(async move {
543 let usage = state
544 .db
545 .record_usage(
546 UserId::from_proto(claims.user_id),
547 claims.is_staff,
548 provider,
549 &model,
550 input_token_count,
551 output_token_count,
552 Utc::now(),
553 )
554 .await
555 .log_err();
556
557 if let Some(usage) = usage {
558 tracing::info!(
559 target: "user usage",
560 user_id = claims.user_id,
561 login = claims.github_user_login,
562 authn.jti = claims.jti,
563 is_staff = claims.is_staff,
564 requests_this_minute = usage.requests_this_minute,
565 tokens_this_minute = usage.tokens_this_minute,
566 );
567
568 if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
569 report_llm_usage(
570 clickhouse_client,
571 LlmUsageEventRow {
572 time: Utc::now().timestamp_millis(),
573 user_id: claims.user_id as i32,
574 is_staff: claims.is_staff,
575 plan: match claims.plan {
576 Plan::Free => "free".to_string(),
577 Plan::ZedPro => "zed_pro".to_string(),
578 },
579 model,
580 provider: provider.to_string(),
581 input_token_count: input_token_count as u64,
582 output_token_count: output_token_count as u64,
583 requests_this_minute: usage.requests_this_minute as u64,
584 tokens_this_minute: usage.tokens_this_minute as u64,
585 tokens_this_day: usage.tokens_this_day as u64,
586 input_tokens_this_month: usage.input_tokens_this_month as u64,
587 output_tokens_this_month: usage.output_tokens_this_month as u64,
588 spending_this_month: usage.spending_this_month as u64,
589 lifetime_spending: usage.lifetime_spending as u64,
590 },
591 )
592 .await
593 .log_err();
594 }
595 }
596 })
597 }
598}
599
600pub fn log_usage_periodically(state: Arc<LlmState>) {
601 state.executor.clone().spawn_detached(async move {
602 loop {
603 state
604 .executor
605 .sleep(std::time::Duration::from_secs(30))
606 .await;
607
608 let Some(usages) = state
609 .db
610 .get_application_wide_usages_by_model(Utc::now())
611 .await
612 .log_err()
613 else {
614 continue;
615 };
616
617 for usage in usages {
618 tracing::info!(
619 target: "computed usage",
620 provider = usage.provider.to_string(),
621 model = usage.model,
622 requests_this_minute = usage.requests_this_minute,
623 tokens_this_minute = usage.tokens_this_minute,
624 );
625 }
626 }
627 })
628}