agent2: Fix token count not updating when changing model/toggling burn mode (#36562)

Bennet Bo Fenner and Antonio Scandurra created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/agent2/Cargo.toml    |  1 
crates/agent2/src/agent.rs  | 25 +++++++++---
crates/agent2/src/thread.rs | 76 +++++++++++++++++++++++++-------------
3 files changed, 69 insertions(+), 33 deletions(-)

Detailed changes

crates/agent2/Cargo.toml 🔗

@@ -26,6 +26,7 @@ assistant_context.workspace = true
 assistant_tool.workspace = true
 assistant_tools.workspace = true
 chrono.workspace = true
+client.workspace = true
 cloud_llm_client.workspace = true
 collections.workspace = true
 context_server.workspace = true

crates/agent2/src/agent.rs 🔗

@@ -1,8 +1,8 @@
-use crate::HistoryStore;
 use crate::{
     ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
     UserMessageContent, templates::Templates,
 };
+use crate::{HistoryStore, TokenUsageUpdated};
 use acp_thread::{AcpThread, AgentModelSelector};
 use action_log::ActionLog;
 use agent_client_protocol as acp;
@@ -253,6 +253,7 @@ impl NativeAgent {
             cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
                 this.sessions.remove(acp_thread.session_id());
             }),
+            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
             cx.observe(&thread_handle, move |this, thread, cx| {
                 this.save_thread(thread.clone(), cx)
             }),
@@ -440,6 +441,23 @@ impl NativeAgent {
         })
     }
 
+    fn handle_thread_token_usage_updated(
+        &mut self,
+        thread: Entity<Thread>,
+        usage: &TokenUsageUpdated,
+        cx: &mut Context<Self>,
+    ) {
+        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
+            return;
+        };
+        session
+            .acp_thread
+            .update(cx, |acp_thread, cx| {
+                acp_thread.update_token_usage(usage.0.clone(), cx);
+            })
+            .ok();
+    }
+
     fn handle_project_event(
         &mut self,
         _project: Entity<Project>,
@@ -695,11 +713,6 @@ impl NativeAgentConnection {
                                     thread.update_tool_call(update, cx)
                                 })??;
                             }
-                            ThreadEvent::TokenUsageUpdate(usage) => {
-                                acp_thread.update(cx, |thread, cx| {
-                                    thread.update_token_usage(Some(usage), cx)
-                                })?;
-                            }
                             ThreadEvent::TitleUpdate(title) => {
                                 acp_thread
                                     .update(cx, |thread, cx| thread.update_title(title, cx))??;

crates/agent2/src/thread.rs 🔗

@@ -15,7 +15,8 @@ use agent_settings::{
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
 use chrono::{DateTime, Utc};
-use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
+use client::{ModelRequestUsage, RequestUsage};
+use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
 use collections::{HashMap, IndexMap};
 use fs::Fs;
 use futures::{
@@ -25,7 +26,9 @@ use futures::{
     stream::FuturesUnordered,
 };
 use git::repository::DiffType;
-use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
+use gpui::{
+    App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
+};
 use language_model::{
     LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
     LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
@@ -484,7 +487,6 @@ pub enum ThreadEvent {
     ToolCall(acp::ToolCall),
     ToolCallUpdate(acp_thread::ToolCallUpdate),
     ToolCallAuthorization(ToolCallAuthorization),
-    TokenUsageUpdate(acp_thread::TokenUsage),
     TitleUpdate(SharedString),
     Retry(acp_thread::RetryStatus),
     Stop(acp::StopReason),
@@ -873,7 +875,12 @@ impl Thread {
     }
 
     pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
+        let old_usage = self.latest_token_usage();
         self.model = Some(model);
+        let new_usage = self.latest_token_usage();
+        if old_usage != new_usage {
+            cx.emit(TokenUsageUpdated(new_usage));
+        }
         cx.notify()
     }
 
@@ -891,7 +898,12 @@ impl Thread {
     }
 
     pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
+        let old_usage = self.latest_token_usage();
         self.completion_mode = mode;
+        let new_usage = self.latest_token_usage();
+        if old_usage != new_usage {
+            cx.emit(TokenUsageUpdated(new_usage));
+        }
         cx.notify()
     }
 
@@ -953,13 +965,15 @@ impl Thread {
         self.flush_pending_message(cx);
     }
 
-    pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
+    fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
         let Some(last_user_message) = self.last_user_message() else {
             return;
         };
 
         self.request_token_usage
             .insert(last_user_message.id.clone(), update);
+        cx.emit(TokenUsageUpdated(self.latest_token_usage()));
+        cx.notify();
     }
 
     pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
@@ -1180,20 +1194,15 @@ impl Thread {
                     )) => {
                         *tool_use_limit_reached = true;
                     }
+                    Ok(LanguageModelCompletionEvent::StatusUpdate(
+                        CompletionRequestStatus::UsageUpdated { amount, limit },
+                    )) => {
+                        this.update(cx, |this, cx| {
+                            this.update_model_request_usage(amount, limit, cx)
+                        })?;
+                    }
                     Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
-                        let usage = acp_thread::TokenUsage {
-                            max_tokens: model.max_token_count_for_mode(
-                                request
-                                    .mode
-                                    .unwrap_or(cloud_llm_client::CompletionMode::Normal),
-                            ),
-                            used_tokens: token_usage.total_tokens(),
-                        };
-
-                        this.update(cx, |this, _cx| this.update_token_usage(token_usage))
-                            .ok();
-
-                        event_stream.send_token_usage_update(usage);
+                        this.update(cx, |this, cx| this.update_token_usage(token_usage, cx))?;
                     }
                     Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
                         *refusal = true;
@@ -1214,8 +1223,7 @@ impl Thread {
                                 event_stream,
                                 cx,
                             ));
-                        })
-                        .ok();
+                        })?;
                     }
                     Err(error) => {
                         let completion_mode =
@@ -1325,8 +1333,8 @@ impl Thread {
                     json_parse_error,
                 )));
             }
-            UsageUpdate(_) | StatusUpdate(_) => {}
-            Stop(_) => unreachable!(),
+            StatusUpdate(_) => {}
+            UsageUpdate(_) | Stop(_) => unreachable!(),
         }
 
         None
@@ -1506,6 +1514,21 @@ impl Thread {
         }
     }
 
+    fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
+        self.project
+            .read(cx)
+            .user_store()
+            .update(cx, |user_store, cx| {
+                user_store.update_model_request_usage(
+                    ModelRequestUsage(RequestUsage {
+                        amount: amount as i32,
+                        limit,
+                    }),
+                    cx,
+                )
+            });
+    }
+
     pub fn title(&self) -> SharedString {
         self.title.clone().unwrap_or("New Thread".into())
     }
@@ -1636,6 +1659,7 @@ impl Thread {
             })
         }))
     }
+
     fn last_user_message(&self) -> Option<&UserMessage> {
         self.messages
             .iter()
@@ -1934,6 +1958,10 @@ impl RunningTurn {
     }
 }
 
+pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
+
+impl EventEmitter<TokenUsageUpdated> for Thread {}
+
 pub trait AgentTool
 where
     Self: 'static + Sized,
@@ -2166,12 +2194,6 @@ impl ThreadEventStream {
             .ok();
     }
 
-    fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
-        self.0
-            .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
-            .ok();
-    }
-
     fn send_retry(&self, status: acp_thread::RetryStatus) {
         self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
     }