llm.rs

  1mod authorization;
  2pub mod db;
  3mod token;
  4
  5use crate::api::CloudflareIpCountryHeader;
  6use crate::llm::authorization::authorize_access_to_language_model;
  7use crate::llm::db::LlmDatabase;
  8use crate::{executor::Executor, Config, Error, Result};
  9use anyhow::{anyhow, Context as _};
 10use axum::TypedHeader;
 11use axum::{
 12    body::Body,
 13    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 14    middleware::{self, Next},
 15    response::{IntoResponse, Response},
 16    routing::post,
 17    Extension, Json, Router,
 18};
 19use futures::StreamExt as _;
 20use http_client::IsahcHttpClient;
 21use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 22use std::sync::Arc;
 23
 24pub use token::*;
 25
 26pub struct LlmState {
 27    pub config: Config,
 28    pub executor: Executor,
 29    pub db: Option<Arc<LlmDatabase>>,
 30    pub http_client: IsahcHttpClient,
 31}
 32
 33impl LlmState {
 34    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 35        // TODO: This is temporary until we have the LLM database stood up.
 36        let db = if config.is_development() {
 37            let database_url = config
 38                .llm_database_url
 39                .as_ref()
 40                .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 41            let max_connections = config
 42                .llm_database_max_connections
 43                .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 44
 45            let mut db_options = db::ConnectOptions::new(database_url);
 46            db_options.max_connections(max_connections);
 47            let db = LlmDatabase::new(db_options, executor.clone()).await?;
 48
 49            Some(Arc::new(db))
 50        } else {
 51            None
 52        };
 53
 54        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 55        let http_client = IsahcHttpClient::builder()
 56            .default_header("User-Agent", user_agent)
 57            .build()
 58            .context("failed to construct http client")?;
 59
 60        let this = Self {
 61            config,
 62            executor,
 63            db,
 64            http_client,
 65        };
 66
 67        Ok(Arc::new(this))
 68    }
 69}
 70
 71pub fn routes() -> Router<(), Body> {
 72    Router::new()
 73        .route("/completion", post(perform_completion))
 74        .layer(middleware::from_fn(validate_api_token))
 75}
 76
 77async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 78    let token = req
 79        .headers()
 80        .get(http::header::AUTHORIZATION)
 81        .and_then(|header| header.to_str().ok())
 82        .ok_or_else(|| {
 83            Error::http(
 84                StatusCode::BAD_REQUEST,
 85                "missing authorization header".to_string(),
 86            )
 87        })?
 88        .strip_prefix("Bearer ")
 89        .ok_or_else(|| {
 90            Error::http(
 91                StatusCode::BAD_REQUEST,
 92                "invalid authorization header".to_string(),
 93            )
 94        })?;
 95
 96    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
 97    match LlmTokenClaims::validate(&token, &state.config) {
 98        Ok(claims) => {
 99            req.extensions_mut().insert(claims);
100            Ok::<_, Error>(next.run(req).await.into_response())
101        }
102        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
103            StatusCode::UNAUTHORIZED,
104            "unauthorized".to_string(),
105            [(
106                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
107                HeaderValue::from_static("true"),
108            )]
109            .into_iter()
110            .collect(),
111        )),
112        Err(_err) => Err(Error::http(
113            StatusCode::UNAUTHORIZED,
114            "unauthorized".to_string(),
115        )),
116    }
117}
118
119async fn perform_completion(
120    Extension(state): Extension<Arc<LlmState>>,
121    Extension(claims): Extension<LlmTokenClaims>,
122    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
123    Json(params): Json<PerformCompletionParams>,
124) -> Result<impl IntoResponse> {
125    authorize_access_to_language_model(
126        &state.config,
127        &claims,
128        country_code_header.map(|header| header.to_string()),
129        params.provider,
130        &params.model,
131    )?;
132
133    match params.provider {
134        LanguageModelProvider::Anthropic => {
135            let api_key = state
136                .config
137                .anthropic_api_key
138                .as_ref()
139                .context("no Anthropic AI API key configured on the server")?;
140            let chunks = anthropic::stream_completion(
141                &state.http_client,
142                anthropic::ANTHROPIC_API_URL,
143                api_key,
144                serde_json::from_str(&params.provider_request.get())?,
145                None,
146            )
147            .await?;
148
149            let stream = chunks.map(|event| {
150                let mut buffer = Vec::new();
151                event.map(|chunk| {
152                    buffer.clear();
153                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
154                    buffer.push(b'\n');
155                    buffer
156                })
157            });
158
159            Ok(Response::new(Body::wrap_stream(stream)))
160        }
161        LanguageModelProvider::OpenAi => {
162            let api_key = state
163                .config
164                .openai_api_key
165                .as_ref()
166                .context("no OpenAI API key configured on the server")?;
167            let chunks = open_ai::stream_completion(
168                &state.http_client,
169                open_ai::OPEN_AI_API_URL,
170                api_key,
171                serde_json::from_str(&params.provider_request.get())?,
172                None,
173            )
174            .await?;
175
176            let stream = chunks.map(|event| {
177                let mut buffer = Vec::new();
178                event.map(|chunk| {
179                    buffer.clear();
180                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
181                    buffer.push(b'\n');
182                    buffer
183                })
184            });
185
186            Ok(Response::new(Body::wrap_stream(stream)))
187        }
188        LanguageModelProvider::Google => {
189            let api_key = state
190                .config
191                .google_ai_api_key
192                .as_ref()
193                .context("no Google AI API key configured on the server")?;
194            let chunks = google_ai::stream_generate_content(
195                &state.http_client,
196                google_ai::API_URL,
197                api_key,
198                serde_json::from_str(&params.provider_request.get())?,
199            )
200            .await?;
201
202            let stream = chunks.map(|event| {
203                let mut buffer = Vec::new();
204                event.map(|chunk| {
205                    buffer.clear();
206                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
207                    buffer.push(b'\n');
208                    buffer
209                })
210            });
211
212            Ok(Response::new(Body::wrap_stream(stream)))
213        }
214        LanguageModelProvider::Zed => {
215            let api_key = state
216                .config
217                .qwen2_7b_api_key
218                .as_ref()
219                .context("no Qwen2-7B API key configured on the server")?;
220            let api_url = state
221                .config
222                .qwen2_7b_api_url
223                .as_ref()
224                .context("no Qwen2-7B URL configured on the server")?;
225            let chunks = open_ai::stream_completion(
226                &state.http_client,
227                &api_url,
228                api_key,
229                serde_json::from_str(&params.provider_request.get())?,
230                None,
231            )
232            .await?;
233
234            let stream = chunks.map(|event| {
235                let mut buffer = Vec::new();
236                event.map(|chunk| {
237                    buffer.clear();
238                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
239                    buffer.push(b'\n');
240                    buffer
241                })
242            });
243
244            Ok(Response::new(Body::wrap_stream(stream)))
245        }
246    }
247}