1mod authorization;
2pub mod db;
3mod token;
4
5use crate::api::CloudflareIpCountryHeader;
6use crate::llm::authorization::authorize_access_to_language_model;
7use crate::llm::db::LlmDatabase;
8use crate::{executor::Executor, Config, Error, Result};
9use anyhow::{anyhow, Context as _};
10use axum::TypedHeader;
11use axum::{
12 body::Body,
13 http::{self, HeaderName, HeaderValue, Request, StatusCode},
14 middleware::{self, Next},
15 response::{IntoResponse, Response},
16 routing::post,
17 Extension, Json, Router,
18};
19use futures::StreamExt as _;
20use http_client::IsahcHttpClient;
21use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
22use std::sync::Arc;
23
24pub use token::*;
25
26pub struct LlmState {
27 pub config: Config,
28 pub executor: Executor,
29 pub db: Option<Arc<LlmDatabase>>,
30 pub http_client: IsahcHttpClient,
31}
32
33impl LlmState {
34 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
35 // TODO: This is temporary until we have the LLM database stood up.
36 let db = if config.is_development() {
37 let database_url = config
38 .llm_database_url
39 .as_ref()
40 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
41 let max_connections = config
42 .llm_database_max_connections
43 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
44
45 let mut db_options = db::ConnectOptions::new(database_url);
46 db_options.max_connections(max_connections);
47 let db = LlmDatabase::new(db_options, executor.clone()).await?;
48
49 Some(Arc::new(db))
50 } else {
51 None
52 };
53
54 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
55 let http_client = IsahcHttpClient::builder()
56 .default_header("User-Agent", user_agent)
57 .build()
58 .context("failed to construct http client")?;
59
60 let this = Self {
61 config,
62 executor,
63 db,
64 http_client,
65 };
66
67 Ok(Arc::new(this))
68 }
69}
70
71pub fn routes() -> Router<(), Body> {
72 Router::new()
73 .route("/completion", post(perform_completion))
74 .layer(middleware::from_fn(validate_api_token))
75}
76
77async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
78 let token = req
79 .headers()
80 .get(http::header::AUTHORIZATION)
81 .and_then(|header| header.to_str().ok())
82 .ok_or_else(|| {
83 Error::http(
84 StatusCode::BAD_REQUEST,
85 "missing authorization header".to_string(),
86 )
87 })?
88 .strip_prefix("Bearer ")
89 .ok_or_else(|| {
90 Error::http(
91 StatusCode::BAD_REQUEST,
92 "invalid authorization header".to_string(),
93 )
94 })?;
95
96 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
97 match LlmTokenClaims::validate(&token, &state.config) {
98 Ok(claims) => {
99 req.extensions_mut().insert(claims);
100 Ok::<_, Error>(next.run(req).await.into_response())
101 }
102 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
103 StatusCode::UNAUTHORIZED,
104 "unauthorized".to_string(),
105 [(
106 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
107 HeaderValue::from_static("true"),
108 )]
109 .into_iter()
110 .collect(),
111 )),
112 Err(_err) => Err(Error::http(
113 StatusCode::UNAUTHORIZED,
114 "unauthorized".to_string(),
115 )),
116 }
117}
118
119async fn perform_completion(
120 Extension(state): Extension<Arc<LlmState>>,
121 Extension(claims): Extension<LlmTokenClaims>,
122 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
123 Json(params): Json<PerformCompletionParams>,
124) -> Result<impl IntoResponse> {
125 authorize_access_to_language_model(
126 &state.config,
127 &claims,
128 country_code_header.map(|header| header.to_string()),
129 params.provider,
130 ¶ms.model,
131 )?;
132
133 match params.provider {
134 LanguageModelProvider::Anthropic => {
135 let api_key = state
136 .config
137 .anthropic_api_key
138 .as_ref()
139 .context("no Anthropic AI API key configured on the server")?;
140 let chunks = anthropic::stream_completion(
141 &state.http_client,
142 anthropic::ANTHROPIC_API_URL,
143 api_key,
144 serde_json::from_str(¶ms.provider_request.get())?,
145 None,
146 )
147 .await?;
148
149 let stream = chunks.map(|event| {
150 let mut buffer = Vec::new();
151 event.map(|chunk| {
152 buffer.clear();
153 serde_json::to_writer(&mut buffer, &chunk).unwrap();
154 buffer.push(b'\n');
155 buffer
156 })
157 });
158
159 Ok(Response::new(Body::wrap_stream(stream)))
160 }
161 LanguageModelProvider::OpenAi => {
162 let api_key = state
163 .config
164 .openai_api_key
165 .as_ref()
166 .context("no OpenAI API key configured on the server")?;
167 let chunks = open_ai::stream_completion(
168 &state.http_client,
169 open_ai::OPEN_AI_API_URL,
170 api_key,
171 serde_json::from_str(¶ms.provider_request.get())?,
172 None,
173 )
174 .await?;
175
176 let stream = chunks.map(|event| {
177 let mut buffer = Vec::new();
178 event.map(|chunk| {
179 buffer.clear();
180 serde_json::to_writer(&mut buffer, &chunk).unwrap();
181 buffer.push(b'\n');
182 buffer
183 })
184 });
185
186 Ok(Response::new(Body::wrap_stream(stream)))
187 }
188 LanguageModelProvider::Google => {
189 let api_key = state
190 .config
191 .google_ai_api_key
192 .as_ref()
193 .context("no Google AI API key configured on the server")?;
194 let chunks = google_ai::stream_generate_content(
195 &state.http_client,
196 google_ai::API_URL,
197 api_key,
198 serde_json::from_str(¶ms.provider_request.get())?,
199 )
200 .await?;
201
202 let stream = chunks.map(|event| {
203 let mut buffer = Vec::new();
204 event.map(|chunk| {
205 buffer.clear();
206 serde_json::to_writer(&mut buffer, &chunk).unwrap();
207 buffer.push(b'\n');
208 buffer
209 })
210 });
211
212 Ok(Response::new(Body::wrap_stream(stream)))
213 }
214 LanguageModelProvider::Zed => {
215 let api_key = state
216 .config
217 .qwen2_7b_api_key
218 .as_ref()
219 .context("no Qwen2-7B API key configured on the server")?;
220 let api_url = state
221 .config
222 .qwen2_7b_api_url
223 .as_ref()
224 .context("no Qwen2-7B URL configured on the server")?;
225 let chunks = open_ai::stream_completion(
226 &state.http_client,
227 &api_url,
228 api_key,
229 serde_json::from_str(¶ms.provider_request.get())?,
230 None,
231 )
232 .await?;
233
234 let stream = chunks.map(|event| {
235 let mut buffer = Vec::new();
236 event.map(|chunk| {
237 buffer.clear();
238 serde_json::to_writer(&mut buffer, &chunk).unwrap();
239 buffer.push(b'\n');
240 buffer
241 })
242 });
243
244 Ok(Response::new(Body::wrap_stream(stream)))
245 }
246 }
247}