Detailed changes
@@ -9,6 +9,7 @@ use crate::{
};
use anyhow::{anyhow, Context as _};
use authorization::authorize_access_to_language_model;
+use axum::routing::get;
use axum::{
body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode},
@@ -22,6 +23,7 @@ use collections::HashMap;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use http_client::IsahcHttpClient;
+use rpc::ListModelsResponse;
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
@@ -114,6 +116,7 @@ impl LlmState {
pub fn routes() -> Router<(), Body> {
Router::new()
+ .route("/models", get(list_models))
.route("/completion", post(perform_completion))
.layer(middleware::from_fn(validate_api_token))
}
@@ -173,6 +176,37 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
}
}
+async fn list_models(
+ Extension(state): Extension<Arc<LlmState>>,
+ Extension(claims): Extension<LlmTokenClaims>,
+ country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
+) -> Result<Json<ListModelsResponse>> {
+ let country_code = country_code_header.map(|header| header.to_string());
+
+ let mut accessible_models = Vec::new();
+
+ for (provider, model) in state.db.all_models() {
+ let authorize_result = authorize_access_to_language_model(
+ &state.config,
+ &claims,
+ country_code.as_deref(),
+ provider,
+ &model.name,
+ );
+
+ if authorize_result.is_ok() {
+ accessible_models.push(rpc::LanguageModel {
+ provider,
+ name: model.name,
+ });
+ }
+ }
+
+ Ok(Json(ListModelsResponse {
+ models: accessible_models,
+ }))
+}
+
async fn perform_completion(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
@@ -187,7 +221,9 @@ async fn perform_completion(
authorize_access_to_language_model(
&state.config,
&claims,
- country_code_header.map(|header| header.to_string()),
+ country_code_header
+ .map(|header| header.to_string())
+ .as_deref(),
params.provider,
&model,
)?;
@@ -7,7 +7,7 @@ use crate::{Config, Error, Result};
pub fn authorize_access_to_language_model(
config: &Config,
claims: &LlmTokenClaims,
- country_code: Option<String>,
+ country_code: Option<&str>,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
@@ -49,7 +49,7 @@ fn authorize_access_to_model(
fn authorize_access_for_country(
config: &Config,
- country_code: Option<String>,
+ country_code: Option<&str>,
provider: LanguageModelProvider,
) -> Result<()> {
// In development we won't have the `CF-IPCountry` header, so we can't check
@@ -62,7 +62,7 @@ fn authorize_access_for_country(
}
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
- let country_code = match country_code.as_deref() {
+ let country_code = match country_code {
// `XX` - Used for clients without country code data.
None | Some("XX") => Err(Error::http(
StatusCode::BAD_REQUEST,
@@ -128,7 +128,7 @@ mod tests {
authorize_access_to_language_model(
&config,
&claims,
- Some(country_code.into()),
+ Some(country_code),
provider,
"the-model",
)
@@ -178,7 +178,7 @@ mod tests {
let error_response = authorize_access_to_language_model(
&config,
&claims,
- Some(country_code.into()),
+ Some(country_code),
provider,
"the-model",
)
@@ -223,7 +223,7 @@ mod tests {
let error_response = authorize_access_to_language_model(
&config,
&claims,
- Some(country_code.into()),
+ Some(country_code),
provider,
"the-model",
)
@@ -278,13 +278,8 @@ mod tests {
..Default::default()
};
- let result = authorize_access_to_language_model(
- &config,
- &claims,
- Some("US".into()),
- provider,
- model,
- );
+ let result =
+ authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
if expected_access {
assert!(
@@ -324,13 +319,8 @@ mod tests {
];
for (provider, model) in test_cases {
- let result = authorize_access_to_language_model(
- &config,
- &claims,
- Some("US".into()),
- provider,
- model,
- );
+ let result =
+ authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
assert!(
result.is_ok(),
@@ -67,6 +67,14 @@ impl LlmDatabase {
Ok(())
}
+ /// Returns the list of all known models, with their [`LanguageModelProvider`].
+ pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
+ self.models
+ .iter()
+ .map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
+ .collect::<Vec<_>>()
+ }
+
/// Returns the names of the known models for the given [`LanguageModelProvider`].
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
self.models
@@ -15,7 +15,18 @@ pub enum LanguageModelProvider {
Zed,
}
-#[derive(Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
+pub struct LanguageModel {
+ pub provider: LanguageModelProvider,
+ pub name: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ListModelsResponse {
+ pub models: Vec<LanguageModel>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
pub struct PerformCompletionParams {
pub provider: LanguageModelProvider,
pub model: String,