diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index bfddba0e2f8a41e3ed234b21ee52454d104c9dd2..9b9d6e19b8de86fd0ee7e6fe6bf57d6d91da19da 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -164,6 +164,8 @@ pub enum ModelVendor { OpenAI, Google, Anthropic, + #[serde(rename = "xAI")] + XAI, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index d48c12aa4b5de713c0130320f7c9e61a733dc33e..bd284eb72b207dee90048f06dc44a8e21ae8d34f 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -32,6 +32,8 @@ use std::time::Duration; use ui::prelude::*; use util::debug_panic; +use crate::provider::x_ai::count_xai_tokens; + use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; use super::open_ai::count_open_ai_tokens; @@ -228,7 +230,9 @@ impl LanguageModel for CopilotChatLanguageModel { ModelVendor::OpenAI | ModelVendor::Anthropic => { LanguageModelToolSchemaFormat::JsonSchema } - ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset, + ModelVendor::Google | ModelVendor::XAI => { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } } } @@ -256,6 +260,10 @@ impl LanguageModel for CopilotChatLanguageModel { match self.model.vendor() { ModelVendor::Anthropic => count_anthropic_tokens(request, cx), ModelVendor::Google => count_google_tokens(request, cx), + ModelVendor::XAI => { + let model = x_ai::Model::from_id(self.model.id()).unwrap_or_default(); + count_xai_tokens(request, model, cx) + } ModelVendor::OpenAI => { let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default(); count_open_ai_tokens(request, model, cx)