assistant: Make it easier to define custom models (#15442)

Bennet Bo Fenner and Thorsten created

This PR makes it easier to specify custom models for the Google, OpenAI,
and Anthropic provider:

Before (google):

```json
{
  "language_models": {
    "google": {
      "available_models": [
        {
          "custom": {
            "name": "my-custom-google-model",
            "max_tokens": 12345
          }
        }
      ]
    }
  }
}
```

After (google):

```json
{
  "language_models": {
    "google": {
      "available_models": [
        {
          "name": "my-custom-google-model",
          "max_tokens": 12345
        }
      ]
    }
  }
}
```

Before (anthropic):

```json
{
  "language_models": {
    "anthropic": {
      "available_models": [
        {
          "custom": {
            "name": "my-custom-anthropic-model",
            "max_tokens": 12345
          }
        }
      ]
    }
  }
}
```

After (anthropic):

```json
{
  "language_models": {
    "anthropic": {
      "version": "1",
      "available_models": [
        {
          "name": "my-custom-anthropic-model",
          "max_tokens": 12345
        }
      ]
    }
  }
}

```

The settings will be auto-upgraded so the old versions will continue to
work (except for Google since that one has not been released).

/cc @as-cii 

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>

Change summary

assets/settings/default.json                    |   8 
crates/assistant/src/assistant_settings.rs      |  40 +++-
crates/language_model/Cargo.toml                |   1 
crates/language_model/src/language_model.rs     |   5 
crates/language_model/src/provider/anthropic.rs |  19 +
crates/language_model/src/provider/google.rs    |  18 +
crates/language_model/src/provider/open_ai.rs   |  19 +
crates/language_model/src/settings.rs           | 189 ++++++++++++++++--
crates/zed/src/main.rs                          |   2 
crates/zed/src/zed.rs                           |   2 
10 files changed, 256 insertions(+), 47 deletions(-)

Detailed changes

assets/settings/default.json 🔗

@@ -865,16 +865,18 @@
   // Different settings for specific language models.
   "language_models": {
     "anthropic": {
+      "version": "1",
       "api_url": "https://api.anthropic.com"
     },
-    "openai": {
-      "api_url": "https://api.openai.com/v1"
-    },
     "google": {
       "api_url": "https://generativelanguage.googleapis.com"
     },
     "ollama": {
       "api_url": "http://localhost:11434"
+    },
+    "openai": {
+      "version": "1",
+      "api_url": "https://api.openai.com/v1"
     }
   },
   // Zed's Prettier integration settings.

crates/assistant/src/assistant_settings.rs 🔗

@@ -110,11 +110,15 @@ impl AssistantSettingsContent {
                             move |content, _| {
                                 if content.anthropic.is_none() {
                                     content.anthropic =
-                                        Some(language_model::settings::AnthropicSettingsContent {
-                                            api_url,
-                                            low_speed_timeout_in_seconds,
-                                            ..Default::default()
-                                        });
+                                        Some(language_model::settings::AnthropicSettingsContent::Versioned(
+                                            language_model::settings::VersionedAnthropicSettingsContent::V1(
+                                                language_model::settings::AnthropicSettingsContentV1 {
+                                                    api_url,
+                                                    low_speed_timeout_in_seconds,
+                                                    available_models: None
+                                                }
+                                            )
+                                        ));
                                 }
                             },
                         ),
@@ -145,12 +149,27 @@ impl AssistantSettingsContent {
                             cx,
                             move |content, _| {
                                 if content.openai.is_none() {
+                                    let available_models = available_models.map(|models| {
+                                        models
+                                            .into_iter()
+                                            .filter_map(|model| match model {
+                                                open_ai::Model::Custom { name, max_tokens } => {
+                                                    Some(language_model::provider::open_ai::AvailableModel { name, max_tokens })
+                                                }
+                                                _ => None,
+                                            })
+                                            .collect::<Vec<_>>()
+                                    });
                                     content.openai =
-                                        Some(language_model::settings::OpenAiSettingsContent {
-                                            api_url,
-                                            low_speed_timeout_in_seconds,
-                                            available_models,
-                                        });
+                                        Some(language_model::settings::OpenAiSettingsContent::Versioned(
+                                            language_model::settings::VersionedOpenAiSettingsContent::V1(
+                                                language_model::settings::OpenAiSettingsContentV1 {
+                                                    api_url,
+                                                    low_speed_timeout_in_seconds,
+                                                    available_models
+                                                }
+                                            )
+                                        ));
                                 }
                             },
                         ),
@@ -377,6 +396,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema:
     schemars::schema::SchemaObject {
         enum_values: Some(vec![
             "anthropic".into(),
+            "google".into(),
             "ollama".into(),
             "openai".into(),
             "zed.dev".into(),

crates/language_model/Cargo.toml 🔗

@@ -37,6 +37,7 @@ menu.workspace = true
 ollama = { workspace = true, features = ["schemars"] }
 open_ai = { workspace = true, features = ["schemars"] }
 proto = { workspace = true, features = ["test-support"] }
+project.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true

crates/language_model/src/language_model.rs 🔗

@@ -13,14 +13,15 @@ use futures::{future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
 
 pub use model::*;
+use project::Fs;
 pub use registry::*;
 pub use request::*;
 pub use role::*;
 use schemars::JsonSchema;
 use serde::de::DeserializeOwned;
 
-pub fn init(client: Arc<Client>, cx: &mut AppContext) {
-    settings::init(cx);
+pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
+    settings::init(fs, cx);
     registry::init(client, cx);
 }
 

crates/language_model/src/provider/anthropic.rs 🔗

@@ -12,6 +12,8 @@ use gpui::{
     WhiteSpace,
 };
 use http_client::HttpClient;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
 use std::{sync::Arc, time::Duration};
 use strum::IntoEnumIterator;
@@ -26,7 +28,14 @@ const PROVIDER_NAME: &str = "Anthropic";
 pub struct AnthropicSettings {
     pub api_url: String,
     pub low_speed_timeout: Option<Duration>,
-    pub available_models: Vec<anthropic::Model>,
+    pub available_models: Vec<AvailableModel>,
+    pub needs_setting_migration: bool,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct AvailableModel {
+    pub name: String,
+    pub max_tokens: usize,
 }
 
 pub struct AnthropicLanguageModelProvider {
@@ -84,7 +93,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
             .available_models
             .iter()
         {
-            models.insert(model.id().to_string(), model.clone());
+            models.insert(
+                model.name.clone(),
+                anthropic::Model::Custom {
+                    name: model.name.clone(),
+                    max_tokens: model.max_tokens,
+                },
+            );
         }
 
         models

crates/language_model/src/provider/google.rs 🔗

@@ -8,6 +8,8 @@ use gpui::{
     WhiteSpace,
 };
 use http_client::HttpClient;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
 use std::{future, sync::Arc, time::Duration};
 use strum::IntoEnumIterator;
@@ -28,7 +30,13 @@ const PROVIDER_NAME: &str = "Google AI";
 pub struct GoogleSettings {
     pub api_url: String,
     pub low_speed_timeout: Option<Duration>,
-    pub available_models: Vec<google_ai::Model>,
+    pub available_models: Vec<AvailableModel>,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct AvailableModel {
+    name: String,
+    max_tokens: usize,
 }
 
 pub struct GoogleLanguageModelProvider {
@@ -86,7 +94,13 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
             .google
             .available_models
         {
-            models.insert(model.id().to_string(), model.clone());
+            models.insert(
+                model.name.clone(),
+                google_ai::Model::Custom {
+                    name: model.name.clone(),
+                    max_tokens: model.max_tokens,
+                },
+            );
         }
 
         models

crates/language_model/src/provider/open_ai.rs 🔗

@@ -8,6 +8,8 @@ use gpui::{
 };
 use http_client::HttpClient;
 use open_ai::stream_completion;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
 use std::{future, sync::Arc, time::Duration};
 use strum::IntoEnumIterator;
@@ -28,7 +30,14 @@ const PROVIDER_NAME: &str = "OpenAI";
 pub struct OpenAiSettings {
     pub api_url: String,
     pub low_speed_timeout: Option<Duration>,
-    pub available_models: Vec<open_ai::Model>,
+    pub available_models: Vec<AvailableModel>,
+    pub needs_setting_migration: bool,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct AvailableModel {
+    pub name: String,
+    pub max_tokens: usize,
 }
 
 pub struct OpenAiLanguageModelProvider {
@@ -86,7 +95,13 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
             .openai
             .available_models
         {
-            models.insert(model.id().to_string(), model.clone());
+            models.insert(
+                model.name.clone(),
+                open_ai::Model::Custom {
+                    name: model.name.clone(),
+                    max_tokens: model.max_tokens,
+                },
+            );
         }
 
         models

crates/language_model/src/settings.rs 🔗

@@ -1,12 +1,14 @@
-use std::time::Duration;
+use std::{sync::Arc, time::Duration};
 
 use anyhow::Result;
 use gpui::AppContext;
+use project::Fs;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use settings::{Settings, SettingsSources};
+use settings::{update_settings_file, Settings, SettingsSources};
 
 use crate::provider::{
+    self,
     anthropic::AnthropicSettings,
     cloud::{self, ZedDotDevSettings},
     copilot_chat::CopilotChatSettings,
@@ -16,8 +18,36 @@ use crate::provider::{
 };
 
 /// Initializes the language model settings.
-pub fn init(cx: &mut AppContext) {
+pub fn init(fs: Arc<dyn Fs>, cx: &mut AppContext) {
     AllLanguageModelSettings::register(cx);
+
+    if AllLanguageModelSettings::get_global(cx)
+        .openai
+        .needs_setting_migration
+    {
+        update_settings_file::<AllLanguageModelSettings>(fs.clone(), cx, move |setting, _| {
+            if let Some(settings) = setting.openai.clone() {
+                let (newest_version, _) = settings.upgrade();
+                setting.openai = Some(OpenAiSettingsContent::Versioned(
+                    VersionedOpenAiSettingsContent::V1(newest_version),
+                ));
+            }
+        });
+    }
+
+    if AllLanguageModelSettings::get_global(cx)
+        .anthropic
+        .needs_setting_migration
+    {
+        update_settings_file::<AllLanguageModelSettings>(fs, cx, move |setting, _| {
+            if let Some(settings) = setting.anthropic.clone() {
+                let (newest_version, _) = settings.upgrade();
+                setting.anthropic = Some(AnthropicSettingsContent::Versioned(
+                    VersionedAnthropicSettingsContent::V1(newest_version),
+                ));
+            }
+        });
+    }
 }
 
 #[derive(Default)]
@@ -41,31 +71,129 @@ pub struct AllLanguageModelSettingsContent {
     pub copilot_chat: Option<CopilotChatSettingsContent>,
 }
 
-#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
-pub struct AnthropicSettingsContent {
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+#[serde(untagged)]
+pub enum AnthropicSettingsContent {
+    Legacy(LegacyAnthropicSettingsContent),
+    Versioned(VersionedAnthropicSettingsContent),
+}
+
+impl AnthropicSettingsContent {
+    pub fn upgrade(self) -> (AnthropicSettingsContentV1, bool) {
+        match self {
+            AnthropicSettingsContent::Legacy(content) => (
+                AnthropicSettingsContentV1 {
+                    api_url: content.api_url,
+                    low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
+                    available_models: content.available_models.map(|models| {
+                        models
+                            .into_iter()
+                            .filter_map(|model| match model {
+                                anthropic::Model::Custom { name, max_tokens } => {
+                                    Some(provider::anthropic::AvailableModel { name, max_tokens })
+                                }
+                                _ => None,
+                            })
+                            .collect()
+                    }),
+                },
+                true,
+            ),
+            AnthropicSettingsContent::Versioned(content) => match content {
+                VersionedAnthropicSettingsContent::V1(content) => (content, false),
+            },
+        }
+    }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct LegacyAnthropicSettingsContent {
     pub api_url: Option<String>,
     pub low_speed_timeout_in_seconds: Option<u64>,
     pub available_models: Option<Vec<anthropic::Model>>,
 }
 
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+#[serde(tag = "version")]
+pub enum VersionedAnthropicSettingsContent {
+    #[serde(rename = "1")]
+    V1(AnthropicSettingsContentV1),
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct AnthropicSettingsContentV1 {
+    pub api_url: Option<String>,
+    pub low_speed_timeout_in_seconds: Option<u64>,
+    pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
+}
+
 #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
 pub struct OllamaSettingsContent {
     pub api_url: Option<String>,
     pub low_speed_timeout_in_seconds: Option<u64>,
 }
 
-#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
-pub struct OpenAiSettingsContent {
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+#[serde(untagged)]
+pub enum OpenAiSettingsContent {
+    Legacy(LegacyOpenAiSettingsContent),
+    Versioned(VersionedOpenAiSettingsContent),
+}
+
+impl OpenAiSettingsContent {
+    pub fn upgrade(self) -> (OpenAiSettingsContentV1, bool) {
+        match self {
+            OpenAiSettingsContent::Legacy(content) => (
+                OpenAiSettingsContentV1 {
+                    api_url: content.api_url,
+                    low_speed_timeout_in_seconds: content.low_speed_timeout_in_seconds,
+                    available_models: content.available_models.map(|models| {
+                        models
+                            .into_iter()
+                            .filter_map(|model| match model {
+                                open_ai::Model::Custom { name, max_tokens } => {
+                                    Some(provider::open_ai::AvailableModel { name, max_tokens })
+                                }
+                                _ => None,
+                            })
+                            .collect()
+                    }),
+                },
+                true,
+            ),
+            OpenAiSettingsContent::Versioned(content) => match content {
+                VersionedOpenAiSettingsContent::V1(content) => (content, false),
+            },
+        }
+    }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct LegacyOpenAiSettingsContent {
     pub api_url: Option<String>,
     pub low_speed_timeout_in_seconds: Option<u64>,
     pub available_models: Option<Vec<open_ai::Model>>,
 }
 
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+#[serde(tag = "version")]
+pub enum VersionedOpenAiSettingsContent {
+    #[serde(rename = "1")]
+    V1(OpenAiSettingsContentV1),
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct OpenAiSettingsContentV1 {
+    pub api_url: Option<String>,
+    pub low_speed_timeout_in_seconds: Option<u64>,
+    pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
+}
+
 #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
 pub struct GoogleSettingsContent {
     pub api_url: Option<String>,
     pub low_speed_timeout_in_seconds: Option<u64>,
-    pub available_models: Option<Vec<google_ai::Model>>,
+    pub available_models: Option<Vec<provider::google::AvailableModel>>,
 }
 
 #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@@ -81,6 +209,8 @@ pub struct CopilotChatSettingsContent {
 impl settings::Settings for AllLanguageModelSettings {
     const KEY: Option<&'static str> = Some("language_models");
 
+    const PRESERVED_KEYS: Option<&'static [&'static str]> = Some(&["version"]);
+
     type FileContent = AllLanguageModelSettingsContent;
 
     fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
@@ -93,12 +223,21 @@ impl settings::Settings for AllLanguageModelSettings {
         let mut settings = AllLanguageModelSettings::default();
 
         for value in sources.defaults_and_customizations() {
+            // Anthropic
+            let (anthropic, upgraded) = match value.anthropic.clone().map(|s| s.upgrade()) {
+                Some((content, upgraded)) => (Some(content), upgraded),
+                None => (None, false),
+            };
+
+            if upgraded {
+                settings.anthropic.needs_setting_migration = true;
+            }
+
             merge(
                 &mut settings.anthropic.api_url,
-                value.anthropic.as_ref().and_then(|s| s.api_url.clone()),
+                anthropic.as_ref().and_then(|s| s.api_url.clone()),
             );
-            if let Some(low_speed_timeout_in_seconds) = value
-                .anthropic
+            if let Some(low_speed_timeout_in_seconds) = anthropic
                 .as_ref()
                 .and_then(|s| s.low_speed_timeout_in_seconds)
             {
@@ -107,10 +246,7 @@ impl settings::Settings for AllLanguageModelSettings {
             }
             merge(
                 &mut settings.anthropic.available_models,
-                value
-                    .anthropic
-                    .as_ref()
-                    .and_then(|s| s.available_models.clone()),
+                anthropic.as_ref().and_then(|s| s.available_models.clone()),
             );
 
             merge(
@@ -126,24 +262,29 @@ impl settings::Settings for AllLanguageModelSettings {
                     Some(Duration::from_secs(low_speed_timeout_in_seconds));
             }
 
+            // OpenAI
+            let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {
+                Some((content, upgraded)) => (Some(content), upgraded),
+                None => (None, false),
+            };
+
+            if upgraded {
+                settings.openai.needs_setting_migration = true;
+            }
+
             merge(
                 &mut settings.openai.api_url,
-                value.openai.as_ref().and_then(|s| s.api_url.clone()),
+                openai.as_ref().and_then(|s| s.api_url.clone()),
             );
-            if let Some(low_speed_timeout_in_seconds) = value
-                .openai
-                .as_ref()
-                .and_then(|s| s.low_speed_timeout_in_seconds)
+            if let Some(low_speed_timeout_in_seconds) =
+                openai.as_ref().and_then(|s| s.low_speed_timeout_in_seconds)
             {
                 settings.openai.low_speed_timeout =
                     Some(Duration::from_secs(low_speed_timeout_in_seconds));
             }
             merge(
                 &mut settings.openai.available_models,
-                value
-                    .openai
-                    .as_ref()
-                    .and_then(|s| s.available_models.clone()),
+                openai.as_ref().and_then(|s| s.available_models.clone()),
             );
 
             merge(

crates/zed/src/main.rs 🔗

@@ -174,7 +174,7 @@ fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) {
         cx,
     );
     supermaven::init(app_state.client.clone(), cx);
-    language_model::init(app_state.client.clone(), cx);
+    language_model::init(app_state.client.clone(), app_state.fs.clone(), cx);
     snippet_provider::init(cx);
     inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
     assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);

crates/zed/src/zed.rs 🔗

@@ -3461,7 +3461,7 @@ mod tests {
                 app_state.client.http_client().clone(),
                 cx,
             );
-            language_model::init(app_state.client.clone(), cx);
+            language_model::init(app_state.client.clone(), app_state.fs.clone(), cx);
             assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
             repl::init(
                 app_state.fs.clone(),