llm.rs

  1mod authorization;
  2pub mod db;
  3mod token;
  4
  5use crate::api::CloudflareIpCountryHeader;
  6use crate::llm::authorization::authorize_access_to_language_model;
  7use crate::llm::db::LlmDatabase;
  8use crate::{executor::Executor, Config, Error, Result};
  9use anyhow::{anyhow, Context as _};
 10use axum::TypedHeader;
 11use axum::{
 12    body::Body,
 13    http::{self, HeaderName, HeaderValue, Request, StatusCode},
 14    middleware::{self, Next},
 15    response::{IntoResponse, Response},
 16    routing::post,
 17    Extension, Json, Router,
 18};
 19use futures::StreamExt as _;
 20use http_client::IsahcHttpClient;
 21use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 22use std::sync::Arc;
 23
 24pub use token::*;
 25
 26pub struct LlmState {
 27    pub config: Config,
 28    pub executor: Executor,
 29    pub db: Option<Arc<LlmDatabase>>,
 30    pub http_client: IsahcHttpClient,
 31}
 32
 33impl LlmState {
 34    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 35        // TODO: This is temporary until we have the LLM database stood up.
 36        let db = if config.is_development() {
 37            let database_url = config
 38                .llm_database_url
 39                .as_ref()
 40                .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
 41            let max_connections = config
 42                .llm_database_max_connections
 43                .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
 44
 45            let mut db_options = db::ConnectOptions::new(database_url);
 46            db_options.max_connections(max_connections);
 47            let db = LlmDatabase::new(db_options, executor.clone()).await?;
 48
 49            Some(Arc::new(db))
 50        } else {
 51            None
 52        };
 53
 54        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 55        let http_client = IsahcHttpClient::builder()
 56            .default_header("User-Agent", user_agent)
 57            .build()
 58            .context("failed to construct http client")?;
 59
 60        let this = Self {
 61            config,
 62            executor,
 63            db,
 64            http_client,
 65        };
 66
 67        Ok(Arc::new(this))
 68    }
 69}
 70
 71pub fn routes() -> Router<(), Body> {
 72    Router::new()
 73        .route("/completion", post(perform_completion))
 74        .layer(middleware::from_fn(validate_api_token))
 75}
 76
 77async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 78    let token = req
 79        .headers()
 80        .get(http::header::AUTHORIZATION)
 81        .and_then(|header| header.to_str().ok())
 82        .ok_or_else(|| {
 83            Error::http(
 84                StatusCode::BAD_REQUEST,
 85                "missing authorization header".to_string(),
 86            )
 87        })?
 88        .strip_prefix("Bearer ")
 89        .ok_or_else(|| {
 90            Error::http(
 91                StatusCode::BAD_REQUEST,
 92                "invalid authorization header".to_string(),
 93            )
 94        })?;
 95
 96    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
 97    match LlmTokenClaims::validate(&token, &state.config) {
 98        Ok(claims) => {
 99            req.extensions_mut().insert(claims);
100            Ok::<_, Error>(next.run(req).await.into_response())
101        }
102        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
103            StatusCode::UNAUTHORIZED,
104            "unauthorized".to_string(),
105            [(
106                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
107                HeaderValue::from_static("true"),
108            )]
109            .into_iter()
110            .collect(),
111        )),
112        Err(_err) => Err(Error::http(
113            StatusCode::UNAUTHORIZED,
114            "unauthorized".to_string(),
115        )),
116    }
117}
118
119async fn perform_completion(
120    Extension(state): Extension<Arc<LlmState>>,
121    Extension(claims): Extension<LlmTokenClaims>,
122    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
123    Json(params): Json<PerformCompletionParams>,
124) -> Result<impl IntoResponse> {
125    authorize_access_to_language_model(
126        &state.config,
127        &claims,
128        country_code_header.map(|header| header.to_string()),
129        params.provider,
130        &params.model,
131    )?;
132
133    match params.provider {
134        LanguageModelProvider::Anthropic => {
135            let api_key = state
136                .config
137                .anthropic_api_key
138                .as_ref()
139                .context("no Anthropic AI API key configured on the server")?;
140
141            let mut request: anthropic::Request =
142                serde_json::from_str(&params.provider_request.get())?;
143
144            // Parse the model, throw away the version that was included, and then set a specific
145            // version that we control on the server.
146            // Right now, we use the version that's defined in `model.id()`, but we will likely
147            // want to change this code once a new version of an Anthropic model is released,
148            // so that users can use the new version, without having to update Zed.
149            request.model = match anthropic::Model::from_id(&request.model) {
150                Ok(model) => model.id().to_string(),
151                Err(_) => request.model,
152            };
153
154            let chunks = anthropic::stream_completion(
155                &state.http_client,
156                anthropic::ANTHROPIC_API_URL,
157                api_key,
158                request,
159                None,
160            )
161            .await?;
162
163            let stream = chunks.map(|event| {
164                let mut buffer = Vec::new();
165                event.map(|chunk| {
166                    buffer.clear();
167                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
168                    buffer.push(b'\n');
169                    buffer
170                })
171            });
172
173            Ok(Response::new(Body::wrap_stream(stream)))
174        }
175        LanguageModelProvider::OpenAi => {
176            let api_key = state
177                .config
178                .openai_api_key
179                .as_ref()
180                .context("no OpenAI API key configured on the server")?;
181            let chunks = open_ai::stream_completion(
182                &state.http_client,
183                open_ai::OPEN_AI_API_URL,
184                api_key,
185                serde_json::from_str(&params.provider_request.get())?,
186                None,
187            )
188            .await?;
189
190            let stream = chunks.map(|event| {
191                let mut buffer = Vec::new();
192                event.map(|chunk| {
193                    buffer.clear();
194                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
195                    buffer.push(b'\n');
196                    buffer
197                })
198            });
199
200            Ok(Response::new(Body::wrap_stream(stream)))
201        }
202        LanguageModelProvider::Google => {
203            let api_key = state
204                .config
205                .google_ai_api_key
206                .as_ref()
207                .context("no Google AI API key configured on the server")?;
208            let chunks = google_ai::stream_generate_content(
209                &state.http_client,
210                google_ai::API_URL,
211                api_key,
212                serde_json::from_str(&params.provider_request.get())?,
213            )
214            .await?;
215
216            let stream = chunks.map(|event| {
217                let mut buffer = Vec::new();
218                event.map(|chunk| {
219                    buffer.clear();
220                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
221                    buffer.push(b'\n');
222                    buffer
223                })
224            });
225
226            Ok(Response::new(Body::wrap_stream(stream)))
227        }
228        LanguageModelProvider::Zed => {
229            let api_key = state
230                .config
231                .qwen2_7b_api_key
232                .as_ref()
233                .context("no Qwen2-7B API key configured on the server")?;
234            let api_url = state
235                .config
236                .qwen2_7b_api_url
237                .as_ref()
238                .context("no Qwen2-7B URL configured on the server")?;
239            let chunks = open_ai::stream_completion(
240                &state.http_client,
241                &api_url,
242                api_key,
243                serde_json::from_str(&params.provider_request.get())?,
244                None,
245            )
246            .await?;
247
248            let stream = chunks.map(|event| {
249                let mut buffer = Vec::new();
250                event.map(|chunk| {
251                    buffer.clear();
252                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
253                    buffer.push(b'\n');
254                    buffer
255                })
256            });
257
258            Ok(Response::new(Body::wrap_stream(stream)))
259        }
260    }
261}