From 30056254f31ce2095059b9ee122bcd140ed8e4f2 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 3 Sep 2024 11:41:32 -0400 Subject: [PATCH] collab: Add `GET /models` endpoint to LLM service (#17307) This PR adds a `GET /models` endpoint to the LLM service. This endpoint returns the models that the authenticated user has access to. This is the first step towards populating the models for the hosted service from the server. Release Notes: - N/A --- crates/collab/src/llm.rs | 38 +++++++++++++++++++++++++- crates/collab/src/llm/authorization.rs | 30 +++++++------------- crates/collab/src/llm/db.rs | 8 ++++++ crates/rpc/src/llm.rs | 13 ++++++++- 4 files changed, 67 insertions(+), 22 deletions(-) diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index cf98604e20a5a6b2b06880159b3b06d1d230c4fe..320d7418eeb7335ce3982a1634f27351a55dc220 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -9,6 +9,7 @@ use crate::{ }; use anyhow::{anyhow, Context as _}; use authorization::authorize_access_to_language_model; +use axum::routing::get; use axum::{ body::Body, http::{self, HeaderName, HeaderValue, Request, StatusCode}, @@ -22,6 +23,7 @@ use collections::HashMap; use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; use futures::{Stream, StreamExt as _}; use http_client::IsahcHttpClient; +use rpc::ListModelsResponse; use rpc::{ proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, }; @@ -114,6 +116,7 @@ impl LlmState { pub fn routes() -> Router<(), Body> { Router::new() + .route("/models", get(list_models)) .route("/completion", post(perform_completion)) .layer(middleware::from_fn(validate_api_token)) } @@ -173,6 +176,37 @@ async fn validate_api_token(mut req: Request, next: Next) -> impl IntoR } } +async fn list_models( + Extension(state): Extension>, + Extension(claims): Extension, + country_code_header: Option>, +) -> Result> { + let country_code = country_code_header.map(|header| header.to_string()); + + let mut accessible_models = Vec::new(); + + for (provider, model) in state.db.all_models() { + let authorize_result = authorize_access_to_language_model( + &state.config, + &claims, + country_code.as_deref(), + provider, + &model.name, + ); + + if authorize_result.is_ok() { + accessible_models.push(rpc::LanguageModel { + provider, + name: model.name, + }); + } + } + + Ok(Json(ListModelsResponse { + models: accessible_models, + })) +} + async fn perform_completion( Extension(state): Extension>, Extension(claims): Extension, @@ -187,7 +221,9 @@ async fn perform_completion( authorize_access_to_language_model( &state.config, &claims, - country_code_header.map(|header| header.to_string()), + country_code_header + .map(|header| header.to_string()) + .as_deref(), params.provider, &model, )?; diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs index 98ee1b7c6a73c20ce1af68a7f4d050c46f0d6b79..f6acff268542d14d923b01dccf75bf91ab7d93e4 100644 --- a/crates/collab/src/llm/authorization.rs +++ b/crates/collab/src/llm/authorization.rs @@ -7,7 +7,7 @@ use crate::{Config, Error, Result}; pub fn authorize_access_to_language_model( config: &Config, claims: &LlmTokenClaims, - country_code: Option, + country_code: Option<&str>, provider: LanguageModelProvider, model: &str, ) -> Result<()> { @@ -49,7 +49,7 @@ fn authorize_access_to_model( fn authorize_access_for_country( config: &Config, - country_code: Option, + country_code: Option<&str>, provider: LanguageModelProvider, ) -> Result<()> { // In development we won't have the `CF-IPCountry` header, so we can't check @@ -62,7 +62,7 @@ fn authorize_access_for_country( } // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry - let country_code = match country_code.as_deref() { + let country_code = match country_code { // `XX` - Used for clients without country code data. None | Some("XX") => Err(Error::http( StatusCode::BAD_REQUEST, @@ -128,7 +128,7 @@ mod tests { authorize_access_to_language_model( &config, &claims, - Some(country_code.into()), + Some(country_code), provider, "the-model", ) @@ -178,7 +178,7 @@ mod tests { let error_response = authorize_access_to_language_model( &config, &claims, - Some(country_code.into()), + Some(country_code), provider, "the-model", ) @@ -223,7 +223,7 @@ mod tests { let error_response = authorize_access_to_language_model( &config, &claims, - Some(country_code.into()), + Some(country_code), provider, "the-model", ) @@ -278,13 +278,8 @@ mod tests { ..Default::default() }; - let result = authorize_access_to_language_model( - &config, - &claims, - Some("US".into()), - provider, - model, - ); + let result = + authorize_access_to_language_model(&config, &claims, Some("US"), provider, model); if expected_access { assert!( @@ -324,13 +319,8 @@ mod tests { ]; for (provider, model) in test_cases { - let result = authorize_access_to_language_model( - &config, - &claims, - Some("US".into()), - provider, - model, - ); + let result = + authorize_access_to_language_model(&config, &claims, Some("US"), provider, model); assert!( result.is_ok(), diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index f76a722471e760f976506b715fb836e85f4fd98f..cd370b14b146cb8a6ad123f531c9947a7e5d5b84 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -67,6 +67,14 @@ impl LlmDatabase { Ok(()) } + /// Returns the list of all known models, with their [`LanguageModelProvider`]. + pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> { + self.models + .iter() + .map(|((model_provider, _model_name), model)| (*model_provider, model.clone())) + .collect::>() + } + /// Returns the names of the known models for the given [`LanguageModelProvider`]. pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec { self.models diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index 7f97b02df7855222053aa4bd80596f00e5b6ed1d..6cae54b3090d564d333856b5e596ebb2c510c9d4 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -15,7 +15,18 @@ pub enum LanguageModelProvider { Zed, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] +pub struct LanguageModel { + pub provider: LanguageModelProvider, + pub name: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] pub struct PerformCompletionParams { pub provider: LanguageModelProvider, pub model: String,