1mod authorization;
2pub mod db;
3mod token;
4
5use crate::api::CloudflareIpCountryHeader;
6use crate::llm::authorization::authorize_access_to_language_model;
7use crate::llm::db::LlmDatabase;
8use crate::{executor::Executor, Config, Error, Result};
9use anyhow::{anyhow, Context as _};
10use axum::TypedHeader;
11use axum::{
12 body::Body,
13 http::{self, HeaderName, HeaderValue, Request, StatusCode},
14 middleware::{self, Next},
15 response::{IntoResponse, Response},
16 routing::post,
17 Extension, Json, Router,
18};
19use futures::StreamExt as _;
20use http_client::IsahcHttpClient;
21use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
22use std::sync::Arc;
23
24pub use token::*;
25
26pub struct LlmState {
27 pub config: Config,
28 pub executor: Executor,
29 pub db: Option<Arc<LlmDatabase>>,
30 pub http_client: IsahcHttpClient,
31}
32
33impl LlmState {
34 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
35 // TODO: This is temporary until we have the LLM database stood up.
36 let db = if config.is_development() {
37 let database_url = config
38 .llm_database_url
39 .as_ref()
40 .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
41 let max_connections = config
42 .llm_database_max_connections
43 .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
44
45 let mut db_options = db::ConnectOptions::new(database_url);
46 db_options.max_connections(max_connections);
47 let db = LlmDatabase::new(db_options, executor.clone()).await?;
48
49 Some(Arc::new(db))
50 } else {
51 None
52 };
53
54 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
55 let http_client = IsahcHttpClient::builder()
56 .default_header("User-Agent", user_agent)
57 .build()
58 .context("failed to construct http client")?;
59
60 let this = Self {
61 config,
62 executor,
63 db,
64 http_client,
65 };
66
67 Ok(Arc::new(this))
68 }
69}
70
71pub fn routes() -> Router<(), Body> {
72 Router::new()
73 .route("/completion", post(perform_completion))
74 .layer(middleware::from_fn(validate_api_token))
75}
76
77async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
78 let token = req
79 .headers()
80 .get(http::header::AUTHORIZATION)
81 .and_then(|header| header.to_str().ok())
82 .ok_or_else(|| {
83 Error::http(
84 StatusCode::BAD_REQUEST,
85 "missing authorization header".to_string(),
86 )
87 })?
88 .strip_prefix("Bearer ")
89 .ok_or_else(|| {
90 Error::http(
91 StatusCode::BAD_REQUEST,
92 "invalid authorization header".to_string(),
93 )
94 })?;
95
96 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
97 match LlmTokenClaims::validate(&token, &state.config) {
98 Ok(claims) => {
99 req.extensions_mut().insert(claims);
100 Ok::<_, Error>(next.run(req).await.into_response())
101 }
102 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
103 StatusCode::UNAUTHORIZED,
104 "unauthorized".to_string(),
105 [(
106 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
107 HeaderValue::from_static("true"),
108 )]
109 .into_iter()
110 .collect(),
111 )),
112 Err(_err) => Err(Error::http(
113 StatusCode::UNAUTHORIZED,
114 "unauthorized".to_string(),
115 )),
116 }
117}
118
119async fn perform_completion(
120 Extension(state): Extension<Arc<LlmState>>,
121 Extension(claims): Extension<LlmTokenClaims>,
122 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
123 Json(params): Json<PerformCompletionParams>,
124) -> Result<impl IntoResponse> {
125 authorize_access_to_language_model(
126 &state.config,
127 &claims,
128 country_code_header.map(|header| header.to_string()),
129 params.provider,
130 ¶ms.model,
131 )?;
132
133 match params.provider {
134 LanguageModelProvider::Anthropic => {
135 let api_key = state
136 .config
137 .anthropic_api_key
138 .as_ref()
139 .context("no Anthropic AI API key configured on the server")?;
140
141 let mut request: anthropic::Request =
142 serde_json::from_str(¶ms.provider_request.get())?;
143
144 // Parse the model, throw away the version that was included, and then set a specific
145 // version that we control on the server.
146 // Right now, we use the version that's defined in `model.id()`, but we will likely
147 // want to change this code once a new version of an Anthropic model is released,
148 // so that users can use the new version, without having to update Zed.
149 request.model = match anthropic::Model::from_id(&request.model) {
150 Ok(model) => model.id().to_string(),
151 Err(_) => request.model,
152 };
153
154 let chunks = anthropic::stream_completion(
155 &state.http_client,
156 anthropic::ANTHROPIC_API_URL,
157 api_key,
158 request,
159 None,
160 )
161 .await?;
162
163 let stream = chunks.map(|event| {
164 let mut buffer = Vec::new();
165 event.map(|chunk| {
166 buffer.clear();
167 serde_json::to_writer(&mut buffer, &chunk).unwrap();
168 buffer.push(b'\n');
169 buffer
170 })
171 });
172
173 Ok(Response::new(Body::wrap_stream(stream)))
174 }
175 LanguageModelProvider::OpenAi => {
176 let api_key = state
177 .config
178 .openai_api_key
179 .as_ref()
180 .context("no OpenAI API key configured on the server")?;
181 let chunks = open_ai::stream_completion(
182 &state.http_client,
183 open_ai::OPEN_AI_API_URL,
184 api_key,
185 serde_json::from_str(¶ms.provider_request.get())?,
186 None,
187 )
188 .await?;
189
190 let stream = chunks.map(|event| {
191 let mut buffer = Vec::new();
192 event.map(|chunk| {
193 buffer.clear();
194 serde_json::to_writer(&mut buffer, &chunk).unwrap();
195 buffer.push(b'\n');
196 buffer
197 })
198 });
199
200 Ok(Response::new(Body::wrap_stream(stream)))
201 }
202 LanguageModelProvider::Google => {
203 let api_key = state
204 .config
205 .google_ai_api_key
206 .as_ref()
207 .context("no Google AI API key configured on the server")?;
208 let chunks = google_ai::stream_generate_content(
209 &state.http_client,
210 google_ai::API_URL,
211 api_key,
212 serde_json::from_str(¶ms.provider_request.get())?,
213 )
214 .await?;
215
216 let stream = chunks.map(|event| {
217 let mut buffer = Vec::new();
218 event.map(|chunk| {
219 buffer.clear();
220 serde_json::to_writer(&mut buffer, &chunk).unwrap();
221 buffer.push(b'\n');
222 buffer
223 })
224 });
225
226 Ok(Response::new(Body::wrap_stream(stream)))
227 }
228 LanguageModelProvider::Zed => {
229 let api_key = state
230 .config
231 .qwen2_7b_api_key
232 .as_ref()
233 .context("no Qwen2-7B API key configured on the server")?;
234 let api_url = state
235 .config
236 .qwen2_7b_api_url
237 .as_ref()
238 .context("no Qwen2-7B URL configured on the server")?;
239 let chunks = open_ai::stream_completion(
240 &state.http_client,
241 &api_url,
242 api_key,
243 serde_json::from_str(¶ms.provider_request.get())?,
244 None,
245 )
246 .await?;
247
248 let stream = chunks.map(|event| {
249 let mut buffer = Vec::new();
250 event.map(|chunk| {
251 buffer.clear();
252 serde_json::to_writer(&mut buffer, &chunk).unwrap();
253 buffer.push(b'\n');
254 buffer
255 })
256 });
257
258 Ok(Response::new(Body::wrap_stream(stream)))
259 }
260 }
261}