llm.rs

  1mod token;
  2
  3use crate::{executor::Executor, Config, Error, Result};
  4use anyhow::Context as _;
  5use axum::{
  6    body::Body,
  7    http::{self, HeaderName, HeaderValue, Request, StatusCode},
  8    middleware::{self, Next},
  9    response::{IntoResponse, Response},
 10    routing::post,
 11    Extension, Json, Router,
 12};
 13use futures::StreamExt as _;
 14use http_client::IsahcHttpClient;
 15use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 16use std::sync::Arc;
 17
 18pub use token::*;
 19
 20pub struct LlmState {
 21    pub config: Config,
 22    pub executor: Executor,
 23    pub http_client: IsahcHttpClient,
 24}
 25
 26impl LlmState {
 27    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
 28        let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 29        let http_client = IsahcHttpClient::builder()
 30            .default_header("User-Agent", user_agent)
 31            .build()
 32            .context("failed to construct http client")?;
 33
 34        let this = Self {
 35            config,
 36            executor,
 37            http_client,
 38        };
 39
 40        Ok(Arc::new(this))
 41    }
 42}
 43
 44pub fn routes() -> Router<(), Body> {
 45    Router::new()
 46        .route("/completion", post(perform_completion))
 47        .layer(middleware::from_fn(validate_api_token))
 48}
 49
 50async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 51    let token = req
 52        .headers()
 53        .get(http::header::AUTHORIZATION)
 54        .and_then(|header| header.to_str().ok())
 55        .ok_or_else(|| {
 56            Error::http(
 57                StatusCode::BAD_REQUEST,
 58                "missing authorization header".to_string(),
 59            )
 60        })?
 61        .strip_prefix("Bearer ")
 62        .ok_or_else(|| {
 63            Error::http(
 64                StatusCode::BAD_REQUEST,
 65                "invalid authorization header".to_string(),
 66            )
 67        })?;
 68
 69    let state = req.extensions().get::<Arc<LlmState>>().unwrap();
 70    match LlmTokenClaims::validate(&token, &state.config) {
 71        Ok(claims) => {
 72            req.extensions_mut().insert(claims);
 73            Ok::<_, Error>(next.run(req).await.into_response())
 74        }
 75        Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
 76            StatusCode::UNAUTHORIZED,
 77            "unauthorized".to_string(),
 78            [(
 79                HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
 80                HeaderValue::from_static("true"),
 81            )]
 82            .into_iter()
 83            .collect(),
 84        )),
 85        Err(_err) => Err(Error::http(
 86            StatusCode::UNAUTHORIZED,
 87            "unauthorized".to_string(),
 88        )),
 89    }
 90}
 91
 92async fn perform_completion(
 93    Extension(state): Extension<Arc<LlmState>>,
 94    Extension(_claims): Extension<LlmTokenClaims>,
 95    Json(params): Json<PerformCompletionParams>,
 96) -> Result<impl IntoResponse> {
 97    match params.provider {
 98        LanguageModelProvider::Anthropic => {
 99            let api_key = state
100                .config
101                .anthropic_api_key
102                .as_ref()
103                .context("no Anthropic AI API key configured on the server")?;
104            let chunks = anthropic::stream_completion(
105                &state.http_client,
106                anthropic::ANTHROPIC_API_URL,
107                api_key,
108                serde_json::from_str(&params.provider_request.get())?,
109                None,
110            )
111            .await?;
112
113            let stream = chunks.map(|event| {
114                let mut buffer = Vec::new();
115                event.map(|chunk| {
116                    buffer.clear();
117                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
118                    buffer.push(b'\n');
119                    buffer
120                })
121            });
122
123            Ok(Response::new(Body::wrap_stream(stream)))
124        }
125        LanguageModelProvider::OpenAi => {
126            let api_key = state
127                .config
128                .openai_api_key
129                .as_ref()
130                .context("no OpenAI API key configured on the server")?;
131            let chunks = open_ai::stream_completion(
132                &state.http_client,
133                open_ai::OPEN_AI_API_URL,
134                api_key,
135                serde_json::from_str(&params.provider_request.get())?,
136                None,
137            )
138            .await?;
139
140            let stream = chunks.map(|event| {
141                let mut buffer = Vec::new();
142                event.map(|chunk| {
143                    buffer.clear();
144                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
145                    buffer.push(b'\n');
146                    buffer
147                })
148            });
149
150            Ok(Response::new(Body::wrap_stream(stream)))
151        }
152        LanguageModelProvider::Google => {
153            let api_key = state
154                .config
155                .google_ai_api_key
156                .as_ref()
157                .context("no Google AI API key configured on the server")?;
158            let chunks = google_ai::stream_generate_content(
159                &state.http_client,
160                google_ai::API_URL,
161                api_key,
162                serde_json::from_str(&params.provider_request.get())?,
163            )
164            .await?;
165
166            let stream = chunks.map(|event| {
167                let mut buffer = Vec::new();
168                event.map(|chunk| {
169                    buffer.clear();
170                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
171                    buffer.push(b'\n');
172                    buffer
173                })
174            });
175
176            Ok(Response::new(Body::wrap_stream(stream)))
177        }
178        LanguageModelProvider::Zed => {
179            let api_key = state
180                .config
181                .qwen2_7b_api_key
182                .as_ref()
183                .context("no Qwen2-7B API key configured on the server")?;
184            let api_url = state
185                .config
186                .qwen2_7b_api_url
187                .as_ref()
188                .context("no Qwen2-7B URL configured on the server")?;
189            let chunks = open_ai::stream_completion(
190                &state.http_client,
191                &api_url,
192                api_key,
193                serde_json::from_str(&params.provider_request.get())?,
194                None,
195            )
196            .await?;
197
198            let stream = chunks.map(|event| {
199                let mut buffer = Vec::new();
200                event.map(|chunk| {
201                    buffer.clear();
202                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
203                    buffer.push(b'\n');
204                    buffer
205                })
206            });
207
208            Ok(Response::new(Body::wrap_stream(stream)))
209        }
210    }
211}