collab: Rework model name checks (#16365)

Marshall Bowers created

This PR reworks how we do checks for model names in the LLM service.

We now normalize the model names using the models defined in the
database.

Release Notes:

- N/A

Change summary

crates/collab/src/llm.rs               | 47 ++++++++++-----------------
crates/collab/src/llm/authorization.rs |  4 -
crates/collab/src/llm/db.rs            | 15 ++++++++
3 files changed, 34 insertions(+), 32 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -169,7 +169,10 @@ async fn perform_completion(
     country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
     Json(params): Json<PerformCompletionParams>,
 ) -> Result<impl IntoResponse> {
-    let model = normalize_model_name(params.provider, params.model);
+    let model = normalize_model_name(
+        state.db.model_names_for_provider(params.provider),
+        params.model,
+    );
 
     authorize_access_to_language_model(
         &state.config,
@@ -200,14 +203,18 @@ async fn perform_completion(
             let mut request: anthropic::Request =
                 serde_json::from_str(&params.provider_request.get())?;
 
-            // Parse the model, throw away the version that was included, and then set a specific
-            // version that we control on the server.
+            // Override the model on the request with the latest version of the model that is
+            // known to the server.
+            //
             // Right now, we use the version that's defined in `model.id()`, but we will likely
             // want to change this code once a new version of an Anthropic model is released,
             // so that users can use the new version, without having to update Zed.
-            request.model = match anthropic::Model::from_id(&request.model) {
-                Ok(model) => model.id().to_string(),
-                Err(_) => request.model,
+            request.model = match model.as_str() {
+                "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
+                "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
+                "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
+                "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
+                _ => request.model,
             };
 
             let chunks = anthropic::stream_completion(
@@ -369,31 +376,13 @@ async fn perform_completion(
     })))
 }
 
-fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
-    let prefixes: &[_] = match provider {
-        LanguageModelProvider::Anthropic => &[
-            "claude-3-5-sonnet",
-            "claude-3-haiku",
-            "claude-3-opus",
-            "claude-3-sonnet",
-        ],
-        LanguageModelProvider::OpenAi => &[
-            "gpt-3.5-turbo",
-            "gpt-4-turbo-preview",
-            "gpt-4o-mini",
-            "gpt-4o",
-            "gpt-4",
-        ],
-        LanguageModelProvider::Google => &[],
-        LanguageModelProvider::Zed => &[],
-    };
-
-    if let Some(prefix) = prefixes
+fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
+    if let Some(known_model_name) = known_models
         .iter()
-        .filter(|&&prefix| name.starts_with(prefix))
-        .max_by_key(|&&prefix| prefix.len())
+        .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
+        .max_by_key(|known_model_name| known_model_name.len())
     {
-        prefix.to_string()
+        known_model_name.to_string()
     } else {
         name
     }

crates/collab/src/llm/authorization.rs 🔗

@@ -26,9 +26,7 @@ fn authorize_access_to_model(
     }
 
     match (provider, model) {
-        (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
-            Ok(())
-        }
+        (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
         _ => Err(Error::http(
             StatusCode::FORBIDDEN,
             format!("access to model {model:?} is not included in your plan"),

crates/collab/src/llm/db.rs 🔗

@@ -67,6 +67,21 @@ impl LlmDatabase {
         Ok(())
     }
 
+    /// Returns the names of the known models for the given [`LanguageModelProvider`].
+    pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
+        self.models
+            .keys()
+            .filter_map(|(model_provider, model_name)| {
+                if model_provider == &provider {
+                    Some(model_name)
+                } else {
+                    None
+                }
+            })
+            .cloned()
+            .collect::<Vec<_>>()
+    }
+
     pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
         Ok(self
             .models