1mod token;
2
3use crate::{executor::Executor, Config, Error, Result};
4use anyhow::Context as _;
5use axum::{
6 body::Body,
7 http::{self, HeaderName, HeaderValue, Request, StatusCode},
8 middleware::{self, Next},
9 response::{IntoResponse, Response},
10 routing::post,
11 Extension, Json, Router,
12};
13use futures::StreamExt as _;
14use http_client::IsahcHttpClient;
15use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
16use std::sync::Arc;
17
18pub use token::*;
19
20pub struct LlmState {
21 pub config: Config,
22 pub executor: Executor,
23 pub http_client: IsahcHttpClient,
24}
25
26impl LlmState {
27 pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
28 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
29 let http_client = IsahcHttpClient::builder()
30 .default_header("User-Agent", user_agent)
31 .build()
32 .context("failed to construct http client")?;
33
34 let this = Self {
35 config,
36 executor,
37 http_client,
38 };
39
40 Ok(Arc::new(this))
41 }
42}
43
44pub fn routes() -> Router<(), Body> {
45 Router::new()
46 .route("/completion", post(perform_completion))
47 .layer(middleware::from_fn(validate_api_token))
48}
49
50async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
51 let token = req
52 .headers()
53 .get(http::header::AUTHORIZATION)
54 .and_then(|header| header.to_str().ok())
55 .ok_or_else(|| {
56 Error::http(
57 StatusCode::BAD_REQUEST,
58 "missing authorization header".to_string(),
59 )
60 })?
61 .strip_prefix("Bearer ")
62 .ok_or_else(|| {
63 Error::http(
64 StatusCode::BAD_REQUEST,
65 "invalid authorization header".to_string(),
66 )
67 })?;
68
69 let state = req.extensions().get::<Arc<LlmState>>().unwrap();
70 match LlmTokenClaims::validate(&token, &state.config) {
71 Ok(claims) => {
72 req.extensions_mut().insert(claims);
73 Ok::<_, Error>(next.run(req).await.into_response())
74 }
75 Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
76 StatusCode::UNAUTHORIZED,
77 "unauthorized".to_string(),
78 [(
79 HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
80 HeaderValue::from_static("true"),
81 )]
82 .into_iter()
83 .collect(),
84 )),
85 Err(_err) => Err(Error::http(
86 StatusCode::UNAUTHORIZED,
87 "unauthorized".to_string(),
88 )),
89 }
90}
91
92async fn perform_completion(
93 Extension(state): Extension<Arc<LlmState>>,
94 Extension(_claims): Extension<LlmTokenClaims>,
95 Json(params): Json<PerformCompletionParams>,
96) -> Result<impl IntoResponse> {
97 match params.provider {
98 LanguageModelProvider::Anthropic => {
99 let api_key = state
100 .config
101 .anthropic_api_key
102 .as_ref()
103 .context("no Anthropic AI API key configured on the server")?;
104 let chunks = anthropic::stream_completion(
105 &state.http_client,
106 anthropic::ANTHROPIC_API_URL,
107 api_key,
108 serde_json::from_str(¶ms.provider_request.get())?,
109 None,
110 )
111 .await?;
112
113 let stream = chunks.map(|event| {
114 let mut buffer = Vec::new();
115 event.map(|chunk| {
116 buffer.clear();
117 serde_json::to_writer(&mut buffer, &chunk).unwrap();
118 buffer.push(b'\n');
119 buffer
120 })
121 });
122
123 Ok(Response::new(Body::wrap_stream(stream)))
124 }
125 LanguageModelProvider::OpenAi => {
126 let api_key = state
127 .config
128 .openai_api_key
129 .as_ref()
130 .context("no OpenAI API key configured on the server")?;
131 let chunks = open_ai::stream_completion(
132 &state.http_client,
133 open_ai::OPEN_AI_API_URL,
134 api_key,
135 serde_json::from_str(¶ms.provider_request.get())?,
136 None,
137 )
138 .await?;
139
140 let stream = chunks.map(|event| {
141 let mut buffer = Vec::new();
142 event.map(|chunk| {
143 buffer.clear();
144 serde_json::to_writer(&mut buffer, &chunk).unwrap();
145 buffer.push(b'\n');
146 buffer
147 })
148 });
149
150 Ok(Response::new(Body::wrap_stream(stream)))
151 }
152 LanguageModelProvider::Google => {
153 let api_key = state
154 .config
155 .google_ai_api_key
156 .as_ref()
157 .context("no Google AI API key configured on the server")?;
158 let chunks = google_ai::stream_generate_content(
159 &state.http_client,
160 google_ai::API_URL,
161 api_key,
162 serde_json::from_str(¶ms.provider_request.get())?,
163 )
164 .await?;
165
166 let stream = chunks.map(|event| {
167 let mut buffer = Vec::new();
168 event.map(|chunk| {
169 buffer.clear();
170 serde_json::to_writer(&mut buffer, &chunk).unwrap();
171 buffer.push(b'\n');
172 buffer
173 })
174 });
175
176 Ok(Response::new(Body::wrap_stream(stream)))
177 }
178 LanguageModelProvider::Zed => {
179 let api_key = state
180 .config
181 .qwen2_7b_api_key
182 .as_ref()
183 .context("no Qwen2-7B API key configured on the server")?;
184 let api_url = state
185 .config
186 .qwen2_7b_api_url
187 .as_ref()
188 .context("no Qwen2-7B URL configured on the server")?;
189 let chunks = open_ai::stream_completion(
190 &state.http_client,
191 &api_url,
192 api_key,
193 serde_json::from_str(¶ms.provider_request.get())?,
194 None,
195 )
196 .await?;
197
198 let stream = chunks.map(|event| {
199 let mut buffer = Vec::new();
200 event.map(|chunk| {
201 buffer.clear();
202 serde_json::to_writer(&mut buffer, &chunk).unwrap();
203 buffer.push(b'\n');
204 buffer
205 })
206 });
207
208 Ok(Response::new(Body::wrap_stream(stream)))
209 }
210 }
211}