Add support for gpt-4o when using zed.dev as the model provider (#11794)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/assistant/src/assistant_panel.rs             |  5 +++--
crates/assistant/src/assistant_settings.rs          |  9 +++++++--
crates/assistant/src/completion_provider/open_ai.rs | 11 ++++++++++-
crates/assistant/src/completion_provider/zed.rs     |  1 +
4 files changed, 21 insertions(+), 5 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -803,12 +803,13 @@ impl AssistantPanel {
             LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
                 ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
                 ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
-                ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Claude3Opus,
+                ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni,
+                ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus,
                 ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
                 ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
                 ZedDotDevModel::Claude3Haiku => {
                     match CompletionProvider::global(cx).default_model() {
-                        LanguageModel::ZedDotDev(custom) => custom,
+                        LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom,
                         _ => ZedDotDevModel::Gpt3Point5Turbo,
                     }
                 }

crates/assistant/src/assistant_settings.rs 🔗

@@ -16,8 +16,9 @@ use settings::{Settings, SettingsSources};
 pub enum ZedDotDevModel {
     Gpt3Point5Turbo,
     Gpt4,
-    #[default]
     Gpt4Turbo,
+    #[default]
+    Gpt4Omni,
     Claude3Opus,
     Claude3Sonnet,
     Claude3Haiku,
@@ -55,6 +56,7 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
                     "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
                     "gpt-4" => Ok(ZedDotDevModel::Gpt4),
                     "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
+                    "gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni),
                     _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
                 }
             }
@@ -74,6 +76,7 @@ impl JsonSchema for ZedDotDevModel {
             "gpt-3.5-turbo".to_owned(),
             "gpt-4".to_owned(),
             "gpt-4-turbo-preview".to_owned(),
+            "gpt-4o".to_owned(),
         ];
         Schema::Object(SchemaObject {
             instance_type: Some(InstanceType::String.into()),
@@ -100,6 +103,7 @@ impl ZedDotDevModel {
             Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
             Self::Gpt4 => "gpt-4",
             Self::Gpt4Turbo => "gpt-4-turbo-preview",
+            Self::Gpt4Omni => "gpt-4o",
             Self::Claude3Opus => "claude-3-opus",
             Self::Claude3Sonnet => "claude-3-sonnet",
             Self::Claude3Haiku => "claude-3-haiku",
@@ -112,6 +116,7 @@ impl ZedDotDevModel {
             Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
             Self::Gpt4 => "GPT 4",
             Self::Gpt4Turbo => "GPT 4 Turbo",
+            Self::Gpt4Omni => "GPT 4 Omni",
             Self::Claude3Opus => "Claude 3 Opus",
             Self::Claude3Sonnet => "Claude 3 Sonnet",
             Self::Claude3Haiku => "Claude 3 Haiku",
@@ -123,7 +128,7 @@ impl ZedDotDevModel {
         match self {
             Self::Gpt3Point5Turbo => 2048,
             Self::Gpt4 => 4096,
-            Self::Gpt4Turbo => 128000,
+            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
             Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
             Self::Custom(_) => 4096, // TODO: Make this configurable
         }

crates/assistant/src/completion_provider/open_ai.rs 🔗

@@ -1,3 +1,4 @@
+use crate::assistant_settings::ZedDotDevModel;
 use crate::{
     assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
 };
@@ -202,7 +203,15 @@ pub fn count_open_ai_tokens(
                 })
                 .collect::<Vec<_>>();
 
-            tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
+            match request.model {
+                LanguageModel::OpenAi(OpenAiModel::FourOmni)
+                | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) => {
+                    // Tiktoken doesn't yet support gpt-4o, so we manually use the
+                    // same tokenizer as GPT-4.
+                    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
+                }
+                _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
+            }
         })
         .boxed()
 }

crates/assistant/src/completion_provider/zed.rs 🔗

@@ -81,6 +81,7 @@ impl ZedDotDevCompletionProvider {
             LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
             LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
             | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
+            | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
             | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
                 count_open_ai_tokens(request, cx.background_executor())
             }