diff --git a/crates/deepseek/src/deepseek.rs b/crates/deepseek/src/deepseek.rs index 22bde8e5943f1a82c7441354a916f980405582c2..c49270febe3b2b3702b808e2219f6e45d7252267 100644 --- a/crates/deepseek/src/deepseek.rs +++ b/crates/deepseek/src/deepseek.rs @@ -201,13 +201,13 @@ pub struct Response { #[derive(Serialize, Deserialize, Debug)] pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub total_tokens: u64, #[serde(default)] - pub prompt_cache_hit_tokens: u32, + pub prompt_cache_hit_tokens: u64, #[serde(default)] - pub prompt_cache_miss_tokens: u32, + pub prompt_cache_miss_tokens: u64, } #[derive(Serialize, Deserialize, Debug)] @@ -224,6 +224,7 @@ pub struct StreamResponse { pub created: u64, pub model: String, pub choices: Vec, + pub usage: Option, } #[derive(Serialize, Deserialize, Debug)] diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 10030c909109e03c3aeac4e2472c5879740290a4..99a1ca70c6e9ced064c76d4ede427e3b2f5ace0f 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -14,7 +14,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, + RateLimiter, Role, StopReason, TokenUsage, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -513,6 +513,15 @@ impl DeepSeekEventMapper { } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + match choice.finish_reason.as_deref() { Some("stop") => { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index f6e1ea559a3efc73de0b104dbc874e0452393b14..3fa5334eb055196e620fc4370d06e4956c6e576b 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -12,7 +12,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, + RateLimiter, Role, StopReason, TokenUsage, }; use menu; use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion}; @@ -528,11 +528,20 @@ impl OpenAiEventMapper { &mut self, event: ResponseStreamEvent, ) -> Vec> { + let mut events = Vec::new(); + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + let Some(choice) = event.choices.first() else { - return Vec::new(); + return events; }; - let mut events = Vec::new(); if let Some(content) = choice.delta.content.clone() { events.push(Ok(LanguageModelCompletionEvent::Text(content))); } diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 034b4b358a0bb8f89b0c33b65266eefe4a6cca69..5b09aa5cbc17a0c48e4a1fadcbdd0b44cba98e1c 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -364,9 +364,9 @@ pub struct FunctionChunk { #[derive(Serialize, Deserialize, Debug)] pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub total_tokens: u64, } #[derive(Serialize, Deserialize, Debug)]