1mod authorization;
2pub mod db;
3mod telemetry;
4mod token;
5
6use crate::{
7 api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error,
8 Result,
9};
10use anyhow::{anyhow, Context as _};
11use authorization::authorize_access_to_language_model;
12use axum::{
13 body::Body,
14 http::{self, HeaderName, HeaderValue, Request, StatusCode},
15 middleware::{self, Next},
16 response::{IntoResponse, Response},
17 routing::post,
18 Extension, Json, Router, TypedHeader,
19};
20use chrono::{DateTime, Duration, Utc};
21use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
22use futures::{Stream, StreamExt as _};
23use http_client::IsahcHttpClient;
24use rpc::{
25 proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
26};
27use std::{
28 pin::Pin,
29 sync::Arc,
30 task::{Context, Poll},
31};
32use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
33use tokio::sync::RwLock;
34use util::ResultExt;
35
36pub use token::*;
37
38pub struct LlmState {
39 pub config: Config,
40 pub executor: Executor,
41 pub db: Arc<LlmDatabase>,
42 pub http_client: IsahcHttpClient,
43 pub clickhouse_client: Option<clickhouse::Client>,
44 active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
45}
46
47const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
48
49impl LlmState {
50 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
51 let database_url = config
52 .llm_database_url
53 .as_ref()
54 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
55 let max_connections = config
56 .llm_database_max_connections
57 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
58
59 let mut db_options = db::ConnectOptions::new(database_url);
60 db_options.max_connections(max_connections);
61 let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
62 db.initialize().await?;
63
64 let db = Arc::new(db);
65
66 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
67 let http_client = IsahcHttpClient::builder()
68 .default_header("User-Agent", user_agent)
69 .build()
70 .context("failed to construct http client")?;
71
72 let initial_active_user_count =
73 Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
74
75 let this = Self {
76 executor,
77 db,
78 http_client,
79 clickhouse_client: config
80 .clickhouse_url
81 .as_ref()
82 .and_then(|_| build_clickhouse_client(&config).log_err()),
83 active_user_count: RwLock::new(initial_active_user_count),
84 config,
85 };
86
87 Ok(Arc::new(this))
88 }
89
90 pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
91 let now = Utc::now();
92
93 if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
94 if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
95 return Ok(*count);
96 }
97 }
98
99 let mut cache = self.active_user_count.write().await;
100 let new_count = self.db.get_active_user_count(now).await?;
101 *cache = Some((now, new_count));
102 Ok(new_count)
103 }
104}
105
106pub fn routes() -> Router<(), Body> {
107 Router::new()
108 .route("/completion", post(perform_completion))
109 .layer(middleware::from_fn(validate_api_token))
110}
111
112async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
113 let token = req
114 .headers()
115 .get(http::header::AUTHORIZATION)
116 .and_then(|header| header.to_str().ok())
117 .ok_or_else(|| {
118 Error::http(
119 StatusCode::BAD_REQUEST,
120 "missing authorization header".to_string(),
121 )
122 })?
123 .strip_prefix("Bearer ")
124 .ok_or_else(|| {
125 Error::http(
126 StatusCode::BAD_REQUEST,
127 "invalid authorization header".to_string(),
128 )
129 })?;
130
131 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
132 match LlmTokenClaims::validate(&token, &state.config) {
133 Ok(claims) => {
134 req.extensions_mut().insert(claims);
135 Ok::<_, Error>(next.run(req).await.into_response())
136 }
137 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
138 StatusCode::UNAUTHORIZED,
139 "unauthorized".to_string(),
140 [(
141 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
142 HeaderValue::from_static("true"),
143 )]
144 .into_iter()
145 .collect(),
146 )),
147 Err(_err) => Err(Error::http(
148 StatusCode::UNAUTHORIZED,
149 "unauthorized".to_string(),
150 )),
151 }
152}
153
154async fn perform_completion(
155 Extension(state): Extension<Arc<LlmState>>,
156 Extension(claims): Extension<LlmTokenClaims>,
157 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
158 Json(params): Json<PerformCompletionParams>,
159) -> Result<impl IntoResponse> {
160 let model = normalize_model_name(params.provider, params.model);
161
162 authorize_access_to_language_model(
163 &state.config,
164 &claims,
165 country_code_header.map(|header| header.to_string()),
166 params.provider,
167 &model,
168 )?;
169
170 check_usage_limit(&state, params.provider, &model, &claims).await?;
171
172 let stream = match params.provider {
173 LanguageModelProvider::Anthropic => {
174 let api_key = if claims.is_staff {
175 state
176 .config
177 .anthropic_staff_api_key
178 .as_ref()
179 .context("no Anthropic AI staff API key configured on the server")?
180 } else {
181 state
182 .config
183 .anthropic_api_key
184 .as_ref()
185 .context("no Anthropic AI API key configured on the server")?
186 };
187
188 let mut request: anthropic::Request =
189 serde_json::from_str(¶ms.provider_request.get())?;
190
191 // Parse the model, throw away the version that was included, and then set a specific
192 // version that we control on the server.
193 // Right now, we use the version that's defined in `model.id()`, but we will likely
194 // want to change this code once a new version of an Anthropic model is released,
195 // so that users can use the new version, without having to update Zed.
196 request.model = match anthropic::Model::from_id(&request.model) {
197 Ok(model) => model.id().to_string(),
198 Err(_) => request.model,
199 };
200
201 let chunks = anthropic::stream_completion(
202 &state.http_client,
203 anthropic::ANTHROPIC_API_URL,
204 api_key,
205 request,
206 None,
207 )
208 .await
209 .map_err(|err| match err {
210 anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
211 Some(anthropic::ApiErrorCode::RateLimitError) => Error::http(
212 StatusCode::TOO_MANY_REQUESTS,
213 "Upstream Anthropic rate limit exceeded.".to_string(),
214 ),
215 Some(anthropic::ApiErrorCode::InvalidRequestError) => {
216 Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
217 }
218 Some(anthropic::ApiErrorCode::OverloadedError) => {
219 Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
220 }
221 Some(_) => {
222 Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
223 }
224 None => Error::Internal(anyhow!(err)),
225 },
226 anthropic::AnthropicError::Other(err) => Error::Internal(err),
227 })?;
228
229 chunks
230 .map(move |event| {
231 let chunk = event?;
232 let (input_tokens, output_tokens) = match &chunk {
233 anthropic::Event::MessageStart {
234 message: anthropic::Response { usage, .. },
235 }
236 | anthropic::Event::MessageDelta { usage, .. } => (
237 usage.input_tokens.unwrap_or(0) as usize,
238 usage.output_tokens.unwrap_or(0) as usize,
239 ),
240 _ => (0, 0),
241 };
242
243 anyhow::Ok((
244 serde_json::to_vec(&chunk).unwrap(),
245 input_tokens,
246 output_tokens,
247 ))
248 })
249 .boxed()
250 }
251 LanguageModelProvider::OpenAi => {
252 let api_key = state
253 .config
254 .openai_api_key
255 .as_ref()
256 .context("no OpenAI API key configured on the server")?;
257 let chunks = open_ai::stream_completion(
258 &state.http_client,
259 open_ai::OPEN_AI_API_URL,
260 api_key,
261 serde_json::from_str(¶ms.provider_request.get())?,
262 None,
263 )
264 .await?;
265
266 chunks
267 .map(|event| {
268 event.map(|chunk| {
269 let input_tokens =
270 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
271 let output_tokens =
272 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
273 (
274 serde_json::to_vec(&chunk).unwrap(),
275 input_tokens,
276 output_tokens,
277 )
278 })
279 })
280 .boxed()
281 }
282 LanguageModelProvider::Google => {
283 let api_key = state
284 .config
285 .google_ai_api_key
286 .as_ref()
287 .context("no Google AI API key configured on the server")?;
288 let chunks = google_ai::stream_generate_content(
289 &state.http_client,
290 google_ai::API_URL,
291 api_key,
292 serde_json::from_str(¶ms.provider_request.get())?,
293 )
294 .await?;
295
296 chunks
297 .map(|event| {
298 event.map(|chunk| {
299 // TODO - implement token counting for Google AI
300 let input_tokens = 0;
301 let output_tokens = 0;
302 (
303 serde_json::to_vec(&chunk).unwrap(),
304 input_tokens,
305 output_tokens,
306 )
307 })
308 })
309 .boxed()
310 }
311 LanguageModelProvider::Zed => {
312 let api_key = state
313 .config
314 .qwen2_7b_api_key
315 .as_ref()
316 .context("no Qwen2-7B API key configured on the server")?;
317 let api_url = state
318 .config
319 .qwen2_7b_api_url
320 .as_ref()
321 .context("no Qwen2-7B URL configured on the server")?;
322 let chunks = open_ai::stream_completion(
323 &state.http_client,
324 &api_url,
325 api_key,
326 serde_json::from_str(¶ms.provider_request.get())?,
327 None,
328 )
329 .await?;
330
331 chunks
332 .map(|event| {
333 event.map(|chunk| {
334 let input_tokens =
335 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
336 let output_tokens =
337 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
338 (
339 serde_json::to_vec(&chunk).unwrap(),
340 input_tokens,
341 output_tokens,
342 )
343 })
344 })
345 .boxed()
346 }
347 };
348
349 Ok(Response::new(Body::wrap_stream(TokenCountingStream {
350 state,
351 claims,
352 provider: params.provider,
353 model,
354 input_tokens: 0,
355 output_tokens: 0,
356 inner_stream: stream,
357 })))
358}
359
360fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
361 let prefixes: &[_] = match provider {
362 LanguageModelProvider::Anthropic => &[
363 "claude-3-5-sonnet",
364 "claude-3-haiku",
365 "claude-3-opus",
366 "claude-3-sonnet",
367 ],
368 LanguageModelProvider::OpenAi => &[
369 "gpt-3.5-turbo",
370 "gpt-4-turbo-preview",
371 "gpt-4o-mini",
372 "gpt-4o",
373 "gpt-4",
374 ],
375 LanguageModelProvider::Google => &[],
376 LanguageModelProvider::Zed => &[],
377 };
378
379 if let Some(prefix) = prefixes
380 .iter()
381 .filter(|&&prefix| name.starts_with(prefix))
382 .max_by_key(|&&prefix| prefix.len())
383 {
384 prefix.to_string()
385 } else {
386 name
387 }
388}
389
390async fn check_usage_limit(
391 state: &Arc<LlmState>,
392 provider: LanguageModelProvider,
393 model_name: &str,
394 claims: &LlmTokenClaims,
395) -> Result<()> {
396 let model = state.db.model(provider, model_name)?;
397 let usage = state
398 .db
399 .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
400 .await?;
401
402 let active_users = state.get_active_user_count().await?;
403
404 let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
405 let users_in_recent_days = active_users.users_in_recent_days.max(1);
406
407 let per_user_max_requests_per_minute =
408 model.max_requests_per_minute as usize / users_in_recent_minutes;
409 let per_user_max_tokens_per_minute =
410 model.max_tokens_per_minute as usize / users_in_recent_minutes;
411 let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
412
413 let checks = [
414 (
415 usage.requests_this_minute,
416 per_user_max_requests_per_minute,
417 UsageMeasure::RequestsPerMinute,
418 ),
419 (
420 usage.tokens_this_minute,
421 per_user_max_tokens_per_minute,
422 UsageMeasure::TokensPerMinute,
423 ),
424 (
425 usage.tokens_this_day,
426 per_user_max_tokens_per_day,
427 UsageMeasure::TokensPerDay,
428 ),
429 ];
430
431 for (used, limit, usage_measure) in checks {
432 // Temporarily bypass rate-limiting for staff members.
433 if claims.is_staff {
434 continue;
435 }
436
437 if used > limit {
438 let resource = match usage_measure {
439 UsageMeasure::RequestsPerMinute => "requests_per_minute",
440 UsageMeasure::TokensPerMinute => "tokens_per_minute",
441 UsageMeasure::TokensPerDay => "tokens_per_day",
442 _ => "",
443 };
444
445 if let Some(client) = state.clickhouse_client.as_ref() {
446 report_llm_rate_limit(
447 client,
448 LlmRateLimitEventRow {
449 time: Utc::now().timestamp_millis(),
450 user_id: claims.user_id as i32,
451 is_staff: claims.is_staff,
452 plan: match claims.plan {
453 Plan::Free => "free".to_string(),
454 Plan::ZedPro => "zed_pro".to_string(),
455 },
456 model: model.name.clone(),
457 provider: provider.to_string(),
458 usage_measure: resource.to_string(),
459 requests_this_minute: usage.requests_this_minute as u64,
460 tokens_this_minute: usage.tokens_this_minute as u64,
461 tokens_this_day: usage.tokens_this_day as u64,
462 users_in_recent_minutes: users_in_recent_minutes as u64,
463 users_in_recent_days: users_in_recent_days as u64,
464 max_requests_per_minute: per_user_max_requests_per_minute as u64,
465 max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
466 max_tokens_per_day: per_user_max_tokens_per_day as u64,
467 },
468 )
469 .await
470 .log_err();
471 }
472
473 return Err(Error::http(
474 StatusCode::TOO_MANY_REQUESTS,
475 format!("Rate limit exceeded. Maximum {} reached.", resource),
476 ));
477 }
478 }
479
480 Ok(())
481}
482
483struct TokenCountingStream<S> {
484 state: Arc<LlmState>,
485 claims: LlmTokenClaims,
486 provider: LanguageModelProvider,
487 model: String,
488 input_tokens: usize,
489 output_tokens: usize,
490 inner_stream: S,
491}
492
493impl<S> Stream for TokenCountingStream<S>
494where
495 S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
496{
497 type Item = Result<Vec<u8>, anyhow::Error>;
498
499 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
500 match Pin::new(&mut self.inner_stream).poll_next(cx) {
501 Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
502 bytes.push(b'\n');
503 self.input_tokens += input_tokens;
504 self.output_tokens += output_tokens;
505 Poll::Ready(Some(Ok(bytes)))
506 }
507 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
508 Poll::Ready(None) => Poll::Ready(None),
509 Poll::Pending => Poll::Pending,
510 }
511 }
512}
513
514impl<S> Drop for TokenCountingStream<S> {
515 fn drop(&mut self) {
516 let state = self.state.clone();
517 let claims = self.claims.clone();
518 let provider = self.provider;
519 let model = std::mem::take(&mut self.model);
520 let input_token_count = self.input_tokens;
521 let output_token_count = self.output_tokens;
522 self.state.executor.spawn_detached(async move {
523 let usage = state
524 .db
525 .record_usage(
526 claims.user_id as i32,
527 claims.is_staff,
528 provider,
529 &model,
530 input_token_count,
531 output_token_count,
532 Utc::now(),
533 )
534 .await
535 .log_err();
536
537 if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
538 report_llm_usage(
539 clickhouse_client,
540 LlmUsageEventRow {
541 time: Utc::now().timestamp_millis(),
542 user_id: claims.user_id as i32,
543 is_staff: claims.is_staff,
544 plan: match claims.plan {
545 Plan::Free => "free".to_string(),
546 Plan::ZedPro => "zed_pro".to_string(),
547 },
548 model,
549 provider: provider.to_string(),
550 input_token_count: input_token_count as u64,
551 output_token_count: output_token_count as u64,
552 requests_this_minute: usage.requests_this_minute as u64,
553 tokens_this_minute: usage.tokens_this_minute as u64,
554 tokens_this_day: usage.tokens_this_day as u64,
555 input_tokens_this_month: usage.input_tokens_this_month as u64,
556 output_tokens_this_month: usage.output_tokens_this_month as u64,
557 spending_this_month: usage.spending_this_month as u64,
558 },
559 )
560 .await
561 .log_err();
562 }
563 })
564 }
565}