Track cumulative token usage in assistant2 when using anthropic API (#26738)

Michael Sloan created

Release Notes:

- N/A

Change summary

Cargo.lock                                       |  1 
crates/anthropic/src/anthropic.rs                |  2 
crates/assistant2/src/thread.rs                  | 17 +++
crates/assistant_context_editor/src/context.rs   |  1 
crates/language_model/src/language_model.rs      | 44 ++++++++
crates/language_models/Cargo.toml                |  1 
crates/language_models/src/provider/anthropic.rs | 91 +++++++++++++----
crates/util/src/serde.rs                         |  4 
8 files changed, 136 insertions(+), 25 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -7173,6 +7173,7 @@ dependencies = [
  "http_client",
  "language_model",
  "lmstudio",
+ "log",
  "menu",
  "mistral",
  "ollama",

crates/anthropic/src/anthropic.rs 🔗

@@ -553,7 +553,7 @@ pub struct Metadata {
     pub user_id: Option<String>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, Default)]
 pub struct Usage {
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub input_tokens: Option<u32>,

crates/assistant2/src/thread.rs 🔗

@@ -11,7 +11,7 @@ use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
     LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
-    Role, StopReason,
+    Role, StopReason, TokenUsage,
 };
 use project::Project;
 use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
@@ -81,6 +81,7 @@ pub struct Thread {
     tool_use: ToolUseState,
     scripting_session: Entity<ScriptingSession>,
     scripting_tool_use: ToolUseState,
+    cumulative_token_usage: TokenUsage,
 }
 
 impl Thread {
@@ -109,6 +110,7 @@ impl Thread {
             tool_use: ToolUseState::new(),
             scripting_session,
             scripting_tool_use: ToolUseState::new(),
+            cumulative_token_usage: TokenUsage::default(),
         }
     }
 
@@ -158,6 +160,8 @@ impl Thread {
             tool_use,
             scripting_session,
             scripting_tool_use,
+            // TODO: persist token usage?
+            cumulative_token_usage: TokenUsage::default(),
         }
     }
 
@@ -490,6 +494,7 @@ impl Thread {
             let stream_completion = async {
                 let mut events = stream.await?;
                 let mut stop_reason = StopReason::EndTurn;
+                let mut current_token_usage = TokenUsage::default();
 
                 while let Some(event) = events.next().await {
                     let event = event?;
@@ -502,6 +507,12 @@ impl Thread {
                             LanguageModelCompletionEvent::Stop(reason) => {
                                 stop_reason = reason;
                             }
+                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
+                                thread.cumulative_token_usage =
+                                    thread.cumulative_token_usage.clone() + token_usage.clone()
+                                        - current_token_usage.clone();
+                                current_token_usage = token_usage;
+                            }
                             LanguageModelCompletionEvent::Text(chunk) => {
                                 if let Some(last_message) = thread.messages.last_mut() {
                                     if last_message.role == Role::Assistant {
@@ -843,6 +854,10 @@ impl Thread {
 
         Ok(String::from_utf8_lossy(&markdown).to_string())
     }
+
+    pub fn cumulative_token_usage(&self) -> TokenUsage {
+        self.cumulative_token_usage.clone()
+    }
 }
 
 #[derive(Debug, Clone)]

crates/language_model/src/language_model.rs 🔗

@@ -17,9 +17,11 @@ use proto::Plan;
 use schemars::JsonSchema;
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use std::fmt;
+use std::ops::{Add, Sub};
 use std::{future::Future, sync::Arc};
 use thiserror::Error;
 use ui::IconName;
+use util::serde::is_default;
 
 pub use crate::model::*;
 pub use crate::rate_limiter::*;
@@ -59,6 +61,7 @@ pub enum LanguageModelCompletionEvent {
     Text(String),
     ToolUse(LanguageModelToolUse),
     StartMessage { message_id: String },
+    UsageUpdate(TokenUsage),
 }
 
 #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
@@ -69,6 +72,46 @@ pub enum StopReason {
     ToolUse,
 }
 
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)]
+pub struct TokenUsage {
+    #[serde(default, skip_serializing_if = "is_default")]
+    pub input_tokens: u32,
+    #[serde(default, skip_serializing_if = "is_default")]
+    pub output_tokens: u32,
+    #[serde(default, skip_serializing_if = "is_default")]
+    pub cache_creation_input_tokens: u32,
+    #[serde(default, skip_serializing_if = "is_default")]
+    pub cache_read_input_tokens: u32,
+}
+
+impl Add<TokenUsage> for TokenUsage {
+    type Output = Self;
+
+    fn add(self, other: Self) -> Self {
+        Self {
+            input_tokens: self.input_tokens + other.input_tokens,
+            output_tokens: self.output_tokens + other.output_tokens,
+            cache_creation_input_tokens: self.cache_creation_input_tokens
+                + other.cache_creation_input_tokens,
+            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
+        }
+    }
+}
+
+impl Sub<TokenUsage> for TokenUsage {
+    type Output = Self;
+
+    fn sub(self, other: Self) -> Self {
+        Self {
+            input_tokens: self.input_tokens - other.input_tokens,
+            output_tokens: self.output_tokens - other.output_tokens,
+            cache_creation_input_tokens: self.cache_creation_input_tokens
+                - other.cache_creation_input_tokens,
+            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
+        }
+    }
+}
+
 #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 pub struct LanguageModelToolUseId(Arc<str>);
 
@@ -176,6 +219,7 @@ pub trait LanguageModel: Send + Sync {
                         Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
                         Ok(LanguageModelCompletionEvent::Stop(_)) => None,
                         Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
+                        Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None,
                         Err(err) => Some(Err(err)),
                     }
                 }))

