1mod authorization;
2mod token;
3
4use crate::api::CloudflareIpCountryHeader;
5use crate::llm::authorization::authorize_access_to_language_model;
6use crate::{executor::Executor, Config, Error, Result};
7use anyhow::Context as _;
8use axum::TypedHeader;
9use axum::{
10 body::Body,
11 http::{self, HeaderName, HeaderValue, Request, StatusCode},
12 middleware::{self, Next},
13 response::{IntoResponse, Response},
14 routing::post,
15 Extension, Json, Router,
16};
17use futures::StreamExt as _;
18use http_client::IsahcHttpClient;
19use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
20use std::sync::Arc;
21
22pub use token::*;
23
24pub struct LlmState {
25 pub config: Config,
26 pub executor: Executor,
27 pub http_client: IsahcHttpClient,
28}
29
30impl LlmState {
31 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
32 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
33 let http_client = IsahcHttpClient::builder()
34 .default_header("User-Agent", user_agent)
35 .build()
36 .context("failed to construct http client")?;
37
38 let this = Self {
39 config,
40 executor,
41 http_client,
42 };
43
44 Ok(Arc::new(this))
45 }
46}
47
48pub fn routes() -> Router<(), Body> {
49 Router::new()
50 .route("/completion", post(perform_completion))
51 .layer(middleware::from_fn(validate_api_token))
52}
53
54async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
55 let token = req
56 .headers()
57 .get(http::header::AUTHORIZATION)
58 .and_then(|header| header.to_str().ok())
59 .ok_or_else(|| {
60 Error::http(
61 StatusCode::BAD_REQUEST,
62 "missing authorization header".to_string(),
63 )
64 })?
65 .strip_prefix("Bearer ")
66 .ok_or_else(|| {
67 Error::http(
68 StatusCode::BAD_REQUEST,
69 "invalid authorization header".to_string(),
70 )
71 })?;
72
73 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
74 match LlmTokenClaims::validate(&token, &state.config) {
75 Ok(claims) => {
76 req.extensions_mut().insert(claims);
77 Ok::<_, Error>(next.run(req).await.into_response())
78 }
79 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
80 StatusCode::UNAUTHORIZED,
81 "unauthorized".to_string(),
82 [(
83 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
84 HeaderValue::from_static("true"),
85 )]
86 .into_iter()
87 .collect(),
88 )),
89 Err(_err) => Err(Error::http(
90 StatusCode::UNAUTHORIZED,
91 "unauthorized".to_string(),
92 )),
93 }
94}
95
96async fn perform_completion(
97 Extension(state): Extension<Arc<LlmState>>,
98 Extension(claims): Extension<LlmTokenClaims>,
99 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
100 Json(params): Json<PerformCompletionParams>,
101) -> Result<impl IntoResponse> {
102 authorize_access_to_language_model(
103 &state.config,
104 &claims,
105 country_code_header.map(|header| header.to_string()),
106 params.provider,
107 ¶ms.model,
108 )?;
109
110 match params.provider {
111 LanguageModelProvider::Anthropic => {
112 let api_key = state
113 .config
114 .anthropic_api_key
115 .as_ref()
116 .context("no Anthropic AI API key configured on the server")?;
117 let chunks = anthropic::stream_completion(
118 &state.http_client,
119 anthropic::ANTHROPIC_API_URL,
120 api_key,
121 serde_json::from_str(¶ms.provider_request.get())?,
122 None,
123 )
124 .await?;
125
126 let stream = chunks.map(|event| {
127 let mut buffer = Vec::new();
128 event.map(|chunk| {
129 buffer.clear();
130 serde_json::to_writer(&mut buffer, &chunk).unwrap();
131 buffer.push(b'\n');
132 buffer
133 })
134 });
135
136 Ok(Response::new(Body::wrap_stream(stream)))
137 }
138 LanguageModelProvider::OpenAi => {
139 let api_key = state
140 .config
141 .openai_api_key
142 .as_ref()
143 .context("no OpenAI API key configured on the server")?;
144 let chunks = open_ai::stream_completion(
145 &state.http_client,
146 open_ai::OPEN_AI_API_URL,
147 api_key,
148 serde_json::from_str(¶ms.provider_request.get())?,
149 None,
150 )
151 .await?;
152
153 let stream = chunks.map(|event| {
154 let mut buffer = Vec::new();
155 event.map(|chunk| {
156 buffer.clear();
157 serde_json::to_writer(&mut buffer, &chunk).unwrap();
158 buffer.push(b'\n');
159 buffer
160 })
161 });
162
163 Ok(Response::new(Body::wrap_stream(stream)))
164 }
165 LanguageModelProvider::Google => {
166 let api_key = state
167 .config
168 .google_ai_api_key
169 .as_ref()
170 .context("no Google AI API key configured on the server")?;
171 let chunks = google_ai::stream_generate_content(
172 &state.http_client,
173 google_ai::API_URL,
174 api_key,
175 serde_json::from_str(¶ms.provider_request.get())?,
176 )
177 .await?;
178
179 let stream = chunks.map(|event| {
180 let mut buffer = Vec::new();
181 event.map(|chunk| {
182 buffer.clear();
183 serde_json::to_writer(&mut buffer, &chunk).unwrap();
184 buffer.push(b'\n');
185 buffer
186 })
187 });
188
189 Ok(Response::new(Body::wrap_stream(stream)))
190 }
191 LanguageModelProvider::Zed => {
192 let api_key = state
193 .config
194 .qwen2_7b_api_key
195 .as_ref()
196 .context("no Qwen2-7B API key configured on the server")?;
197 let api_url = state
198 .config
199 .qwen2_7b_api_url
200 .as_ref()
201 .context("no Qwen2-7B URL configured on the server")?;
202 let chunks = open_ai::stream_completion(
203 &state.http_client,
204 &api_url,
205 api_key,
206 serde_json::from_str(¶ms.provider_request.get())?,
207 None,
208 )
209 .await?;
210
211 let stream = chunks.map(|event| {
212 let mut buffer = Vec::new();
213 event.map(|chunk| {
214 buffer.clear();
215 serde_json::to_writer(&mut buffer, &chunk).unwrap();
216 buffer.push(b'\n');
217 buffer
218 })
219 });
220
221 Ok(Response::new(Body::wrap_stream(stream)))
222 }
223 }
224}