@@ -169,7 +169,10 @@ async fn perform_completion(
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
- let model = normalize_model_name(params.provider, params.model);
+ let model = normalize_model_name(
+ state.db.model_names_for_provider(params.provider),
+ params.model,
+ );
authorize_access_to_language_model(
&state.config,
@@ -200,14 +203,18 @@ async fn perform_completion(
let mut request: anthropic::Request =
serde_json::from_str(¶ms.provider_request.get())?;
- // Parse the model, throw away the version that was included, and then set a specific
- // version that we control on the server.
+ // Override the model on the request with the latest version of the model that is
+ // known to the server.
+ //
// Right now, we use the version that's defined in `model.id()`, but we will likely
// want to change this code once a new version of an Anthropic model is released,
// so that users can use the new version, without having to update Zed.
- request.model = match anthropic::Model::from_id(&request.model) {
- Ok(model) => model.id().to_string(),
- Err(_) => request.model,
+ request.model = match model.as_str() {
+ "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
+ "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
+ "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
+ "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
+ _ => request.model,
};
let chunks = anthropic::stream_completion(
@@ -369,31 +376,13 @@ async fn perform_completion(
})))
}
-fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
- let prefixes: &[_] = match provider {
- LanguageModelProvider::Anthropic => &[
- "claude-3-5-sonnet",
- "claude-3-haiku",
- "claude-3-opus",
- "claude-3-sonnet",
- ],
- LanguageModelProvider::OpenAi => &[
- "gpt-3.5-turbo",
- "gpt-4-turbo-preview",
- "gpt-4o-mini",
- "gpt-4o",
- "gpt-4",
- ],
- LanguageModelProvider::Google => &[],
- LanguageModelProvider::Zed => &[],
- };
-
- if let Some(prefix) = prefixes
+fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
+ if let Some(known_model_name) = known_models
.iter()
- .filter(|&&prefix| name.starts_with(prefix))
- .max_by_key(|&&prefix| prefix.len())
+ .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
+ .max_by_key(|known_model_name| known_model_name.len())
{
- prefix.to_string()
+ known_model_name.to_string()
} else {
name
}
@@ -26,9 +26,7 @@ fn authorize_access_to_model(
}
match (provider, model) {
- (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
- Ok(())
- }
+ (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
_ => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
@@ -67,6 +67,21 @@ impl LlmDatabase {
Ok(())
}
+ /// Returns the names of the known models for the given [`LanguageModelProvider`].
+ pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
+ self.models
+ .keys()
+ .filter_map(|(model_provider, model_name)| {
+ if model_provider == &provider {
+ Some(model_name)
+ } else {
+ None
+ }
+ })
+ .cloned()
+ .collect::<Vec<_>>()
+ }
+
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
Ok(self
.models