language_models: Emit UsageUpdate events for token usage in DeepSeek and OpenAI (#33242)

Umesh Yadav created

Closes #ISSUE

Release Notes:

- N/A

Change summary

crates/deepseek/src/deepseek.rs                 | 11 ++++++-----
crates/language_models/src/provider/deepseek.rs | 11 ++++++++++-
crates/language_models/src/provider/open_ai.rs  | 15 ++++++++++++---
crates/open_ai/src/open_ai.rs                   |  6 +++---
4 files changed, 31 insertions(+), 12 deletions(-)

Detailed changes

crates/deepseek/src/deepseek.rs 🔗

@@ -201,13 +201,13 @@ pub struct Response {
 
 #[derive(Serialize, Deserialize, Debug)]
 pub struct Usage {
-    pub prompt_tokens: u32,
-    pub completion_tokens: u32,
-    pub total_tokens: u32,
+    pub prompt_tokens: u64,
+    pub completion_tokens: u64,
+    pub total_tokens: u64,
     #[serde(default)]
-    pub prompt_cache_hit_tokens: u32,
+    pub prompt_cache_hit_tokens: u64,
     #[serde(default)]
-    pub prompt_cache_miss_tokens: u32,
+    pub prompt_cache_miss_tokens: u64,
 }
 
 #[derive(Serialize, Deserialize, Debug)]
@@ -224,6 +224,7 @@ pub struct StreamResponse {
     pub created: u64,
     pub model: String,
     pub choices: Vec<StreamChoice>,
+    pub usage: Option<Usage>,
 }
 
 #[derive(Serialize, Deserialize, Debug)]

crates/language_models/src/provider/deepseek.rs 🔗

@@ -14,7 +14,7 @@ use language_model::{
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
     LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
-    RateLimiter, Role, StopReason,
+    RateLimiter, Role, StopReason, TokenUsage,
 };
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
@@ -513,6 +513,15 @@ impl DeepSeekEventMapper {
             }
         }
 
+        if let Some(usage) = event.usage {
+            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+                input_tokens: usage.prompt_tokens,
+                output_tokens: usage.completion_tokens,
+                cache_creation_input_tokens: 0,
+                cache_read_input_tokens: 0,
+            })));
+        }
+
         match choice.finish_reason.as_deref() {
             Some("stop") => {
                 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));

crates/language_models/src/provider/open_ai.rs 🔗

@@ -12,7 +12,7 @@ use language_model::{
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
     LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
-    RateLimiter, Role, StopReason,
+    RateLimiter, Role, StopReason, TokenUsage,
 };
 use menu;
 use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
@@ -528,11 +528,20 @@ impl OpenAiEventMapper {
         &mut self,
         event: ResponseStreamEvent,
     ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+        let mut events = Vec::new();
+        if let Some(usage) = event.usage {
+            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+                input_tokens: usage.prompt_tokens,
+                output_tokens: usage.completion_tokens,
+                cache_creation_input_tokens: 0,
+                cache_read_input_tokens: 0,
+            })));
+        }
+
         let Some(choice) = event.choices.first() else {
-            return Vec::new();
+            return events;
         };
 
-        let mut events = Vec::new();
         if let Some(content) = choice.delta.content.clone() {
             events.push(Ok(LanguageModelCompletionEvent::Text(content)));
         }

crates/open_ai/src/open_ai.rs 🔗

@@ -364,9 +364,9 @@ pub struct FunctionChunk {
 
 #[derive(Serialize, Deserialize, Debug)]
 pub struct Usage {
-    pub prompt_tokens: u32,
-    pub completion_tokens: u32,
-    pub total_tokens: u32,
+    pub prompt_tokens: u64,
+    pub completion_tokens: u64,
+    pub total_tokens: u64,
 }
 
 #[derive(Serialize, Deserialize, Debug)]