1mod authorization;
2pub mod db;
3mod token;
4
5use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
6use anyhow::{anyhow, Context as _};
7use authorization::authorize_access_to_language_model;
8use axum::{
9 body::Body,
10 http::{self, HeaderName, HeaderValue, Request, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::post,
14 Extension, Json, Router, TypedHeader,
15};
16use chrono::{DateTime, Duration, Utc};
17use db::{ActiveUserCount, LlmDatabase};
18use futures::StreamExt as _;
19use http_client::IsahcHttpClient;
20use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use util::ResultExt;
24
25pub use token::*;
26
27pub struct LlmState {
28 pub config: Config,
29 pub executor: Executor,
30 pub db: Option<Arc<LlmDatabase>>,
31 pub http_client: IsahcHttpClient,
32 active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
33}
34
35const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
36
37impl LlmState {
38 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
39 // TODO: This is temporary until we have the LLM database stood up.
40 let db = if config.is_development() {
41 let database_url = config
42 .llm_database_url
43 .as_ref()
44 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
45 let max_connections = config
46 .llm_database_max_connections
47 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
48
49 let mut db_options = db::ConnectOptions::new(database_url);
50 db_options.max_connections(max_connections);
51 let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
52 db.initialize().await?;
53
54 Some(Arc::new(db))
55 } else {
56 None
57 };
58
59 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
60 let http_client = IsahcHttpClient::builder()
61 .default_header("User-Agent", user_agent)
62 .build()
63 .context("failed to construct http client")?;
64
65 let initial_active_user_count = if let Some(db) = &db {
66 Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
67 } else {
68 None
69 };
70
71 let this = Self {
72 config,
73 executor,
74 db,
75 http_client,
76 active_user_count: RwLock::new(initial_active_user_count),
77 };
78
79 Ok(Arc::new(this))
80 }
81
82 pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
83 let now = Utc::now();
84
85 if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
86 if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
87 return Ok(*count);
88 }
89 }
90
91 if let Some(db) = &self.db {
92 let mut cache = self.active_user_count.write().await;
93 let new_count = db.get_active_user_count(now).await?;
94 *cache = Some((now, new_count));
95 Ok(new_count)
96 } else {
97 Ok(ActiveUserCount::default())
98 }
99 }
100}
101
102pub fn routes() -> Router<(), Body> {
103 Router::new()
104 .route("/completion", post(perform_completion))
105 .layer(middleware::from_fn(validate_api_token))
106}
107
108async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
109 let token = req
110 .headers()
111 .get(http::header::AUTHORIZATION)
112 .and_then(|header| header.to_str().ok())
113 .ok_or_else(|| {
114 Error::http(
115 StatusCode::BAD_REQUEST,
116 "missing authorization header".to_string(),
117 )
118 })?
119 .strip_prefix("Bearer ")
120 .ok_or_else(|| {
121 Error::http(
122 StatusCode::BAD_REQUEST,
123 "invalid authorization header".to_string(),
124 )
125 })?;
126
127 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
128 match LlmTokenClaims::validate(&token, &state.config) {
129 Ok(claims) => {
130 req.extensions_mut().insert(claims);
131 Ok::<_, Error>(next.run(req).await.into_response())
132 }
133 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
134 StatusCode::UNAUTHORIZED,
135 "unauthorized".to_string(),
136 [(
137 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
138 HeaderValue::from_static("true"),
139 )]
140 .into_iter()
141 .collect(),
142 )),
143 Err(_err) => Err(Error::http(
144 StatusCode::UNAUTHORIZED,
145 "unauthorized".to_string(),
146 )),
147 }
148}
149
150async fn perform_completion(
151 Extension(state): Extension<Arc<LlmState>>,
152 Extension(claims): Extension<LlmTokenClaims>,
153 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
154 Json(params): Json<PerformCompletionParams>,
155) -> Result<impl IntoResponse> {
156 let model = normalize_model_name(params.provider, params.model);
157
158 authorize_access_to_language_model(
159 &state.config,
160 &claims,
161 country_code_header.map(|header| header.to_string()),
162 params.provider,
163 &model,
164 )?;
165
166 let user_id = claims.user_id as i32;
167
168 if state.db.is_some() {
169 check_usage_limit(&state, params.provider, &model, &claims).await?;
170 }
171
172 match params.provider {
173 LanguageModelProvider::Anthropic => {
174 let api_key = state
175 .config
176 .anthropic_api_key
177 .as_ref()
178 .context("no Anthropic AI API key configured on the server")?;
179
180 let mut request: anthropic::Request =
181 serde_json::from_str(¶ms.provider_request.get())?;
182
183 // Parse the model, throw away the version that was included, and then set a specific
184 // version that we control on the server.
185 // Right now, we use the version that's defined in `model.id()`, but we will likely
186 // want to change this code once a new version of an Anthropic model is released,
187 // so that users can use the new version, without having to update Zed.
188 request.model = match anthropic::Model::from_id(&request.model) {
189 Ok(model) => model.id().to_string(),
190 Err(_) => request.model,
191 };
192
193 let chunks = anthropic::stream_completion(
194 &state.http_client,
195 anthropic::ANTHROPIC_API_URL,
196 api_key,
197 request,
198 None,
199 )
200 .await?;
201
202 let mut recorder = state.db.clone().map(|db| UsageRecorder {
203 db,
204 executor: state.executor.clone(),
205 user_id,
206 provider: params.provider,
207 model,
208 token_count: 0,
209 });
210
211 let stream = chunks.map(move |event| {
212 let mut buffer = Vec::new();
213 event.map(|chunk| {
214 match &chunk {
215 anthropic::Event::MessageStart {
216 message: anthropic::Response { usage, .. },
217 }
218 | anthropic::Event::MessageDelta { usage, .. } => {
219 if let Some(recorder) = &mut recorder {
220 recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
221 recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
222 }
223 }
224 _ => {}
225 }
226
227 buffer.clear();
228 serde_json::to_writer(&mut buffer, &chunk).unwrap();
229 buffer.push(b'\n');
230 buffer
231 })
232 });
233
234 Ok(Response::new(Body::wrap_stream(stream)))
235 }
236 LanguageModelProvider::OpenAi => {
237 let api_key = state
238 .config
239 .openai_api_key
240 .as_ref()
241 .context("no OpenAI API key configured on the server")?;
242 let chunks = open_ai::stream_completion(
243 &state.http_client,
244 open_ai::OPEN_AI_API_URL,
245 api_key,
246 serde_json::from_str(¶ms.provider_request.get())?,
247 None,
248 )
249 .await?;
250
251 let stream = chunks.map(|event| {
252 let mut buffer = Vec::new();
253 event.map(|chunk| {
254 buffer.clear();
255 serde_json::to_writer(&mut buffer, &chunk).unwrap();
256 buffer.push(b'\n');
257 buffer
258 })
259 });
260
261 Ok(Response::new(Body::wrap_stream(stream)))
262 }
263 LanguageModelProvider::Google => {
264 let api_key = state
265 .config
266 .google_ai_api_key
267 .as_ref()
268 .context("no Google AI API key configured on the server")?;
269 let chunks = google_ai::stream_generate_content(
270 &state.http_client,
271 google_ai::API_URL,
272 api_key,
273 serde_json::from_str(¶ms.provider_request.get())?,
274 )
275 .await?;
276
277 let stream = chunks.map(|event| {
278 let mut buffer = Vec::new();
279 event.map(|chunk| {
280 buffer.clear();
281 serde_json::to_writer(&mut buffer, &chunk).unwrap();
282 buffer.push(b'\n');
283 buffer
284 })
285 });
286
287 Ok(Response::new(Body::wrap_stream(stream)))
288 }
289 LanguageModelProvider::Zed => {
290 let api_key = state
291 .config
292 .qwen2_7b_api_key
293 .as_ref()
294 .context("no Qwen2-7B API key configured on the server")?;
295 let api_url = state
296 .config
297 .qwen2_7b_api_url
298 .as_ref()
299 .context("no Qwen2-7B URL configured on the server")?;
300 let chunks = open_ai::stream_completion(
301 &state.http_client,
302 &api_url,
303 api_key,
304 serde_json::from_str(¶ms.provider_request.get())?,
305 None,
306 )
307 .await?;
308
309 let stream = chunks.map(|event| {
310 let mut buffer = Vec::new();
311 event.map(|chunk| {
312 buffer.clear();
313 serde_json::to_writer(&mut buffer, &chunk).unwrap();
314 buffer.push(b'\n');
315 buffer
316 })
317 });
318
319 Ok(Response::new(Body::wrap_stream(stream)))
320 }
321 }
322}
323
324fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
325 match provider {
326 LanguageModelProvider::Anthropic => {
327 for prefix in &[
328 "claude-3-5-sonnet",
329 "claude-3-haiku",
330 "claude-3-opus",
331 "claude-3-sonnet",
332 ] {
333 if name.starts_with(prefix) {
334 return prefix.to_string();
335 }
336 }
337 }
338 LanguageModelProvider::OpenAi => {}
339 LanguageModelProvider::Google => {}
340 LanguageModelProvider::Zed => {}
341 }
342
343 name
344}
345
346async fn check_usage_limit(
347 state: &Arc<LlmState>,
348 provider: LanguageModelProvider,
349 model_name: &str,
350 claims: &LlmTokenClaims,
351) -> Result<()> {
352 let db = state
353 .db
354 .as_ref()
355 .ok_or_else(|| anyhow!("LLM database not configured"))?;
356 let model = db.model(provider, model_name)?;
357 let usage = db
358 .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
359 .await?;
360
361 let active_users = state.get_active_user_count().await?;
362
363 let per_user_max_requests_per_minute =
364 model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
365 let per_user_max_tokens_per_minute =
366 model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
367 let per_user_max_tokens_per_day =
368 model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
369
370 let checks = [
371 (
372 usage.requests_this_minute,
373 per_user_max_requests_per_minute,
374 "requests per minute",
375 ),
376 (
377 usage.tokens_this_minute,
378 per_user_max_tokens_per_minute,
379 "tokens per minute",
380 ),
381 (
382 usage.tokens_this_day,
383 per_user_max_tokens_per_day,
384 "tokens per day",
385 ),
386 ];
387
388 for (usage, limit, resource) in checks {
389 if usage > limit {
390 return Err(Error::http(
391 StatusCode::TOO_MANY_REQUESTS,
392 format!("Rate limit exceeded. Maximum {} reached.", resource),
393 ));
394 }
395 }
396
397 Ok(())
398}
399struct UsageRecorder {
400 db: Arc<LlmDatabase>,
401 executor: Executor,
402 user_id: i32,
403 provider: LanguageModelProvider,
404 model: String,
405 token_count: usize,
406}
407
408impl Drop for UsageRecorder {
409 fn drop(&mut self) {
410 let db = self.db.clone();
411 let user_id = self.user_id;
412 let provider = self.provider;
413 let model = std::mem::take(&mut self.model);
414 let token_count = self.token_count;
415 self.executor.spawn_detached(async move {
416 db.record_usage(user_id, provider, &model, token_count, Utc::now())
417 .await
418 .log_err();
419 })
420 }
421}