Add the ability to customize available models for OpenAI-compatible services (#13276)

ᴀᴍᴛᴏᴀᴇʀ created

Closes #11984, closes #11075.

Release Notes:

- Added the ability to customize available models for OpenAI-compatible
services ([#11984](https://github.com/zed-industries/zed/issues/11984))
([#11075](https://github.com/zed-industries/zed/issues/11075)).


![image](https://github.com/zed-industries/zed/assets/32017007/01057e7b-1f21-49ad-a3ad-abc5282ffaf0)

Change summary

crates/assistant/src/assistant_settings.rs          | 14 +++++++
crates/assistant/src/completion_provider.rs         | 27 ++++++++++++--
crates/assistant/src/completion_provider/open_ai.rs | 26 ++++++++++++-
crates/assistant/src/inline_assistant.rs            |  3 +
crates/assistant/src/model_selector.rs              |  2 
crates/open_ai/src/open_ai.rs                       | 18 +++++++++
6 files changed, 79 insertions(+), 11 deletions(-)

Detailed changes

crates/assistant/src/assistant_settings.rs 🔗

@@ -169,6 +169,7 @@ pub enum AssistantProvider {
         model: OpenAiModel,
         api_url: String,
         low_speed_timeout_in_seconds: Option<u64>,
+        available_models: Vec<OpenAiModel>,
     },
     Anthropic {
         model: AnthropicModel,
@@ -188,6 +189,7 @@ impl Default for AssistantProvider {
             model: OpenAiModel::default(),
             api_url: open_ai::OPEN_AI_API_URL.into(),
             low_speed_timeout_in_seconds: None,
+            available_models: Default::default(),
         }
     }
 }
@@ -202,6 +204,7 @@ pub enum AssistantProviderContent {
         default_model: Option<OpenAiModel>,
         api_url: Option<String>,
         low_speed_timeout_in_seconds: Option<u64>,
+        available_models: Option<Vec<OpenAiModel>>,
     },
     #[serde(rename = "anthropic")]
     Anthropic {
@@ -272,6 +275,7 @@ impl AssistantSettingsContent {
                         default_model: settings.default_open_ai_model.clone(),
                         api_url: Some(open_ai_api_url.clone()),
                         low_speed_timeout_in_seconds: None,
+                        available_models: Some(Default::default()),
                     })
                 } else {
                     settings.default_open_ai_model.clone().map(|open_ai_model| {
@@ -279,6 +283,7 @@ impl AssistantSettingsContent {
                             default_model: Some(open_ai_model),
                             api_url: None,
                             low_speed_timeout_in_seconds: None,
+                            available_models: Some(Default::default()),
                         }
                     })
                 },
@@ -345,6 +350,7 @@ impl AssistantSettingsContent {
                                 default_model: Some(model),
                                 api_url: None,
                                 low_speed_timeout_in_seconds: None,
+                                available_models: Some(Default::default()),
                             })
                         }
                         LanguageModel::Anthropic(model) => {
@@ -489,15 +495,18 @@ impl Settings for AssistantSettings {
                             model,
                             api_url,
                             low_speed_timeout_in_seconds,
+                            available_models,
                         },
                         AssistantProviderContent::OpenAi {
                             default_model: model_override,
                             api_url: api_url_override,
                             low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
+                            available_models: available_models_override,
                         },
                     ) => {
                         merge(model, model_override);
                         merge(api_url, api_url_override);
+                        merge(available_models, available_models_override);
                         if let Some(low_speed_timeout_in_seconds_override) =
                             low_speed_timeout_in_seconds_override
                         {
@@ -558,10 +567,12 @@ impl Settings for AssistantSettings {
                                 default_model: model,
                                 api_url,
                                 low_speed_timeout_in_seconds,
+                                available_models,
                             } => AssistantProvider::OpenAi {
                                 model: model.unwrap_or_default(),
                                 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
                                 low_speed_timeout_in_seconds,
+                                available_models: available_models.unwrap_or_default(),
                             },
                             AssistantProviderContent::Anthropic {
                                 default_model: model,
@@ -618,6 +629,7 @@ mod tests {
                 model: OpenAiModel::FourOmni,
                 api_url: open_ai::OPEN_AI_API_URL.into(),
                 low_speed_timeout_in_seconds: None,
+                available_models: Default::default(),
             }
         );
 
@@ -640,6 +652,7 @@ mod tests {
                 model: OpenAiModel::FourOmni,
                 api_url: "test-url".into(),
                 low_speed_timeout_in_seconds: None,
+                available_models: Default::default(),
             }
         );
         SettingsStore::update_global(cx, |store, cx| {
@@ -660,6 +673,7 @@ mod tests {
                 model: OpenAiModel::Four,
                 api_url: open_ai::OPEN_AI_API_URL.into(),
                 low_speed_timeout_in_seconds: None,
+                available_models: Default::default(),
             }
         );
 

crates/assistant/src/completion_provider.rs 🔗

@@ -24,6 +24,20 @@ use settings::{Settings, SettingsStore};
 use std::sync::Arc;
 use std::time::Duration;
 
+/// Choose which model to use for openai provider.
+/// If the model is not available, try to use the first available model, or fallback to the original model.
+fn choose_openai_model(
+    model: &::open_ai::Model,
+    available_models: &[::open_ai::Model],
+) -> ::open_ai::Model {
+    available_models
+        .iter()
+        .find(|&m| m == model)
+        .or_else(|| available_models.first())
+        .unwrap_or_else(|| model)
+        .clone()
+}
+
 pub fn init(client: Arc<Client>, cx: &mut AppContext) {
     let mut settings_version = 0;
     let provider = match &AssistantSettings::get_global(cx).provider {
@@ -34,8 +48,9 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
             model,
             api_url,
             low_speed_timeout_in_seconds,
+            available_models,
         } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
-            model.clone(),
+            choose_openai_model(model, available_models),
             api_url.clone(),
             client.http_client(),
             low_speed_timeout_in_seconds.map(Duration::from_secs),
@@ -77,10 +92,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                         model,
                         api_url,
                         low_speed_timeout_in_seconds,
+                        available_models,
                     },
                 ) => {
                     provider.update(
-                        model.clone(),
+                        choose_openai_model(model, available_models),
                         api_url.clone(),
                         low_speed_timeout_in_seconds.map(Duration::from_secs),
                         settings_version,
@@ -136,10 +152,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                         model,
                         api_url,
                         low_speed_timeout_in_seconds,
+                        available_models,
                     },
                 ) => {
                     *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
-                        model.clone(),
+                        choose_openai_model(model, available_models),
                         api_url.clone(),
                         client.http_client(),
                         low_speed_timeout_in_seconds.map(Duration::from_secs),
@@ -201,10 +218,10 @@ impl CompletionProvider {
         cx.global::<Self>()
     }
 
-    pub fn available_models(&self) -> Vec<LanguageModel> {
+    pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
         match self {
             CompletionProvider::OpenAi(provider) => provider
-                .available_models()
+                .available_models(cx)
                 .map(LanguageModel::OpenAi)
                 .collect(),
             CompletionProvider::Anthropic(provider) => provider

crates/assistant/src/completion_provider/open_ai.rs 🔗

@@ -1,4 +1,5 @@
 use crate::assistant_settings::CloudModel;
+use crate::assistant_settings::{AssistantProvider, AssistantSettings};
 use crate::{
     assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
 };
@@ -56,8 +57,26 @@ impl OpenAiCompletionProvider {
         self.settings_version = settings_version;
     }
 
-    pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> {
-        OpenAiModel::iter()
+    pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> {
+        if let AssistantProvider::OpenAi {
+            available_models, ..
+        } = &AssistantSettings::get_global(cx).provider
+        {
+            if !available_models.is_empty() {
+                // available_models is set, just return it
+                return available_models.clone().into_iter();
+            }
+        }
+        let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
+            // available_models is not set but the default model is set to custom, only show custom
+            vec![self.model.clone()]
+        } else {
+            // default case, use all models except custom
+            OpenAiModel::iter()
+                .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
+                .collect()
+        };
+        available_models.into_iter()
     }
 
     pub fn settings_version(&self) -> usize {
@@ -213,7 +232,8 @@ pub fn count_open_ai_tokens(
                 | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
                 | LanguageModel::Cloud(CloudModel::Claude3Opus)
                 | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
-                | LanguageModel::Cloud(CloudModel::Claude3Haiku) => {
+                | LanguageModel::Cloud(CloudModel::Claude3Haiku)
+                | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
                     // Tiktoken doesn't yet support these models, so we manually use the
                     // same tokenizer as GPT-4.
                     tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)

crates/assistant/src/inline_assistant.rs 🔗

@@ -1298,7 +1298,8 @@ impl Render for PromptEditor {
                         PopoverMenu::new("model-switcher")
                             .menu(move |cx| {
                                 ContextMenu::build(cx, |mut menu, cx| {
-                                    for model in CompletionProvider::global(cx).available_models() {
+                                    for model in CompletionProvider::global(cx).available_models(cx)
+                                    {
                                         menu = menu.custom_entry(
                                             {
                                                 let model = model.clone();

crates/assistant/src/model_selector.rs 🔗

@@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector {
             .with_handle(self.handle)
             .menu(move |cx| {
                 ContextMenu::build(cx, |mut menu, cx| {
-                    for model in CompletionProvider::global(cx).available_models() {
+                    for model in CompletionProvider::global(cx).available_models(cx) {
                         menu = menu.custom_entry(
                             {
                                 let model = model.clone();

crates/open_ai/src/open_ai.rs 🔗

@@ -55,6 +55,8 @@ pub enum Model {
     #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
     #[default]
     FourOmni,
+    #[serde(rename = "custom")]
+    Custom { name: String, max_tokens: usize },
 }
 
 impl Model {
@@ -74,15 +76,17 @@ impl Model {
             Self::Four => "gpt-4",
             Self::FourTurbo => "gpt-4-turbo-preview",
             Self::FourOmni => "gpt-4o",
+            Self::Custom { .. } => "custom",
         }
     }
 
-    pub fn display_name(&self) -> &'static str {
+    pub fn display_name(&self) -> &str {
         match self {
             Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
             Self::Four => "gpt-4",
             Self::FourTurbo => "gpt-4-turbo",
             Self::FourOmni => "gpt-4o",
+            Self::Custom { name, .. } => name,
         }
     }
 
@@ -92,12 +96,24 @@ impl Model {
             Model::Four => 8192,
             Model::FourTurbo => 128000,
             Model::FourOmni => 128000,
+            Model::Custom { max_tokens, .. } => *max_tokens,
         }
     }
 }
 
+fn serialize_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
+where
+    S: serde::Serializer,
+{
+    match model {
+        Model::Custom { name, .. } => serializer.serialize_str(name),
+        _ => serializer.serialize_str(model.id()),
+    }
+}
+
 #[derive(Debug, Serialize)]
 pub struct Request {
+    #[serde(serialize_with = "serialize_model")]
     pub model: Model,
     pub messages: Vec<RequestMessage>,
     pub stream: bool,