1mod authorization;
2pub mod db;
3mod token;
4
5use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
6use anyhow::{anyhow, Context as _};
7use authorization::authorize_access_to_language_model;
8use axum::{
9 body::Body,
10 http::{self, HeaderName, HeaderValue, Request, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::post,
14 Extension, Json, Router, TypedHeader,
15};
16use chrono::{DateTime, Duration, Utc};
17use db::{ActiveUserCount, LlmDatabase};
18use futures::{Stream, StreamExt as _};
19use http_client::IsahcHttpClient;
20use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
21use std::{
22 pin::Pin,
23 sync::Arc,
24 task::{Context, Poll},
25};
26use tokio::sync::RwLock;
27use util::ResultExt;
28
29pub use token::*;
30
31pub struct LlmState {
32 pub config: Config,
33 pub executor: Executor,
34 pub db: Arc<LlmDatabase>,
35 pub http_client: IsahcHttpClient,
36 active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
37}
38
39const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
40
41impl LlmState {
42 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
43 let database_url = config
44 .llm_database_url
45 .as_ref()
46 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
47 let max_connections = config
48 .llm_database_max_connections
49 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
50
51 let mut db_options = db::ConnectOptions::new(database_url);
52 db_options.max_connections(max_connections);
53 let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
54 db.initialize().await?;
55
56 let db = Arc::new(db);
57
58 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
59 let http_client = IsahcHttpClient::builder()
60 .default_header("User-Agent", user_agent)
61 .build()
62 .context("failed to construct http client")?;
63
64 let initial_active_user_count =
65 Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
66
67 let this = Self {
68 config,
69 executor,
70 db,
71 http_client,
72 active_user_count: RwLock::new(initial_active_user_count),
73 };
74
75 Ok(Arc::new(this))
76 }
77
78 pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
79 let now = Utc::now();
80
81 if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
82 if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
83 return Ok(*count);
84 }
85 }
86
87 let mut cache = self.active_user_count.write().await;
88 let new_count = self.db.get_active_user_count(now).await?;
89 *cache = Some((now, new_count));
90 Ok(new_count)
91 }
92}
93
94pub fn routes() -> Router<(), Body> {
95 Router::new()
96 .route("/completion", post(perform_completion))
97 .layer(middleware::from_fn(validate_api_token))
98}
99
100async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
101 let token = req
102 .headers()
103 .get(http::header::AUTHORIZATION)
104 .and_then(|header| header.to_str().ok())
105 .ok_or_else(|| {
106 Error::http(
107 StatusCode::BAD_REQUEST,
108 "missing authorization header".to_string(),
109 )
110 })?
111 .strip_prefix("Bearer ")
112 .ok_or_else(|| {
113 Error::http(
114 StatusCode::BAD_REQUEST,
115 "invalid authorization header".to_string(),
116 )
117 })?;
118
119 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
120 match LlmTokenClaims::validate(&token, &state.config) {
121 Ok(claims) => {
122 req.extensions_mut().insert(claims);
123 Ok::<_, Error>(next.run(req).await.into_response())
124 }
125 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
126 StatusCode::UNAUTHORIZED,
127 "unauthorized".to_string(),
128 [(
129 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
130 HeaderValue::from_static("true"),
131 )]
132 .into_iter()
133 .collect(),
134 )),
135 Err(_err) => Err(Error::http(
136 StatusCode::UNAUTHORIZED,
137 "unauthorized".to_string(),
138 )),
139 }
140}
141
142async fn perform_completion(
143 Extension(state): Extension<Arc<LlmState>>,
144 Extension(claims): Extension<LlmTokenClaims>,
145 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
146 Json(params): Json<PerformCompletionParams>,
147) -> Result<impl IntoResponse> {
148 let model = normalize_model_name(params.provider, params.model);
149
150 authorize_access_to_language_model(
151 &state.config,
152 &claims,
153 country_code_header.map(|header| header.to_string()),
154 params.provider,
155 &model,
156 )?;
157
158 let user_id = claims.user_id as i32;
159
160 check_usage_limit(&state, params.provider, &model, &claims).await?;
161
162 let stream = match params.provider {
163 LanguageModelProvider::Anthropic => {
164 let api_key = state
165 .config
166 .anthropic_api_key
167 .as_ref()
168 .context("no Anthropic AI API key configured on the server")?;
169
170 let mut request: anthropic::Request =
171 serde_json::from_str(¶ms.provider_request.get())?;
172
173 // Parse the model, throw away the version that was included, and then set a specific
174 // version that we control on the server.
175 // Right now, we use the version that's defined in `model.id()`, but we will likely
176 // want to change this code once a new version of an Anthropic model is released,
177 // so that users can use the new version, without having to update Zed.
178 request.model = match anthropic::Model::from_id(&request.model) {
179 Ok(model) => model.id().to_string(),
180 Err(_) => request.model,
181 };
182
183 let chunks = anthropic::stream_completion(
184 &state.http_client,
185 anthropic::ANTHROPIC_API_URL,
186 api_key,
187 request,
188 None,
189 )
190 .await?;
191
192 chunks
193 .map(move |event| {
194 let chunk = event?;
195 let (input_tokens, output_tokens) = match &chunk {
196 anthropic::Event::MessageStart {
197 message: anthropic::Response { usage, .. },
198 }
199 | anthropic::Event::MessageDelta { usage, .. } => (
200 usage.input_tokens.unwrap_or(0) as usize,
201 usage.output_tokens.unwrap_or(0) as usize,
202 ),
203 _ => (0, 0),
204 };
205
206 anyhow::Ok((
207 serde_json::to_vec(&chunk).unwrap(),
208 input_tokens,
209 output_tokens,
210 ))
211 })
212 .boxed()
213 }
214 LanguageModelProvider::OpenAi => {
215 let api_key = state
216 .config
217 .openai_api_key
218 .as_ref()
219 .context("no OpenAI API key configured on the server")?;
220 let chunks = open_ai::stream_completion(
221 &state.http_client,
222 open_ai::OPEN_AI_API_URL,
223 api_key,
224 serde_json::from_str(¶ms.provider_request.get())?,
225 None,
226 )
227 .await?;
228
229 chunks
230 .map(|event| {
231 event.map(|chunk| {
232 let input_tokens =
233 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
234 let output_tokens =
235 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
236 (
237 serde_json::to_vec(&chunk).unwrap(),
238 input_tokens,
239 output_tokens,
240 )
241 })
242 })
243 .boxed()
244 }
245 LanguageModelProvider::Google => {
246 let api_key = state
247 .config
248 .google_ai_api_key
249 .as_ref()
250 .context("no Google AI API key configured on the server")?;
251 let chunks = google_ai::stream_generate_content(
252 &state.http_client,
253 google_ai::API_URL,
254 api_key,
255 serde_json::from_str(¶ms.provider_request.get())?,
256 )
257 .await?;
258
259 chunks
260 .map(|event| {
261 event.map(|chunk| {
262 // TODO - implement token counting for Google AI
263 let input_tokens = 0;
264 let output_tokens = 0;
265 (
266 serde_json::to_vec(&chunk).unwrap(),
267 input_tokens,
268 output_tokens,
269 )
270 })
271 })
272 .boxed()
273 }
274 LanguageModelProvider::Zed => {
275 let api_key = state
276 .config
277 .qwen2_7b_api_key
278 .as_ref()
279 .context("no Qwen2-7B API key configured on the server")?;
280 let api_url = state
281 .config
282 .qwen2_7b_api_url
283 .as_ref()
284 .context("no Qwen2-7B URL configured on the server")?;
285 let chunks = open_ai::stream_completion(
286 &state.http_client,
287 &api_url,
288 api_key,
289 serde_json::from_str(¶ms.provider_request.get())?,
290 None,
291 )
292 .await?;
293
294 chunks
295 .map(|event| {
296 event.map(|chunk| {
297 let input_tokens =
298 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
299 let output_tokens =
300 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
301 (
302 serde_json::to_vec(&chunk).unwrap(),
303 input_tokens,
304 output_tokens,
305 )
306 })
307 })
308 .boxed()
309 }
310 };
311
312 Ok(Response::new(Body::wrap_stream(TokenCountingStream {
313 db: state.db.clone(),
314 executor: state.executor.clone(),
315 user_id,
316 provider: params.provider,
317 model,
318 input_tokens: 0,
319 output_tokens: 0,
320 inner_stream: stream,
321 })))
322}
323
324fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
325 let prefixes: &[_] = match provider {
326 LanguageModelProvider::Anthropic => &[
327 "claude-3-5-sonnet",
328 "claude-3-haiku",
329 "claude-3-opus",
330 "claude-3-sonnet",
331 ],
332 LanguageModelProvider::OpenAi => &[
333 "gpt-3.5-turbo",
334 "gpt-4-turbo-preview",
335 "gpt-4o-mini",
336 "gpt-4o",
337 "gpt-4",
338 ],
339 LanguageModelProvider::Google => &[],
340 LanguageModelProvider::Zed => &[],
341 };
342
343 if let Some(prefix) = prefixes
344 .iter()
345 .filter(|&&prefix| name.starts_with(prefix))
346 .max_by_key(|&&prefix| prefix.len())
347 {
348 prefix.to_string()
349 } else {
350 name
351 }
352}
353
354async fn check_usage_limit(
355 state: &Arc<LlmState>,
356 provider: LanguageModelProvider,
357 model_name: &str,
358 claims: &LlmTokenClaims,
359) -> Result<()> {
360 let model = state.db.model(provider, model_name)?;
361 let usage = state
362 .db
363 .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
364 .await?;
365
366 let active_users = state.get_active_user_count().await?;
367
368 let per_user_max_requests_per_minute =
369 model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
370 let per_user_max_tokens_per_minute =
371 model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
372 let per_user_max_tokens_per_day =
373 model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
374
375 let checks = [
376 (
377 usage.requests_this_minute,
378 per_user_max_requests_per_minute,
379 "requests per minute",
380 ),
381 (
382 usage.tokens_this_minute,
383 per_user_max_tokens_per_minute,
384 "tokens per minute",
385 ),
386 (
387 usage.tokens_this_day,
388 per_user_max_tokens_per_day,
389 "tokens per day",
390 ),
391 ];
392
393 for (usage, limit, resource) in checks {
394 if usage > limit {
395 return Err(Error::http(
396 StatusCode::TOO_MANY_REQUESTS,
397 format!("Rate limit exceeded. Maximum {} reached.", resource),
398 ));
399 }
400 }
401
402 Ok(())
403}
404
405struct TokenCountingStream<S> {
406 db: Arc<LlmDatabase>,
407 executor: Executor,
408 user_id: i32,
409 provider: LanguageModelProvider,
410 model: String,
411 input_tokens: usize,
412 output_tokens: usize,
413 inner_stream: S,
414}
415
416impl<S> Stream for TokenCountingStream<S>
417where
418 S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
419{
420 type Item = Result<Vec<u8>, anyhow::Error>;
421
422 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
423 match Pin::new(&mut self.inner_stream).poll_next(cx) {
424 Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
425 bytes.push(b'\n');
426 self.input_tokens += input_tokens;
427 self.output_tokens += output_tokens;
428 Poll::Ready(Some(Ok(bytes)))
429 }
430 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
431 Poll::Ready(None) => Poll::Ready(None),
432 Poll::Pending => Poll::Pending,
433 }
434 }
435}
436
437impl<S> Drop for TokenCountingStream<S> {
438 fn drop(&mut self) {
439 let db = self.db.clone();
440 let user_id = self.user_id;
441 let provider = self.provider;
442 let model = std::mem::take(&mut self.model);
443 let token_count = self.input_tokens + self.output_tokens;
444 self.executor.spawn_detached(async move {
445 db.record_usage(user_id, provider, &model, token_count, Utc::now())
446 .await
447 .log_err();
448 })
449 }
450}