agent2: Token count (#36496)

Bennet Bo Fenner and Agus Zubiaga created

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/acp_thread/src/acp_thread.rs    |  19 +++
crates/acp_thread/src/connection.rs    |   2 
crates/agent2/Cargo.toml               |   1 
crates/agent2/src/agent.rs             |  44 ++++++--
crates/agent2/src/db.rs                |  21 +++
crates/agent2/src/tests/mod.rs         | 144 +++++++++++++++++++++++++++
crates/agent2/src/thread.rs            |  74 +++++++++++++-
crates/agent_ui/src/acp/thread_view.rs |  41 +++++++
crates/agent_ui/src/agent_diff.rs      |   1 
9 files changed, 321 insertions(+), 26 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -6,6 +6,7 @@ mod terminal;
 pub use connection::*;
 pub use diff::*;
 pub use mention::*;
+use serde::{Deserialize, Serialize};
 pub use terminal::*;
 
 use action_log::ActionLog;
@@ -664,6 +665,12 @@ impl PlanEntry {
     }
 }
 
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub struct TokenUsage {
+    pub max_tokens: u64,
+    pub used_tokens: u64,
+}
+
 #[derive(Debug, Clone)]
 pub struct RetryStatus {
     pub last_error: SharedString,
@@ -683,12 +690,14 @@ pub struct AcpThread {
     send_task: Option<Task<()>>,
     connection: Rc<dyn AgentConnection>,
     session_id: acp::SessionId,
+    token_usage: Option<TokenUsage>,
 }
 
 #[derive(Debug)]
 pub enum AcpThreadEvent {
     NewEntry,
     TitleUpdated,
+    TokenUsageUpdated,
     EntryUpdated(usize),
     EntriesRemoved(Range<usize>),
     ToolAuthorizationRequired,
@@ -748,6 +757,7 @@ impl AcpThread {
             send_task: None,
             connection,
             session_id,
+            token_usage: None,
         }
     }
 
@@ -787,6 +797,10 @@ impl AcpThread {
         }
     }
 
+    pub fn token_usage(&self) -> Option<&TokenUsage> {
+        self.token_usage.as_ref()
+    }
+
     pub fn has_pending_edit_tool_calls(&self) -> bool {
         for entry in self.entries.iter().rev() {
             match entry {
@@ -937,6 +951,11 @@ impl AcpThread {
         Ok(())
     }
 
+    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
+        self.token_usage = usage;
+        cx.emit(AcpThreadEvent::TokenUsageUpdated);
+    }
+
     pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
         cx.emit(AcpThreadEvent::Retry(status));
     }

crates/acp_thread/src/connection.rs 🔗

@@ -10,7 +10,7 @@ use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 use ui::{App, IconName};
 use uuid::Uuid;
 
-#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 pub struct UserMessageId(Arc<str>);
 
 impl UserMessageId {

crates/agent2/Cargo.toml 🔗

@@ -66,6 +66,7 @@ zstd.workspace = true
 
 [dev-dependencies]
 agent = { workspace = true, "features" = ["test-support"] }
+assistant_context = { workspace = true, "features" = ["test-support"] }
 ctor.workspace = true
 client = { workspace = true, "features" = ["test-support"] }
 clock = { workspace = true, "features" = ["test-support"] }

crates/agent2/src/agent.rs 🔗

@@ -1,8 +1,8 @@
+use crate::HistoryStore;
 use crate::{
-    ContextServerRegistry, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent,
-    templates::Templates,
+    ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
+    UserMessageContent, templates::Templates,
 };
-use crate::{HistoryStore, ThreadsDatabase};
 use acp_thread::{AcpThread, AgentModelSelector};
 use action_log::ActionLog;
 use agent_client_protocol as acp;
@@ -673,6 +673,11 @@ impl NativeAgentConnection {
                                     thread.update_tool_call(update, cx)
                                 })??;
                             }
