From 2e36e9782ee2235e97d765387ceb0fed57a406af Mon Sep 17 00:00:00 2001 From: Umesh Yadav <23421535+imumesh18@users.noreply.github.com> Date: Wed, 10 Sep 2025 00:58:26 +0530 Subject: [PATCH] language_models: Make Copilot Chat resilient to new model vendors and add tokenizer-based token counting (#37118) While working on fixing this: #37116. I reliased the current implementation of github copilot is not truly resilient to upstream changes. This PR enhances GitHub Copilot Chat to be forward-compatible with new AI model vendors and improves token counting accuracy by using vendor-specific tokenizers from the GitHub Copilot API. The system previously failed when GitHub added new model vendors like xAI with deserialization errors, and token counting wasn't utilizing the vendor-specific tokenizer information provided by the API. The solution adds an Unknown variant to the ModelVendor enum with serde other attribute to gracefully handle any new vendors GitHub introduces, implements tokenizer-aware token counting that uses the model's specified tokenizer mapping o200k_base to gpt-4o with fallback, adds explicit support for xAI models with proper tool input format handling, and includes comprehensive test coverage for unknown vendor scenarios. Key changes include adding the tokenizer field to model capabilities, implementing the tokenizer method on models, updating tool input format logic to handle unknown vendors, and simplifying token counting to use the vendor's specified tokenizer or fall back to gpt-4o. This ensures Zed's Copilot Chat integration remains robust and accurate as GitHub continues expanding their AI model provider ecosystem. Release Notes: - Enhanced model vendor compatibility to automatically support future AI providers and improved token counting accuracy using vendor-specific tokenizers from the GitHub Copilot --------- Signed-off-by: Umesh Yadav --- crates/copilot/src/copilot_chat.rs | 50 +++++++++++++++++ .../src/provider/copilot_chat.rs | 53 ++++++++++++------- 2 files changed, 84 insertions(+), 19 deletions(-) diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 9b9d6e19b8de86fd0ee7e6fe6bf57d6d91da19da..ccd8f09613eec54f2d30b619f142d111bf2a3497 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -128,6 +128,8 @@ struct ModelCapabilities { supports: ModelSupportedFeatures, #[serde(rename = "type")] model_type: String, + #[serde(default)] + tokenizer: Option, } #[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -166,6 +168,9 @@ pub enum ModelVendor { Anthropic, #[serde(rename = "xAI")] XAI, + /// Unknown vendor that we don't explicitly support yet + #[serde(other)] + Unknown, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] @@ -214,6 +219,10 @@ impl Model { pub fn supports_parallel_tool_calls(&self) -> bool { self.capabilities.supports.parallel_tool_calls } + + pub fn tokenizer(&self) -> Option<&str> { + self.capabilities.tokenizer.as_deref() + } } #[derive(Serialize, Deserialize)] @@ -901,4 +910,45 @@ mod tests { assert_eq!(schema.data[0].id, "gpt-4"); assert_eq!(schema.data[1].id, "claude-3.7-sonnet"); } + + #[test] + fn test_unknown_vendor_resilience() { + let json = r#"{ + "data": [ + { + "billing": { + "is_premium": false, + "multiplier": 1 + }, + "capabilities": { + "family": "future-model", + "limits": { + "max_context_window_tokens": 128000, + "max_output_tokens": 8192, + "max_prompt_tokens": 120000 + }, + "object": "model_capabilities", + "supports": { "streaming": true, "tool_calls": true }, + "type": "chat" + }, + "id": "future-model-v1", + "is_chat_default": false, + "is_chat_fallback": false, + "model_picker_enabled": true, + "name": "Future Model v1", + "object": "model", + "preview": false, + "vendor": "SomeNewVendor", + "version": "v1.0" + } + ], + "object": "list" + }"#; + + let schema: ModelSchema = serde_json::from_str(json).unwrap(); + + assert_eq!(schema.data.len(), 1); + assert_eq!(schema.data[0].id, "future-model-v1"); + assert_eq!(schema.data[0].vendor, ModelVendor::Unknown); + } } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 071424eabe3c1ad3436de201860d6220ab664a06..b7ece55fed70beae543b9bd55e7635fa6a3fc04d 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -28,12 +28,6 @@ use settings::SettingsStore; use ui::{CommonAnimationExt, prelude::*}; use util::debug_panic; -use crate::provider::x_ai::count_xai_tokens; - -use super::anthropic::count_anthropic_tokens; -use super::google::count_google_tokens; -use super::open_ai::count_open_ai_tokens; - const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("GitHub Copilot Chat"); @@ -191,6 +185,25 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } } +fn collect_tiktoken_messages( + request: LanguageModelRequest, +) -> Vec { + request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>() +} + pub struct CopilotChatLanguageModel { model: CopilotChatModel, request_limiter: RateLimiter, @@ -226,7 +239,7 @@ impl LanguageModel for CopilotChatLanguageModel { ModelVendor::OpenAI | ModelVendor::Anthropic => { LanguageModelToolSchemaFormat::JsonSchema } - ModelVendor::Google | ModelVendor::XAI => { + ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => { LanguageModelToolSchemaFormat::JsonSchemaSubset } } @@ -253,18 +266,20 @@ impl LanguageModel for CopilotChatLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - match self.model.vendor() { - ModelVendor::Anthropic => count_anthropic_tokens(request, cx), - ModelVendor::Google => count_google_tokens(request, cx), - ModelVendor::XAI => { - let model = x_ai::Model::from_id(self.model.id()).unwrap_or_default(); - count_xai_tokens(request, model, cx) - } - ModelVendor::OpenAI => { - let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default(); - count_open_ai_tokens(request, model, cx) - } - } + let model = self.model.clone(); + cx.background_spawn(async move { + let messages = collect_tiktoken_messages(request); + // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor). + let tokenizer_model = match model.tokenizer() { + Some("o200k_base") => "gpt-4o", + Some("cl100k_base") => "gpt-4", + _ => "gpt-4o", + }; + + tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages) + .map(|tokens| tokens as u64) + }) + .boxed() } fn stream_completion(