@@ -128,6 +128,8 @@ struct ModelCapabilities {
supports: ModelSupportedFeatures,
#[serde(rename = "type")]
model_type: String,
+ #[serde(default)]
+ tokenizer: Option<String>,
}
#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
@@ -166,6 +168,9 @@ pub enum ModelVendor {
Anthropic,
#[serde(rename = "xAI")]
XAI,
+ /// Unknown vendor that we don't explicitly support yet
+ #[serde(other)]
+ Unknown,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
@@ -214,6 +219,10 @@ impl Model {
pub fn supports_parallel_tool_calls(&self) -> bool {
self.capabilities.supports.parallel_tool_calls
}
+
+ pub fn tokenizer(&self) -> Option<&str> {
+ self.capabilities.tokenizer.as_deref()
+ }
}
#[derive(Serialize, Deserialize)]
@@ -901,4 +910,45 @@ mod tests {
assert_eq!(schema.data[0].id, "gpt-4");
assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
}
+
+ #[test]
+ fn test_unknown_vendor_resilience() {
+ let json = r#"{
+ "data": [
+ {
+ "billing": {
+ "is_premium": false,
+ "multiplier": 1
+ },
+ "capabilities": {
+ "family": "future-model",
+ "limits": {
+ "max_context_window_tokens": 128000,
+ "max_output_tokens": 8192,
+ "max_prompt_tokens": 120000
+ },
+ "object": "model_capabilities",
+ "supports": { "streaming": true, "tool_calls": true },
+ "type": "chat"
+ },
+ "id": "future-model-v1",
+ "is_chat_default": false,
+ "is_chat_fallback": false,
+ "model_picker_enabled": true,
+ "name": "Future Model v1",
+ "object": "model",
+ "preview": false,
+ "vendor": "SomeNewVendor",
+ "version": "v1.0"
+ }
+ ],
+ "object": "list"
+ }"#;
+
+ let schema: ModelSchema = serde_json::from_str(json).unwrap();
+
+ assert_eq!(schema.data.len(), 1);
+ assert_eq!(schema.data[0].id, "future-model-v1");
+ assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
+ }
}
@@ -28,12 +28,6 @@ use settings::SettingsStore;
use ui::{CommonAnimationExt, prelude::*};
use util::debug_panic;
-use crate::provider::x_ai::count_xai_tokens;
-
-use super::anthropic::count_anthropic_tokens;
-use super::google::count_google_tokens;
-use super::open_ai::count_open_ai_tokens;
-
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
const PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("GitHub Copilot Chat");
@@ -191,6 +185,25 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
}
}
+fn collect_tiktoken_messages(
+ request: LanguageModelRequest,
+) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
+ request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>()
+}
+
pub struct CopilotChatLanguageModel {
model: CopilotChatModel,
request_limiter: RateLimiter,
@@ -226,7 +239,7 @@ impl LanguageModel for CopilotChatLanguageModel {
ModelVendor::OpenAI | ModelVendor::Anthropic => {
LanguageModelToolSchemaFormat::JsonSchema
}
- ModelVendor::Google | ModelVendor::XAI => {
+ ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => {
LanguageModelToolSchemaFormat::JsonSchemaSubset
}
}
@@ -253,18 +266,20 @@ impl LanguageModel for CopilotChatLanguageModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
- match self.model.vendor() {
- ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
- ModelVendor::Google => count_google_tokens(request, cx),
- ModelVendor::XAI => {
- let model = x_ai::Model::from_id(self.model.id()).unwrap_or_default();
- count_xai_tokens(request, model, cx)
- }
- ModelVendor::OpenAI => {
- let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
- count_open_ai_tokens(request, model, cx)
- }
- }
+ let model = self.model.clone();
+ cx.background_spawn(async move {
+ let messages = collect_tiktoken_messages(request);
+ // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor).
+ let tokenizer_model = match model.tokenizer() {
+ Some("o200k_base") => "gpt-4o",
+ Some("cl100k_base") => "gpt-4",
+ _ => "gpt-4o",
+ };
+
+ tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages)
+ .map(|tokens| tokens as u64)
+ })
+ .boxed()
}
fn stream_completion(