+                            ThreadEvent::TokenUsageUpdate(usage) => {
+                                acp_thread.update(cx, |thread, cx| {
+                                    thread.update_token_usage(Some(usage), cx)
+                                })?;
+                            }
                             ThreadEvent::TitleUpdate(title) => {
                                 acp_thread
                                     .update(cx, |thread, cx| thread.update_title(title, cx))??;
@@ -895,10 +900,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         cx: &mut App,
     ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
         self.0.update(cx, |agent, _cx| {
-            agent
-                .sessions
-                .get(session_id)
-                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
+            agent.sessions.get(session_id).map(|session| {
+                Rc::new(NativeAgentSessionEditor {
+                    thread: session.thread.clone(),
+                    acp_thread: session.acp_thread.clone(),
+                }) as _
+            })
         })
     }
 
@@ -907,14 +914,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
     }
 }
 
-struct NativeAgentSessionEditor(Entity<Thread>);
+struct NativeAgentSessionEditor {
+    thread: Entity<Thread>,
+    acp_thread: WeakEntity<AcpThread>,
+}
 
 impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
     fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
-        Task::ready(
-            self.0
-                .update(cx, |thread, cx| thread.truncate(message_id, cx)),
-        )
+        match self.thread.update(cx, |thread, cx| {
+            thread.truncate(message_id.clone(), cx)?;
+            Ok(thread.latest_token_usage())
+        }) {
+            Ok(usage) => {
+                self.acp_thread
+                    .update(cx, |thread, cx| {
+                        thread.update_token_usage(usage, cx);
+                    })
+                    .ok();
+                Task::ready(Ok(()))
+            }
+            Err(error) => Task::ready(Err(error)),
+        }
     }
 }
 

crates/agent2/src/db.rs 🔗

@@ -1,4 +1,5 @@
 use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
+use acp_thread::UserMessageId;
 use agent::thread_store;
 use agent_client_protocol as acp;
 use agent_settings::{AgentProfileId, CompletionMode};
