diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 9b9d6e19b8de86fd0ee7e6fe6bf57d6d91da19da..ccd8f09613eec54f2d30b619f142d111bf2a3497 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -128,6 +128,8 @@ struct ModelCapabilities { supports: ModelSupportedFeatures, #[serde(rename = "type")] model_type: String, + #[serde(default)] + tokenizer: Option, } #[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -166,6 +168,9 @@ pub enum ModelVendor { Anthropic, #[serde(rename = "xAI")] XAI, + /// Unknown vendor that we don't explicitly support yet + #[serde(other)] + Unknown, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] @@ -214,6 +219,10 @@ impl Model { pub fn supports_parallel_tool_calls(&self) -> bool { self.capabilities.supports.parallel_tool_calls } + + pub fn tokenizer(&self) -> Option<&str> { + self.capabilities.tokenizer.as_deref() + } } #[derive(Serialize, Deserialize)] @@ -901,4 +910,45 @@ mod tests { assert_eq!(schema.data[0].id, "gpt-4"); assert_eq!(schema.data[1].id, "claude-3.7-sonnet"); } + + #[test] + fn test_unknown_vendor_resilience() { + let json = r#"{ + "data": [ + { + "billing": { + "is_premium": false, + "multiplier": 1 + }, + "capabilities": { + "family": "future-model", + "limits": { + "max_context_window_tokens": 128000, + "max_output_tokens": 8192, + "max_prompt_tokens": 120000 + }, + "object": "model_capabilities", + "supports": { "streaming": true, "tool_calls": true }, + "type": "chat" + }, + "id": "future-model-v1", + "is_chat_default": false, + "is_chat_fallback": false, + "model_picker_enabled": true, + "name": "Future Model v1", + "object": "model", + "preview": false, + "vendor": "SomeNewVendor", + "version": "v1.0" + } + ], + "object": "list" + }"#; + + let schema: ModelSchema = serde_json::from_str(json).unwrap(); + + assert_eq!(schema.data.len(), 1); + assert_eq!(schema.data[0].id, "future-model-v1"); + assert_eq!(schema.data[0].vendor, ModelVendor::Unknown); + } } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 071424eabe3c1ad3436de201860d6220ab664a06..b7ece55fed70beae543b9bd55e7635fa6a3fc04d 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -28,12 +28,6 @@ use settings::SettingsStore; use ui::{CommonAnimationExt, prelude::*}; use util::debug_panic; -use crate::provider::x_ai::count_xai_tokens; - -use super::anthropic::count_anthropic_tokens; -use super::google::count_google_tokens; -use super::open_ai::count_open_ai_tokens; - const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("GitHub Copilot Chat"); @@ -191,6 +185,25 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } } +fn collect_tiktoken_messages( + request: LanguageModelRequest, +) -> Vec { + request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>() +} + pub struct CopilotChatLanguageModel { model: CopilotChatModel, request_limiter: RateLimiter, @@ -226,7 +239,7 @@ impl LanguageModel for CopilotChatLanguageModel { ModelVendor::OpenAI | ModelVendor::Anthropic => { LanguageModelToolSchemaFormat::JsonSchema } - ModelVendor::Google | ModelVendor::XAI => { + ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => { LanguageModelToolSchemaFormat::JsonSchemaSubset } } @@ -253,18 +266,20 @@ impl LanguageModel for CopilotChatLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - match self.model.vendor() { - ModelVendor::Anthropic => count_anthropic_tokens(request, cx), - ModelVendor::Google => count_google_tokens(request, cx), - ModelVendor::XAI => { - let model = x_ai::Model::from_id(self.model.id()).unwrap_or_default(); - count_xai_tokens(request, model, cx) - } - ModelVendor::OpenAI => { - let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default(); - count_open_ai_tokens(request, model, cx) - } - } + let model = self.model.clone(); + cx.background_spawn(async move { + let messages = collect_tiktoken_messages(request); + // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor). + let tokenizer_model = match model.tokenizer() { + Some("o200k_base") => "gpt-4o", + Some("cl100k_base") => "gpt-4", + _ => "gpt-4o", + }; + + tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages) + .map(|tokens| tokens as u64) + }) + .boxed() } fn stream_completion(