crates/language_models/Cargo.toml 🔗

@@ -33,6 +33,7 @@ gpui_tokio.workspace = true
 http_client.workspace = true
 language_model.workspace = true
 lmstudio = { workspace = true, features = ["schemars"] }
+log.workspace = true
 menu.workspace = true
 mistral = { workspace = true, features = ["schemars"] }
 ollama = { workspace = true, features = ["schemars"] }

crates/language_models/src/provider/anthropic.rs 🔗

@@ -1,6 +1,6 @@
 use crate::ui::InstructionListItem;
 use crate::AllLanguageModelSettings;
-use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
+use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage};
 use anyhow::{anyhow, Context as _, Result};
 use collections::{BTreeMap, HashMap};
 use credentials_provider::CredentialsProvider;
@@ -582,12 +582,16 @@ pub fn map_to_language_model_completion_events(
     struct State {
         events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
         tool_uses_by_index: HashMap<usize, RawToolUse>,
+        usage: Usage,
+        stop_reason: StopReason,
     }
 
     futures::stream::unfold(
         State {
             events,
             tool_uses_by_index: HashMap::default(),
+            usage: Usage::default(),
+            stop_reason: StopReason::EndTurn,
         },
         |mut state| async move {
             while let Some(event) = state.events.next().await {
@@ -599,7 +603,7 @@ pub fn map_to_language_model_completion_events(
                         } => match content_block {
                             ResponseContent::Text { text } => {
                                 return Some((
-                                    Some(Ok(LanguageModelCompletionEvent::Text(text))),
+                                    vec![Ok(LanguageModelCompletionEvent::Text(text))],
                                     state,
                                 ));
                             }
@@ -612,28 +616,25 @@ pub fn map_to_language_model_completion_events(
                                         input_json: String::new(),
                                     },
                                 );
-
-                                return Some((None, state));
                             }
                         },
                         Event::ContentBlockDelta { index, delta } => match delta {
                             ContentDelta::TextDelta { text } => {
                                 return Some((
-                                    Some(Ok(LanguageModelCompletionEvent::Text(text))),
+                                    vec![Ok(LanguageModelCompletionEvent::Text(text))],
                                     state,
                                 ));
                             }
                             ContentDelta::InputJsonDelta { partial_json } => {
                                 if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
                                     tool_use.input_json.push_str(&partial_json);
-                                    return Some((None, state));
                                 }
                             }
                         },
                         Event::ContentBlockStop { index } => {
                             if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
                                 return Some((
-                                    Some(maybe!({
+                                    vec![maybe!({
                                         Ok(LanguageModelCompletionEvent::ToolUse(
                                             LanguageModelToolUse {
                                                 id: tool_use.id.into(),
@@ -650,44 +651,63 @@ pub fn map_to_language_model_completion_events(
                                                 },
                                             },
                                         ))
-                                    })),
+                                    })],
                                     state,
                                 ));
                             }
                         }
                         Event::MessageStart { message } => {
+                            update_usage(&mut state.usage, &message.usage);
                             return Some((
-                                Some(Ok(LanguageModelCompletionEvent::StartMessage {
-                                    message_id: message.id,
-                                })),
+                                vec![
+                                    Ok(LanguageModelCompletionEvent::StartMessage {
+                                        message_id: message.id,
+                                    }),
+                                    Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
+                                        &state.usage,
+                                    ))),
+                                ],
                                 state,