@@ -42,7 +43,7 @@ pub struct DbThread {
     #[serde(default)]
     pub cumulative_token_usage: language_model::TokenUsage,
     #[serde(default)]
-    pub request_token_usage: Vec<language_model::TokenUsage>,
+    pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
     #[serde(default)]
     pub model: Option<DbLanguageModel>,
     #[serde(default)]
@@ -67,7 +68,10 @@ impl DbThread {
 
     fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
         let mut messages = Vec::new();
-        for msg in thread.messages {
+        let mut request_token_usage = HashMap::default();
+
+        let mut last_user_message_id = None;
+        for (ix, msg) in thread.messages.into_iter().enumerate() {
             let message = match msg.role {
                 language_model::Role::User => {
                     let mut content = Vec::new();
@@ -93,9 +97,12 @@ impl DbThread {
                         content.push(UserMessageContent::Text(msg.context));
                     }
 
+                    let id = UserMessageId::new();
+                    last_user_message_id = Some(id.clone());
+
                     crate::Message::User(UserMessage {
                         // MessageId from old format can't be meaningfully converted, so generate a new one
-                        id: acp_thread::UserMessageId::new(),
+                        id,
                         content,
                     })
                 }
@@ -154,6 +161,12 @@ impl DbThread {
                         );
                     }
 
+                    if let Some(last_user_message_id) = &last_user_message_id
+                        && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
+                    {
+                        request_token_usage.insert(last_user_message_id.clone(), token_usage);
+                    }
+
                     crate::Message::Agent(AgentMessage {
                         content,
                         tool_results,
@@ -175,7 +188,7 @@ impl DbThread {
             summary: thread.detailed_summary_state,
             initial_project_snapshot: thread.initial_project_snapshot,
             cumulative_token_usage: thread.cumulative_token_usage,
-            request_token_usage: thread.request_token_usage,
+            request_token_usage,
             model: thread.model,
             completion_mode: thread.completion_mode,
             profile: thread.profile,

crates/agent2/src/tests/mod.rs 🔗

@@ -1117,7 +1117,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
 }
 
 #[gpui::test]
-async fn test_truncate(cx: &mut TestAppContext) {
+async fn test_truncate_first_message(cx: &mut TestAppContext) {
     let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
     let fake_model = model.as_fake();
 
@@ -1137,9 +1137,18 @@ async fn test_truncate(cx: &mut TestAppContext) {
                 Hello
             "}
         );
+        assert_eq!(thread.latest_token_usage(), None);
     });
 
     fake_model.send_last_completion_stream_text_chunk("Hey!");
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+        language_model::TokenUsage {
+            input_tokens: 32_000,
+            output_tokens: 16_000,
+            cache_creation_input_tokens: 0,
+            cache_read_input_tokens: 0,
+        },
+    ));
     cx.run_until_parked();
     thread.read_with(cx, |thread, _| {
         assert_eq!(
@@ -1154,6 +1163,13 @@ async fn test_truncate(cx: &mut TestAppContext) {
                 Hey!
             "}
         );
+        assert_eq!(
+            thread.latest_token_usage(),
+            Some(acp_thread::TokenUsage {
+                used_tokens: 32_000 + 16_000,
+                max_tokens: 1_000_000,
+            })
+        );
     });
 
     thread
@@ -1162,6 +1178,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
     cx.run_until_parked();
     thread.read_with(cx, |thread, _| {
         assert_eq!(thread.to_markdown(), "");
+        assert_eq!(thread.latest_token_usage(), None);
     });
 
     // Ensure we can still send a new message after truncation.
@@ -1182,6 +1199,14 @@ async fn test_truncate(cx: &mut TestAppContext) {
     });
     cx.run_until_parked();
     fake_model.send_last_completion_stream_text_chunk("Ahoy!");
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+        language_model::TokenUsage {
+            input_tokens: 40_000,
+            output_tokens: 20_000,
+            cache_creation_input_tokens: 0,
+            cache_read_input_tokens: 0,
+        },
+    ));
     cx.run_until_parked();
     thread.read_with(cx, |thread, _| {
         assert_eq!(
@@ -1196,7 +1221,124 @@ async fn test_truncate(cx: &mut TestAppContext) {
                 Ahoy!
             "}
         );
+
+        assert_eq!(
+            thread.latest_token_usage(),
+            Some(acp_thread::TokenUsage {
+                used_tokens: 40_000 + 20_000,
+                max_tokens: 1_000_000,
+            })
+        );
+    });
+}
+
+#[gpui::test]
+async fn test_truncate_second_message(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    thread
+        .update(cx, |thread, cx| {
+            thread.send(UserMessageId::new(), ["Message 1"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+        language_model::TokenUsage {
+            input_tokens: 32_000,
+            output_tokens: 16_000,
+            cache_creation_input_tokens: 0,
+            cache_read_input_tokens: 0,
+        },
+    ));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    let assert_first_message_state = |cx: &mut TestAppContext| {
+        thread.clone().read_with(cx, |thread, _| {
+            assert_eq!(
+                thread.to_markdown(),
+                indoc! {"
+                    ## User
+
+                    Message 1
+
+                    ## Assistant
+
+                    Message 1 response
+                "}
+            );
+
+            assert_eq!(
+                thread.latest_token_usage(),
+                Some(acp_thread::TokenUsage {
+                    used_tokens: 32_000 + 16_000,
+                    max_tokens: 1_000_000,
+                })
+            );
+        });
+    };
+
+    assert_first_message_state(cx);
+
+    let second_message_id = UserMessageId::new();
+    thread
+        .update(cx, |thread, cx| {
+            thread.send(second_message_id.clone(), ["Message 2"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+        language_model::TokenUsage {
+            input_tokens: 40_000,
+            output_tokens: 20_000,
+            cache_creation_input_tokens: 0,
+            cache_read_input_tokens: 0,
+        },
+    ));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    thread.read_with(cx, |thread, _| {
+        assert_eq!(
+            thread.to_markdown(),
+            indoc! {"
+                ## User
+
+                Message 1
+
+                ## Assistant
+
+                Message 1 response
+
+                ## User
+
+                Message 2
+
+                ## Assistant
+
+                Message 2 response
+            "}
+        );
+
+        assert_eq!(
+            thread.latest_token_usage(),
+            Some(acp_thread::TokenUsage {
+                used_tokens: 40_000 + 20_000,
+                max_tokens: 1_000_000,
+            })
+        );
     });
+
+    thread
+        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
+        .unwrap();
+    cx.run_until_parked();
+
+    assert_first_message_state(cx);
 }
 
 #[gpui::test]

crates/agent2/src/thread.rs 🔗

@@ -13,7 +13,7 @@ use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
 use chrono::{DateTime, Utc};
 use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
-use collections::IndexMap;
+use collections::{HashMap, IndexMap};
 use fs::Fs;
 use futures::{
     FutureExt,
@@ -24,8 +24,8 @@ use futures::{
 use git::repository::DiffType;
 use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
 use language_model::{
-    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
-    LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
+    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
+    LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
     LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
     LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
@@ -481,6 +481,7 @@ pub enum ThreadEvent {
     ToolCall(acp::ToolCall),
     ToolCallUpdate(acp_thread::ToolCallUpdate),
     ToolCallAuthorization(ToolCallAuthorization),
+    TokenUsageUpdate(acp_thread::TokenUsage),
     TitleUpdate(SharedString),
     Retry(acp_thread::RetryStatus),
     Stop(acp::StopReason),
@@ -509,8 +510,7 @@ pub struct Thread {
     pending_message: Option<AgentMessage>,
     tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
     tool_use_limit_reached: bool,
-    #[allow(unused)]
-    request_token_usage: Vec<TokenUsage>,
+    request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
     #[allow(unused)]
     cumulative_token_usage: TokenUsage,
     #[allow(unused)]
@@ -548,7 +548,7 @@ impl Thread {
             pending_message: None,
             tools: BTreeMap::default(),
             tool_use_limit_reached: false,
-            request_token_usage: Vec::new(),
+            request_token_usage: HashMap::default(),
             cumulative_token_usage: TokenUsage::default(),
             initial_project_snapshot: {
                 let project_snapshot = Self::project_snapshot(project.clone(), cx);
@@ -951,6 +951,15 @@ impl Thread {
         self.flush_pending_message(cx);
     }
 
+    pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
+        let Some(last_user_message) = self.last_user_message() else {
+            return;
+        };
+
+        self.request_token_usage
+            .insert(last_user_message.id.clone(), update);
+    }
+
     pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
         self.cancel(cx);
         let Some(position) = self.messages.iter().position(
@@ -958,11 +967,31 @@ impl Thread {
         ) else {
             return Err(anyhow!("Message not found"));
         };
-        self.messages.truncate(position);
+
+        for message in self.messages.drain(position..) {
+            match message {
+                Message::User(message) => {
+                    self.request_token_usage.remove(&message.id);
+                }
+                Message::Agent(_) | Message::Resume => {}
+            }
+        }
+
         cx.notify();
         Ok(())
     }
 
+    pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
+        let last_user_message = self.last_user_message()?;
+        let tokens = self.request_token_usage.get(&last_user_message.id)?;
+        let model = self.model.clone()?;
+
+        Some(acp_thread::TokenUsage {
+            max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
+            used_tokens: tokens.total_tokens(),
+        })
+    }
+
     pub fn resume(
         &mut self,
         cx: &mut Context<Self>,
@@ -1148,6 +1177,21 @@ impl Thread {
                     )) => {
                         *tool_use_limit_reached = true;
                     }
+                    Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
+                        let usage = acp_thread::TokenUsage {
+                            max_tokens: model.max_token_count_for_mode(
+                                request
+                                    .mode
+                                    .unwrap_or(cloud_llm_client::CompletionMode::Normal),
+                            ),
+                            used_tokens: token_usage.total_tokens(),
+                        };
+
+                        this.update(cx, |this, _cx| this.update_token_usage(token_usage))
+                            .ok();
+
+                        event_stream.send_token_usage_update(usage);
+                    }
                     Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
                         *refusal = true;
                         return Ok(FuturesUnordered::default());
@@ -1532,6 +1576,16 @@ impl Thread {
             })
         }))
     }
+    fn last_user_message(&self) -> Option<&UserMessage> {
+        self.messages
+            .iter()
+            .rev()
+            .find_map(|message| match message {
+                Message::User(user_message) => Some(user_message),
+                Message::Agent(_) => None,
+                Message::Resume => None,
+            })
+    }
 
     fn pending_message(&mut self) -> &mut AgentMessage {
         self.pending_message.get_or_insert_default()
@@ -2051,6 +2105,12 @@ impl ThreadEventStream {
             .ok();
     }
 
+    fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
+        self.0
+            .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
+            .ok();
+    }
+
     fn send_retry(&self, status: acp_thread::RetryStatus) {
         self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
     }

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -816,7 +816,7 @@ impl AcpThreadView {
                 self.thread_retry_status.take();
                 self.thread_state = ThreadState::ServerExited { status: *status };
             }
-            AcpThreadEvent::TitleUpdated => {}
+            AcpThreadEvent::TitleUpdated | AcpThreadEvent::TokenUsageUpdated => {}
         }
         cx.notify();
     }
@@ -2794,6 +2794,7 @@ impl AcpThreadView {
                     .child(
                         h_flex()
                             .gap_1()
+                            .children(self.render_token_usage(cx))
                             .children(self.profile_selector.clone())
                             .children(self.model_selector.clone())
                             .child(self.render_send_button(cx)),
@@ -2816,6 +2817,44 @@ impl AcpThreadView {
             .thread(acp_thread.session_id(), cx)
     }
 
+    fn render_token_usage(&self, cx: &mut Context<Self>) -> Option<Div> {
+        let thread = self.thread()?.read(cx);
+        let usage = thread.token_usage()?;
+        let is_generating = thread.status() != ThreadStatus::Idle;
+
+        let used = crate::text_thread_editor::humanize_token_count(usage.used_tokens);
+        let max = crate::text_thread_editor::humanize_token_count(usage.max_tokens);
+
+        Some(
+            h_flex()
+                .flex_shrink_0()
+                .gap_0p5()
+                .mr_1()
+                .child(
+                    Label::new(used)
+                        .size(LabelSize::Small)
+                        .color(Color::Muted)
+                        .map(|label| {
+                            if is_generating {
+                                label
+                                    .with_animation(
+                                        "used-tokens-label",
+                                        Animation::new(Duration::from_secs(2))
+                                            .repeat()
+                                            .with_easing(pulsating_between(0.6, 1.)),
+                                        |label, delta| label.alpha(delta),
+                                    )
+                                    .into_any()
+                            } else {
+                                label.into_any_element()
+                            }
+                        }),
+                )
+                .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
+                .child(Label::new(max).size(LabelSize::Small).color(Color::Muted)),
+        )
+    }
+
     fn toggle_burn_mode(
         &mut self,
         _: &ToggleBurnMode,

crates/agent_ui/src/agent_diff.rs 🔗

@@ -1526,6 +1526,7 @@ impl AgentDiff {
                 self.update_reviewing_editors(workspace, window, cx);
             }
             AcpThreadEvent::TitleUpdated
+            | AcpThreadEvent::TokenUsageUpdated
             | AcpThreadEvent::EntriesRemoved(_)
             | AcpThreadEvent::ToolAuthorizationRequired
             | AcpThreadEvent::Retry(_) => {}