From ed4b29f80c971df820ad1477102ed4ccf034450f Mon Sep 17 00:00:00 2001 From: Umesh Yadav <23421535+imumesh18@users.noreply.github.com> Date: Tue, 17 Jun 2025 16:16:29 +0530 Subject: [PATCH] language_models: Improve token counting for providers (#32853) We push the usage data whenever we receive it from the provider to make sure the counting is correct after the turn has ended. - [x] Ollama - [x] Copilot - [x] Mistral - [x] OpenRouter - [x] LMStudio Put all the changes into a single PR open to move these to separate PR if that makes the review and testing easier. Release Notes: - N/A --- crates/copilot/src/copilot_chat.rs | 14 ++++++++++++++ .../language_models/src/provider/copilot_chat.rs | 13 ++++++++++++- crates/language_models/src/provider/lmstudio.rs | 11 ++++++++++- crates/language_models/src/provider/mistral.rs | 11 ++++++++++- crates/language_models/src/provider/ollama.rs | 8 +++++++- crates/language_models/src/provider/open_router.rs | 14 ++++++++++++-- crates/mistral/src/mistral.rs | 1 + crates/ollama/src/ollama.rs | 2 ++ crates/open_router/src/open_router.rs | 6 ++++++ 9 files changed, 74 insertions(+), 6 deletions(-) diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 0f81df2e081b96a244f8f1ca572eeb21a635d022..c89e4cdb980443da5fcbd31efb1fcfd43a1ac13f 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -311,6 +311,20 @@ pub struct FunctionContent { pub struct ResponseEvent { pub choices: Vec, pub id: String, + pub usage: Option, +} + +#[derive(Deserialize, Debug)] +pub struct Usage { + pub completion_tokens: u32, + pub prompt_tokens: u32, + pub prompt_tokens_details: PromptTokensDetails, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct PromptTokensDetails { + pub cached_tokens: u32, } #[derive(Debug, Deserialize)] diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 475f77c3180640e2a9207a45071df02ebf1163b9..e0ccbcbae6acf6da69b4503843abe29c1b4e27c8 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -24,7 +24,7 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, - StopReason, + StopReason, TokenUsage, }; use settings::SettingsStore; use std::time::Duration; @@ -378,6 +378,17 @@ pub fn map_to_language_model_completion_events( } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + ))); + } + match choice.finish_reason.as_deref() { Some("stop") => { events.push(Ok(LanguageModelCompletionEvent::Stop( diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 8cb0829c2a2393290f09b25305011acbc14f8135..0a75ef2f88bfdd33ad06a127b1f58e09bed493a6 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -7,7 +7,7 @@ use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - StopReason, + StopReason, TokenUsage, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, @@ -528,6 +528,15 @@ impl LmStudioEventMapper { } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + match choice.finish_reason.as_deref() { Some("stop") => { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index d00af8ecd6acce9138f7afb2e2178bd653c3b5c1..84b7131c7dc7c4a8cd82922bb1ea6eb236790c7d 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -13,7 +13,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, + RateLimiter, Role, StopReason, TokenUsage, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -626,6 +626,15 @@ impl MistralEventMapper { } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + if let Some(finish_reason) = choice.finish_reason.as_deref() { match finish_reason { "stop" => { diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index fca78e4791f9c3c0a76f2f83710b8268b3dc0e9b..42ccd970891557639f3761e52f9c608de15ee8a7 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -8,7 +8,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, + LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; use ollama::{ ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, @@ -507,6 +507,12 @@ fn map_to_language_model_completion_events( }; if delta.done { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: delta.prompt_eval_count.unwrap_or(0), + output_tokens: delta.eval_count.unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); if state.used_tools { state.used_tools = false; events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 3d1cefa07f07a169f78fbf27c3454f5371ce21b6..450d56a1b204df94cade86c3d7a04e847d62d32f 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -12,7 +12,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, + RateLimiter, Role, StopReason, TokenUsage, }; use open_router::{Model, ResponseStreamEvent, list_models, stream_completion}; use schemars::JsonSchema; @@ -467,6 +467,7 @@ pub fn into_open_router( } else { None }, + usage: open_router::RequestUsage { include: true }, tools: request .tools .into_iter() @@ -581,6 +582,15 @@ impl OpenRouterEventMapper { } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + match choice.finish_reason.as_deref() { Some("stop") => { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); @@ -609,7 +619,7 @@ impl OpenRouterEventMapper { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); } Some(stop_reason) => { - log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); } None => {} diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index 7ad3b1c2948e91dfa73e9992db22cf50d174964b..4fc976860c34da6eb851c3c587f54e20e59909fb 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -379,6 +379,7 @@ pub struct StreamResponse { pub created: u64, pub model: String, pub choices: Vec, + pub usage: Option, } #[derive(Serialize, Deserialize, Debug)] diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 95a7ded680683c014b7353b312fb8736ee9ea8cd..e17b08cde62c1be353865bb95e975140e808c2a5 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -183,6 +183,8 @@ pub struct ChatResponseDelta { pub done_reason: Option, #[allow(unused)] pub done: bool, + pub prompt_eval_count: Option, + pub eval_count: Option, } #[derive(Serialize, Deserialize)] diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index ad3009b48ff184b664fb9237af5af179af2ad837..407ed416ecdb72ff1d631593d7461982989abc0f 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -127,6 +127,12 @@ pub struct Request { pub parallel_tool_calls: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, + pub usage: RequestUsage, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct RequestUsage { + pub include: bool, } #[derive(Debug, Serialize, Deserialize)]