llm.rs

  1mod authorization;
  2mod token;
  3
  4use crate::api::CloudflareIpCountryHeader;
  5use crate::llm::authorization::authorize_access_to_language_model;
  6use crate::{executor::Executor, Config, Error, Result};
  7use anyhow::Context as _;
  8use axum::TypedHeader;
  9use axum::{
 10    body::Body,
 11    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 12    middleware::{self, Next},
 13    response::{IntoResponse, Response},
 14    routing::post,
 15    Extension, Json, Router,
 16};
 17use futures::StreamExt as _;
 18use http_client::IsahcHttpClient;
 19use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 20use std::sync::Arc;
 21
 22pub use token::*;
 23
 24pub struct LlmState {
 25    pub config: Config,
 26    pub executor: Executor,
 27    pub http_client: IsahcHttpClient,
 28}
 29
 30impl LlmState {
 31    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 32        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 33        let http_client = IsahcHttpClient::builder()
 34            .default_header("User-Agent", user_agent)
 35            .build()
 36            .context("failed to construct http client")?;
 37
 38        let this = Self {
 39            config,
 40            executor,
 41            http_client,
 42        };
 43
 44        Ok(Arc::new(this))
 45    }
 46}
 47
 48pub fn routes() -> Router<(), Body> {
 49    Router::new()
 50        .route("/completion", post(perform_completion))
 51        .layer(middleware::from_fn(validate_api_token))
 52}
 53
 54async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 55    let token = req
 56        .headers()
 57        .get(http::header::AUTHORIZATION)
 58        .and_then(|header| header.to_str().ok())
 59        .ok_or_else(|| {
 60            Error::http(
 61                StatusCode::BAD_REQUEST,
 62                "missing authorization header".to_string(),
 63            )
 64        })?
 65        .strip_prefix("Bearer ")
 66        .ok_or_else(|| {
 67            Error::http(
 68                StatusCode::BAD_REQUEST,
 69                "invalid authorization header".to_string(),
 70            )
 71        })?;
 72
 73    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
 74    match LlmTokenClaims::validate(&token, &state.config) {
 75        Ok(claims) => {
 76            req.extensions_mut().insert(claims);
 77            Ok::<_, Error>(next.run(req).await.into_response())
 78        }
 79        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
 80            StatusCode::UNAUTHORIZED,
 81            "unauthorized".to_string(),
 82            [(
 83                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
 84                HeaderValue::from_static("true"),
 85            )]
 86            .into_iter()
 87            .collect(),
 88        )),
 89        Err(_err) => Err(Error::http(
 90            StatusCode::UNAUTHORIZED,
 91            "unauthorized".to_string(),
 92        )),
 93    }
 94}
 95
 96async fn perform_completion(
 97    Extension(state): Extension<Arc<LlmState>>,
 98    Extension(claims): Extension<LlmTokenClaims>,
 99    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
100    Json(params): Json<PerformCompletionParams>,
101) -> Result<impl IntoResponse> {
102    authorize_access_to_language_model(
103        &state.config,
104        &claims,
105        country_code_header.map(|header| header.to_string()),
106        params.provider,
107        &params.model,
108    )?;
109
110    match params.provider {
111        LanguageModelProvider::Anthropic => {
112            let api_key = state
113                .config
114                .anthropic_api_key
115                .as_ref()
116                .context("no Anthropic AI API key configured on the server")?;
117            let chunks = anthropic::stream_completion(
118                &state.http_client,
119                anthropic::ANTHROPIC_API_URL,
120                api_key,
121                serde_json::from_str(&params.provider_request.get())?,
122                None,
123            )
124            .await?;
125
126            let stream = chunks.map(|event| {
127                let mut buffer = Vec::new();
128                event.map(|chunk| {
129                    buffer.clear();
130                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
131                    buffer.push(b'\n');
132                    buffer
133                })
134            });
135
136            Ok(Response::new(Body::wrap_stream(stream)))
137        }
138        LanguageModelProvider::OpenAi => {
139            let api_key = state
140                .config
141                .openai_api_key
142                .as_ref()
143                .context("no OpenAI API key configured on the server")?;
144            let chunks = open_ai::stream_completion(
145                &state.http_client,
146                open_ai::OPEN_AI_API_URL,
147                api_key,
148                serde_json::from_str(&params.provider_request.get())?,
149                None,
150            )
151            .await?;
152
153            let stream = chunks.map(|event| {
154                let mut buffer = Vec::new();
155                event.map(|chunk| {
156                    buffer.clear();
157                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
158                    buffer.push(b'\n');
159                    buffer
160                })
161            });
162
163            Ok(Response::new(Body::wrap_stream(stream)))
164        }
165        LanguageModelProvider::Google => {
166            let api_key = state
167                .config
168                .google_ai_api_key
169                .as_ref()
170                .context("no Google AI API key configured on the server")?;
171            let chunks = google_ai::stream_generate_content(
172                &state.http_client,
173                google_ai::API_URL,
174                api_key,
175                serde_json::from_str(&params.provider_request.get())?,
176            )
177            .await?;
178
179            let stream = chunks.map(|event| {
180                let mut buffer = Vec::new();
181                event.map(|chunk| {
182                    buffer.clear();
183                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
184                    buffer.push(b'\n');
185                    buffer
186                })
187            });
188
189            Ok(Response::new(Body::wrap_stream(stream)))
190        }
191        LanguageModelProvider::Zed => {
192            let api_key = state
193                .config
194                .qwen2_7b_api_key
195                .as_ref()
196                .context("no Qwen2-7B API key configured on the server")?;
197            let api_url = state
198                .config
199                .qwen2_7b_api_url
200                .as_ref()
201                .context("no Qwen2-7B URL configured on the server")?;
202            let chunks = open_ai::stream_completion(
203                &state.http_client,
204                &api_url,
205                api_key,
206                serde_json::from_str(&params.provider_request.get())?,
207                None,
208            )
209            .await?;
210
211            let stream = chunks.map(|event| {
212                let mut buffer = Vec::new();
213                event.map(|chunk| {
214                    buffer.clear();
215                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
216                    buffer.push(b'\n');
217                    buffer
218                })
219            });
220
221            Ok(Response::new(Body::wrap_stream(stream)))
222        }
223    }
224}