1mod authorization;
2pub mod db;
3mod token;
4
5use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
6use anyhow::{anyhow, Context as _};
7use authorization::authorize_access_to_language_model;
8use axum::{
9 body::Body,
10 http::{self, HeaderName, HeaderValue, Request, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::post,
14 Extension, Json, Router, TypedHeader,
15};
16use chrono::{DateTime, Duration, Utc};
17use db::{ActiveUserCount, LlmDatabase};
18use futures::{Stream, StreamExt as _};
19use http_client::IsahcHttpClient;
20use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
21use std::{
22 pin::Pin,
23 sync::Arc,
24 task::{Context, Poll},
25};
26use tokio::sync::RwLock;
27use util::ResultExt;
28
29pub use token::*;
30
31pub struct LlmState {
32 pub config: Config,
33 pub executor: Executor,
34 pub db: Arc<LlmDatabase>,
35 pub http_client: IsahcHttpClient,
36 active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
37}
38
39const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
40
41impl LlmState {
42 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
43 let database_url = config
44 .llm_database_url
45 .as_ref()
46 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
47 let max_connections = config
48 .llm_database_max_connections
49 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
50
51 let mut db_options = db::ConnectOptions::new(database_url);
52 db_options.max_connections(max_connections);
53 let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
54 db.initialize().await?;
55
56 let db = Arc::new(db);
57
58 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
59 let http_client = IsahcHttpClient::builder()
60 .default_header("User-Agent", user_agent)
61 .build()
62 .context("failed to construct http client")?;
63
64 let initial_active_user_count =
65 Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
66
67 let this = Self {
68 config,
69 executor,
70 db,
71 http_client,
72 active_user_count: RwLock::new(initial_active_user_count),
73 };
74
75 Ok(Arc::new(this))
76 }
77
78 pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
79 let now = Utc::now();
80
81 if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
82 if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
83 return Ok(*count);
84 }
85 }
86
87 let mut cache = self.active_user_count.write().await;
88 let new_count = self.db.get_active_user_count(now).await?;
89 *cache = Some((now, new_count));
90 Ok(new_count)
91 }
92}
93
94pub fn routes() -> Router<(), Body> {
95 Router::new()
96 .route("/completion", post(perform_completion))
97 .layer(middleware::from_fn(validate_api_token))
98}
99
100async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
101 let token = req
102 .headers()
103 .get(http::header::AUTHORIZATION)
104 .and_then(|header| header.to_str().ok())
105 .ok_or_else(|| {
106 Error::http(
107 StatusCode::BAD_REQUEST,
108 "missing authorization header".to_string(),
109 )
110 })?
111 .strip_prefix("Bearer ")
112 .ok_or_else(|| {
113 Error::http(
114 StatusCode::BAD_REQUEST,
115 "invalid authorization header".to_string(),
116 )
117 })?;
118
119 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
120 match LlmTokenClaims::validate(&token, &state.config) {
121 Ok(claims) => {
122 req.extensions_mut().insert(claims);
123 Ok::<_, Error>(next.run(req).await.into_response())
124 }
125 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
126 StatusCode::UNAUTHORIZED,
127 "unauthorized".to_string(),
128 [(
129 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
130 HeaderValue::from_static("true"),
131 )]
132 .into_iter()
133 .collect(),
134 )),
135 Err(_err) => Err(Error::http(
136 StatusCode::UNAUTHORIZED,
137 "unauthorized".to_string(),
138 )),
139 }
140}
141
142async fn perform_completion(
143 Extension(state): Extension<Arc<LlmState>>,
144 Extension(claims): Extension<LlmTokenClaims>,
145 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
146 Json(params): Json<PerformCompletionParams>,
147) -> Result<impl IntoResponse> {
148 let model = normalize_model_name(params.provider, params.model);
149
150 authorize_access_to_language_model(
151 &state.config,
152 &claims,
153 country_code_header.map(|header| header.to_string()),
154 params.provider,
155 &model,
156 )?;
157
158 let user_id = claims.user_id as i32;
159
160 check_usage_limit(&state, params.provider, &model, &claims).await?;
161
162 let stream = match params.provider {
163 LanguageModelProvider::Anthropic => {
164 let api_key = state
165 .config
166 .anthropic_api_key
167 .as_ref()
168 .context("no Anthropic AI API key configured on the server")?;
169
170 let mut request: anthropic::Request =
171 serde_json::from_str(¶ms.provider_request.get())?;
172
173 // Parse the model, throw away the version that was included, and then set a specific
174 // version that we control on the server.
175 // Right now, we use the version that's defined in `model.id()`, but we will likely
176 // want to change this code once a new version of an Anthropic model is released,
177 // so that users can use the new version, without having to update Zed.
178 request.model = match anthropic::Model::from_id(&request.model) {
179 Ok(model) => model.id().to_string(),
180 Err(_) => request.model,
181 };
182
183 let chunks = anthropic::stream_completion(
184 &state.http_client,
185 anthropic::ANTHROPIC_API_URL,
186 api_key,
187 request,
188 None,
189 )
190 .await?;
191
192 chunks
193 .map(move |event| {
194 let chunk = event?;
195 let (input_tokens, output_tokens) = match &chunk {
196 anthropic::Event::MessageStart {
197 message: anthropic::Response { usage, .. },
198 }
199 | anthropic::Event::MessageDelta { usage, .. } => (
200 usage.input_tokens.unwrap_or(0) as usize,
201 usage.output_tokens.unwrap_or(0) as usize,
202 ),
203 _ => (0, 0),
204 };
205
206 anyhow::Ok((
207 serde_json::to_vec(&chunk).unwrap(),
208 input_tokens,
209 output_tokens,
210 ))
211 })
212 .boxed()
213 }
214 LanguageModelProvider::OpenAi => {
215 let api_key = state
216 .config
217 .openai_api_key
218 .as_ref()
219 .context("no OpenAI API key configured on the server")?;
220 let chunks = open_ai::stream_completion(
221 &state.http_client,
222 open_ai::OPEN_AI_API_URL,
223 api_key,
224 serde_json::from_str(¶ms.provider_request.get())?,
225 None,
226 )
227 .await?;
228
229 chunks
230 .map(|event| {
231 event.map(|chunk| {
232 let input_tokens =
233 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
234 let output_tokens =
235 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
236 (
237 serde_json::to_vec(&chunk).unwrap(),
238 input_tokens,
239 output_tokens,
240 )
241 })
242 })
243 .boxed()
244 }
245 LanguageModelProvider::Google => {
246 let api_key = state
247 .config
248 .google_ai_api_key
249 .as_ref()
250 .context("no Google AI API key configured on the server")?;
251 let chunks = google_ai::stream_generate_content(
252 &state.http_client,
253 google_ai::API_URL,
254 api_key,
255 serde_json::from_str(¶ms.provider_request.get())?,
256 )
257 .await?;
258
259 chunks
260 .map(|event| {
261 event.map(|chunk| {
262 // TODO - implement token counting for Google AI
263 let input_tokens = 0;
264 let output_tokens = 0;
265 (
266 serde_json::to_vec(&chunk).unwrap(),
267 input_tokens,
268 output_tokens,
269 )
270 })
271 })
272 .boxed()
273 }
274 LanguageModelProvider::Zed => {
275 let api_key = state
276 .config
277 .qwen2_7b_api_key
278 .as_ref()
279 .context("no Qwen2-7B API key configured on the server")?;
280 let api_url = state
281 .config
282 .qwen2_7b_api_url
283 .as_ref()
284 .context("no Qwen2-7B URL configured on the server")?;
285 let chunks = open_ai::stream_completion(
286 &state.http_client,
287 &api_url,
288 api_key,
289 serde_json::from_str(¶ms.provider_request.get())?,
290 None,
291 )
292 .await?;
293
294 chunks
295 .map(|event| {
296 event.map(|chunk| {
297 let input_tokens =
298 chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
299 let output_tokens =
300 chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
301 (
302 serde_json::to_vec(&chunk).unwrap(),
303 input_tokens,
304 output_tokens,
305 )
306 })
307 })
308 .boxed()
309 }
310 };
311
312 Ok(Response::new(Body::wrap_stream(TokenCountingStream {
313 db: state.db.clone(),
314 executor: state.executor.clone(),
315 user_id,
316 provider: params.provider,
317 model,
318 input_tokens: 0,
319 output_tokens: 0,
320 inner_stream: stream,
321 })))
322}
323
324fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
325 match provider {
326 LanguageModelProvider::Anthropic => {
327 for prefix in &[
328 "claude-3-5-sonnet",
329 "claude-3-haiku",
330 "claude-3-opus",
331 "claude-3-sonnet",
332 ] {
333 if name.starts_with(prefix) {
334 return prefix.to_string();
335 }
336 }
337 }
338 LanguageModelProvider::OpenAi => {}
339 LanguageModelProvider::Google => {}
340 LanguageModelProvider::Zed => {}
341 }
342
343 name
344}
345
346async fn check_usage_limit(
347 state: &Arc<LlmState>,
348 provider: LanguageModelProvider,
349 model_name: &str,
350 claims: &LlmTokenClaims,
351) -> Result<()> {
352 let model = state.db.model(provider, model_name)?;
353 let usage = state
354 .db
355 .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
356 .await?;
357
358 let active_users = state.get_active_user_count().await?;
359
360 let per_user_max_requests_per_minute =
361 model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
362 let per_user_max_tokens_per_minute =
363 model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
364 let per_user_max_tokens_per_day =
365 model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
366
367 let checks = [
368 (
369 usage.requests_this_minute,
370 per_user_max_requests_per_minute,
371 "requests per minute",
372 ),
373 (
374 usage.tokens_this_minute,
375 per_user_max_tokens_per_minute,
376 "tokens per minute",
377 ),
378 (
379 usage.tokens_this_day,
380 per_user_max_tokens_per_day,
381 "tokens per day",
382 ),
383 ];
384
385 for (usage, limit, resource) in checks {
386 if usage > limit {
387 return Err(Error::http(
388 StatusCode::TOO_MANY_REQUESTS,
389 format!("Rate limit exceeded. Maximum {} reached.", resource),
390 ));
391 }
392 }
393
394 Ok(())
395}
396
397struct TokenCountingStream<S> {
398 db: Arc<LlmDatabase>,
399 executor: Executor,
400 user_id: i32,
401 provider: LanguageModelProvider,
402 model: String,
403 input_tokens: usize,
404 output_tokens: usize,
405 inner_stream: S,
406}
407
408impl<S> Stream for TokenCountingStream<S>
409where
410 S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
411{
412 type Item = Result<Vec<u8>, anyhow::Error>;
413
414 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
415 match Pin::new(&mut self.inner_stream).poll_next(cx) {
416 Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
417 bytes.push(b'\n');
418 self.input_tokens += input_tokens;
419 self.output_tokens += output_tokens;
420 Poll::Ready(Some(Ok(bytes)))
421 }
422 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
423 Poll::Ready(None) => Poll::Ready(None),
424 Poll::Pending => Poll::Pending,
425 }
426 }
427}
428
429impl<S> Drop for TokenCountingStream<S> {
430 fn drop(&mut self) {
431 let db = self.db.clone();
432 let user_id = self.user_id;
433 let provider = self.provider;
434 let model = std::mem::take(&mut self.model);
435 let token_count = self.input_tokens + self.output_tokens;
436 self.executor.spawn_detached(async move {
437 db.record_usage(user_id, provider, &model, token_count, Utc::now())
438 .await
439 .log_err();
440 })
441 }
442}