language_models: Improve token counting for providers (#32853)

Umesh Yadav created

We push the usage data whenever we receive it from the provider to make
sure the counting is correct after the turn has ended.

- [x] Ollama 
- [x] Copilot 
- [x] Mistral 
- [x] OpenRouter 
- [x] LMStudio

Put all the changes into a single PR open to move these to separate PR
if that makes the review and testing easier.

Release Notes:

- N/A

Change summary

crates/copilot/src/copilot_chat.rs                  | 14 ++++++++++++++
crates/language_models/src/provider/copilot_chat.rs | 13 ++++++++++++-
crates/language_models/src/provider/lmstudio.rs     | 11 ++++++++++-
crates/language_models/src/provider/mistral.rs      | 11 ++++++++++-
crates/language_models/src/provider/ollama.rs       |  8 +++++++-
crates/language_models/src/provider/open_router.rs  | 14 ++++++++++++--
crates/mistral/src/mistral.rs                       |  1 +
crates/ollama/src/ollama.rs                         |  2 ++
crates/open_router/src/open_router.rs               |  6 ++++++
9 files changed, 74 insertions(+), 6 deletions(-)

Detailed changes

crates/copilot/src/copilot_chat.rs 🔗

@@ -311,6 +311,20 @@ pub struct FunctionContent {
 pub struct ResponseEvent {
     pub choices: Vec<ResponseChoice>,
     pub id: String,
+    pub usage: Option<Usage>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct Usage {
+    pub completion_tokens: u32,
+    pub prompt_tokens: u32,
+    pub prompt_tokens_details: PromptTokensDetails,
+    pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct PromptTokensDetails {
+    pub cached_tokens: u32,
 }
 
 #[derive(Debug, Deserialize)]

crates/language_models/src/provider/copilot_chat.rs 🔗

@@ -24,7 +24,7 @@ use language_model::{
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
     LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
-    StopReason,
+    StopReason, TokenUsage,
 };
 use settings::SettingsStore;
 use std::time::Duration;
@@ -378,6 +378,17 @@ pub fn map_to_language_model_completion_events(
                             }
                         }
 
+                        if let Some(usage) = event.usage {
+                            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
+                                TokenUsage {
+                                    input_tokens: usage.prompt_tokens,
+                                    output_tokens: usage.completion_tokens,
+                                    cache_creation_input_tokens: 0,
+                                    cache_read_input_tokens: 0,
+                                },
+                            )));
+                        }
+
                         match choice.finish_reason.as_deref() {
                             Some("stop") => {
                                 events.push(Ok(LanguageModelCompletionEvent::Stop(

crates/language_models/src/provider/lmstudio.rs 🔗

@@ -7,7 +7,7 @@ use http_client::HttpClient;
 use language_model::{
     AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
     LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
-    StopReason,
+    StopReason, TokenUsage,
 };
 use language_model::{
     LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -528,6 +528,15 @@ impl LmStudioEventMapper {
             }
         }
 
+        if let Some(usage) = event.usage {
+            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+                input_tokens: usage.prompt_tokens,
+                output_tokens: usage.completion_tokens,
+                cache_creation_input_tokens: 0,
+                cache_read_input_tokens: 0,
+            })));
+        }
+
         match choice.finish_reason.as_deref() {
             Some("stop") => {
                 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));

crates/language_models/src/provider/mistral.rs 🔗

@@ -13,7 +13,7 @@ use language_model::{
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
     LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
-    RateLimiter, Role, StopReason,
+    RateLimiter, Role, StopReason, TokenUsage,
 };
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
@@ -626,6 +626,15 @@ impl MistralEventMapper {
             }
         }
 
+        if let Some(usage) = event.usage {
+            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+                input_tokens: usage.prompt_tokens,
+                output_tokens: usage.completion_tokens,
+                cache_creation_input_tokens: 0,
+                cache_read_input_tokens: 0,
+            })));
+        }
+
         if let Some(finish_reason) = choice.finish_reason.as_deref() {
             match finish_reason {
                 "stop" => {

crates/language_models/src/provider/ollama.rs 🔗

@@ -8,7 +8,7 @@ use language_model::{
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
     LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
-    LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason,
+    LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
 };
 use ollama::{
     ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
@@ -507,6 +507,12 @@ fn map_to_language_model_completion_events(
             };
 
             if delta.done {
+                events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+                    input_tokens: delta.prompt_eval_count.unwrap_or(0),
+                    output_tokens: delta.eval_count.unwrap_or(0),
+                    cache_creation_input_tokens: 0,
+                    cache_read_input_tokens: 0,
+                })));
                 if state.used_tools {
                     state.used_tools = false;
                     events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));

crates/language_models/src/provider/open_router.rs 🔗

@@ -12,7 +12,7 @@ use language_model::{
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
     LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
-    RateLimiter, Role, StopReason,
+    RateLimiter, Role, StopReason, TokenUsage,
 };
 use open_router::{Model, ResponseStreamEvent, list_models, stream_completion};
 use schemars::JsonSchema;
@@ -467,6 +467,7 @@ pub fn into_open_router(
         } else {
             None
         },
+        usage: open_router::RequestUsage { include: true },
         tools: request
             .tools
             .into_iter()
@@ -581,6 +582,15 @@ impl OpenRouterEventMapper {
             }
         }
 
+        if let Some(usage) = event.usage {
+            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+                input_tokens: usage.prompt_tokens,
+                output_tokens: usage.completion_tokens,
+                cache_creation_input_tokens: 0,
+                cache_read_input_tokens: 0,
+            })));
+        }
+
         match choice.finish_reason.as_deref() {
             Some("stop") => {
                 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
@@ -609,7 +619,7 @@ impl OpenRouterEventMapper {
                 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
             }
             Some(stop_reason) => {
-                log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
+                log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
                 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
             }
             None => {}

crates/mistral/src/mistral.rs 🔗

@@ -379,6 +379,7 @@ pub struct StreamResponse {
     pub created: u64,
     pub model: String,
     pub choices: Vec<StreamChoice>,
+    pub usage: Option<Usage>,
 }
 
 #[derive(Serialize, Deserialize, Debug)]

crates/ollama/src/ollama.rs 🔗

@@ -183,6 +183,8 @@ pub struct ChatResponseDelta {
     pub done_reason: Option<String>,
     #[allow(unused)]
     pub done: bool,
+    pub prompt_eval_count: Option<u32>,
+    pub eval_count: Option<u32>,
 }
 
 #[derive(Serialize, Deserialize)]

crates/open_router/src/open_router.rs 🔗

@@ -127,6 +127,12 @@ pub struct Request {
     pub parallel_tool_calls: Option<bool>,
     #[serde(default, skip_serializing_if = "Vec::is_empty")]
     pub tools: Vec<ToolDefinition>,
+    pub usage: RequestUsage,
+}
+
+#[derive(Debug, Default, Serialize, Deserialize)]
+pub struct RequestUsage {
+    pub include: bool,
 }
 
 #[derive(Debug, Serialize, Deserialize)]