-                            ))
+                            ));
                         }
-                        Event::MessageDelta { delta, .. } => {
+                        Event::MessageDelta { delta, usage } => {
+                            update_usage(&mut state.usage, &usage);
                             if let Some(stop_reason) = delta.stop_reason.as_deref() {
-                                let stop_reason = match stop_reason {
+                                state.stop_reason = match stop_reason {
                                     "end_turn" => StopReason::EndTurn,
                                     "max_tokens" => StopReason::MaxTokens,
                                     "tool_use" => StopReason::ToolUse,
-                                    _ => StopReason::EndTurn,
+                                    _ => {
+                                        log::error!(
+                                            "Unexpected anthropic stop_reason: {stop_reason}"
+                                        );
+                                        StopReason::EndTurn
+                                    }
                                 };
-
-                                return Some((
-                                    Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
-                                    state,
-                                ));
                             }
+                            return Some((
+                                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
+                                    convert_usage(&state.usage),
+                                ))],
+                                state,
+                            ));
+                        }
+                        Event::MessageStop => {
+                            return Some((
+                                vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))],
+                                state,
+                            ));
                         }
                         Event::Error { error } => {
                             return Some((
-                                Some(Err(anyhow!(AnthropicError::ApiError(error)))),
+                                vec![Err(anyhow!(AnthropicError::ApiError(error)))],
                                 state,
                             ));
                         }
                         _ => {}
                     },
                     Err(err) => {
-                        return Some((Some(Err(anyhow!(err))), state));
+                        return Some((vec![Err(anyhow!(err))], state));
                     }
                 }
             }
@@ -695,7 +715,32 @@ pub fn map_to_language_model_completion_events(
             None
         },
     )
-    .filter_map(|event| async move { event })
+    .flat_map(futures::stream::iter)
+}
+
+/// Updates usage data by preferring counts from `new`.
+fn update_usage(usage: &mut Usage, new: &Usage) {
+    if let Some(input_tokens) = new.input_tokens {
+        usage.input_tokens = Some(input_tokens);
+    }
+    if let Some(output_tokens) = new.output_tokens {
+        usage.output_tokens = Some(output_tokens);
+    }
+    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
+        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
+    }
+    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
+        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
+    }
+}
+
+fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
+    language_model::TokenUsage {
+        input_tokens: usage.input_tokens.unwrap_or(0),
+        output_tokens: usage.output_tokens.unwrap_or(0),
+        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
+        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
+    }
 }
 
 struct ConfigurationView {

crates/util/src/serde.rs 🔗

@@ -1,3 +1,7 @@
 pub const fn default_true() -> bool {
     true
 }
+
+pub fn is_default<T: Default + PartialEq>(value: &T) -> bool {
+    *value == T::default()
+}