acp: Initial support for ACP usage (#53894)

Ben Brandt created

Adds initial beta test of ACP usage stats. Behind a flag for now while
we work on standardizing the usage values.

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Release Notes:

- N/A

Change summary

Cargo.lock                                           |   1 
crates/acp_thread/Cargo.toml                         |   1 
crates/acp_thread/src/acp_thread.rs                  | 208 +++++++++++++
crates/agent_ui/src/conversation_view/thread_view.rs |  34 ++
4 files changed, 243 insertions(+), 1 deletion(-)

Detailed changes

Cargo.lock 🔗

@@ -14,6 +14,7 @@ dependencies = [
  "chrono",
  "collections",
  "env_logger 0.11.8",
+ "feature_flags",
  "file_icons",
  "futures 0.3.32",
  "gpui",

crates/acp_thread/Cargo.toml 🔗

@@ -23,6 +23,7 @@ anyhow.workspace = true
 buffer_diff.workspace = true
 chrono.workspace = true
 collections.workspace = true
+feature_flags.workspace = true
 multi_buffer.workspace = true
 file_icons.workspace = true
 futures.workspace = true

crates/acp_thread/src/acp_thread.rs 🔗

@@ -8,6 +8,7 @@ use anyhow::{Context as _, Result, anyhow};
 use collections::HashSet;
 pub use connection::*;
 pub use diff::*;
+use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _};
 use futures::{FutureExt, channel::oneshot, future::BoxFuture};
 use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
 use itertools::Itertools;
@@ -972,7 +973,7 @@ impl PlanEntry {
     }
 }
 
-#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
 pub struct TokenUsage {
     pub max_tokens: u64,
     pub used_tokens: u64,
@@ -981,6 +982,12 @@ pub struct TokenUsage {
     pub max_output_tokens: Option<u64>,
 }
 
+#[derive(Debug, Clone)]
+pub struct SessionCost {
+    pub amount: f64,
+    pub currency: SharedString,
+}
+
 pub const TOKEN_USAGE_WARNING_THRESHOLD: f32 = 0.8;
 
 impl TokenUsage {
@@ -1043,6 +1050,7 @@ pub struct AcpThread {
     running_turn: Option<RunningTurn>,
     connection: Rc<dyn AgentConnection>,
     token_usage: Option<TokenUsage>,
+    cost: Option<SessionCost>,
     prompt_capabilities: acp::PromptCapabilities,
     available_commands: Vec<acp::AvailableCommand>,
     _observe_prompt_capabilities: Task<anyhow::Result<()>>,
@@ -1232,6 +1240,7 @@ impl AcpThread {
             connection,
             session_id,
             token_usage: None,
+            cost: None,
             prompt_capabilities,
             available_commands: Vec::new(),
             _observe_prompt_capabilities: task,
@@ -1348,6 +1357,10 @@ impl AcpThread {
         self.token_usage.as_ref()
     }
 
+    pub fn cost(&self) -> Option<&SessionCost> {
+        self.cost.as_ref()
+    }
+
     pub fn has_pending_edit_tool_calls(&self) -> bool {
         for entry in self.entries.iter().rev() {
             match entry {
@@ -1463,6 +1476,18 @@ impl AcpThread {
                 config_options,
                 ..
             }) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
+            acp::SessionUpdate::UsageUpdate(update) if cx.has_flag::<AcpBetaFeatureFlag>() => {
+                let usage = self.token_usage.get_or_insert_with(Default::default);
+                usage.max_tokens = update.size;
+                usage.used_tokens = update.used;
+                if let Some(cost) = update.cost {
+                    self.cost = Some(SessionCost {
+                        amount: cost.amount,
+                        currency: cost.currency.into(),
+                    });
+                }
+                cx.emit(AcpThreadEvent::TokenUsageUpdated);
+            }
             _ => {}
         }
         Ok(())
@@ -1759,6 +1784,9 @@ impl AcpThread {
     }
 
     pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
+        if usage.is_none() {
+            self.cost = None;
+        }
         self.token_usage = usage;
         cx.emit(AcpThreadEvent::TokenUsageUpdated);
     }
@@ -2340,6 +2368,15 @@ impl AcpThread {
                             }
                         }
 
+                        if cx.has_flag::<AcpBetaFeatureFlag>()
+                            && let Some(response_usage) = &r.usage
+                        {
+                            let usage = this.token_usage.get_or_insert_with(Default::default);
+                            usage.input_tokens = response_usage.input_tokens;
+                            usage.output_tokens = response_usage.output_tokens;
+                            cx.emit(AcpThreadEvent::TokenUsageUpdated);
+                        }
+
                         cx.emit(AcpThreadEvent::Stopped(r.stop_reason));
                         Ok(Some(r))
                     }
@@ -5297,4 +5334,173 @@ mod tests {
             "session info title update should not propagate back to the connection"
         );
     }
+
+    #[gpui::test]
+    async fn test_usage_update_populates_token_usage_and_cost(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        thread.update(cx, |thread, cx| {
+            thread
+                .handle_session_update(
+                    acp::SessionUpdate::UsageUpdate(
+                        acp::UsageUpdate::new(5000, 10000).cost(acp::Cost::new(0.42, "USD")),
+                    ),
+                    cx,
+                )
+                .unwrap();
+        });
+
+        thread.read_with(cx, |thread, _| {
+            let usage = thread.token_usage().expect("token_usage should be set");
+            assert_eq!(usage.max_tokens, 10000);
+            assert_eq!(usage.used_tokens, 5000);
+
+            let cost = thread.cost().expect("cost should be set");
+            assert!((cost.amount - 0.42).abs() < f64::EPSILON);
+            assert_eq!(cost.currency.as_ref(), "USD");
+        });
+    }
+
+    #[gpui::test]
+    async fn test_usage_update_without_cost_preserves_existing_cost(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        thread.update(cx, |thread, cx| {
+            thread
+                .handle_session_update(
+                    acp::SessionUpdate::UsageUpdate(
+                        acp::UsageUpdate::new(1000, 10000).cost(acp::Cost::new(0.10, "USD")),
+                    ),
+                    cx,
+                )
+                .unwrap();
+
+            thread
+                .handle_session_update(
+                    acp::SessionUpdate::UsageUpdate(acp::UsageUpdate::new(2000, 10000)),
+                    cx,
+                )
+                .unwrap();
+        });
+
+        thread.read_with(cx, |thread, _| {
+            let usage = thread.token_usage().expect("token_usage should be set");
+            assert_eq!(usage.used_tokens, 2000);
+
+            let cost = thread.cost().expect("cost should be preserved");
+            assert!((cost.amount - 0.10).abs() < f64::EPSILON);
+        });
+    }
+
+    #[gpui::test]
+    async fn test_response_usage_does_not_clobber_session_usage(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
+            move |_, thread, mut cx| {
+                async move {
+                    thread.update(&mut cx, |thread, cx| {
+                        thread
+                            .handle_session_update(
+                                acp::SessionUpdate::UsageUpdate(
+                                    acp::UsageUpdate::new(3000, 10000)
+                                        .cost(acp::Cost::new(0.05, "EUR")),
+                                ),
+                                cx,
+                            )
+                            .unwrap();
+                    })?;
+                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)
+                        .usage(acp::Usage::new(500, 200, 300)))
+                }
+                .boxed_local()
+            },
+        ));
+
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        thread
+            .update(cx, |thread, cx| thread.send_raw("hello", cx))
+            .await
+            .unwrap();
+
+        thread.read_with(cx, |thread, _| {
+            let usage = thread.token_usage().expect("token_usage should be set");
+            assert_eq!(usage.max_tokens, 10000, "max_tokens from UsageUpdate");
+            assert_eq!(usage.used_tokens, 3000, "used_tokens from UsageUpdate");
+            assert_eq!(usage.input_tokens, 200, "input_tokens from response usage");
+            assert_eq!(
+                usage.output_tokens, 300,
+                "output_tokens from response usage"
+            );
+
+            let cost = thread.cost().expect("cost should be set");
+            assert!((cost.amount - 0.05).abs() < f64::EPSILON);
+            assert_eq!(cost.currency.as_ref(), "EUR");
+        });
+    }
+
+    #[gpui::test]
+    async fn test_clearing_token_usage_also_clears_cost(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        thread.update(cx, |thread, cx| {
+            thread
+                .handle_session_update(
+                    acp::SessionUpdate::UsageUpdate(
+                        acp::UsageUpdate::new(1000, 10000).cost(acp::Cost::new(0.25, "USD")),
+                    ),
+                    cx,
+                )
+                .unwrap();
+
+            assert!(thread.token_usage().is_some());
+            assert!(thread.cost().is_some());
+
+            thread.update_token_usage(None, cx);
+
+            assert!(thread.token_usage().is_none());
+            assert!(
+                thread.cost().is_none(),
+                "cost should be cleared when token usage is cleared"
+            );
+        });
+    }
 }

crates/agent_ui/src/conversation_view/thread_view.rs 🔗

@@ -7,6 +7,7 @@ use std::cell::RefCell;
 use acp_thread::{ContentBlock, PlanEntry};
 use cloud_api_types::{SubmitAgentThreadFeedbackBody, SubmitAgentThreadFeedbackCommentsBody};
 use editor::actions::OpenExcerpts;
+use feature_flags::AcpBetaFeatureFlag;
 
 use crate::StartThreadIn;
 use crate::message_editor::SharedSessionCapabilities;
@@ -3547,6 +3548,19 @@ impl ThreadView {
         let usage = thread.token_usage()?;
         let show_split = self.supports_split_token_display(cx);
 
+        let cost_label = if cx.has_flag::<AcpBetaFeatureFlag>() {
+            thread.cost().map(|cost| {
+                let precision = if cost.amount > 0.0 && cost.amount < 0.01 {
+                    4
+                } else {
+                    2
+                };
+                format!("{:.prec$} {}", cost.amount, cost.currency, prec = precision)
+            })
+        } else {
+            None
+        };
+
         let progress_color = |ratio: f32| -> Hsla {
             if ratio >= 0.85 {
                 cx.theme().status().warning
@@ -3617,6 +3631,7 @@ impl ThreadView {
                 let output_max_label = output_max_label.clone();
                 let project_entry_ids = project_entry_ids.clone();
                 let workspace = workspace.clone();
+                let cost_label = cost_label.clone();
                 cx.new(move |_cx| TokenUsageTooltip {
                     percentage,
                     used,
@@ -3626,6 +3641,7 @@ impl ThreadView {
                     input_max: input_max_label,
                     output_max: output_max_label,
                     show_split,
+                    cost_label,
                     separator_color: tooltip_separator_color,
                     user_rules_count,
                     first_user_rules_id,
@@ -4273,6 +4289,7 @@ struct TokenUsageTooltip {
     input_max: String,
     output_max: String,
     show_split: bool,
+    cost_label: Option<String>,
     separator_color: Color,
     user_rules_count: usize,
     first_user_rules_id: Option<uuid::Uuid>,
@@ -4292,6 +4309,7 @@ impl Render for TokenUsageTooltip {
         let input_max = self.input_max.clone();
         let output_max = self.output_max.clone();
         let show_split = self.show_split;
+        let cost_label = self.cost_label.clone();
         let user_rules_count = self.user_rules_count;
         let first_user_rules_id = self.first_user_rules_id;
         let project_rules_count = self.project_rules_count;
@@ -4339,6 +4357,22 @@ impl Render for TokenUsageTooltip {
                             ),
                     )
                 })
+                .when_some(cost_label, |this, cost_label| {
+                    this.child(
+                        v_flex()
+                            .mt_1p5()
+                            .pt_1p5()
+                            .gap_0p5()
+                            .border_t_1()
+                            .border_color(cx.theme().colors().border_variant)
+                            .child(
+                                Label::new("Cost")
+                                    .color(Color::Muted)
+                                    .size(LabelSize::Small),
+                            )
+                            .child(Label::new(cost_label)),
+                    )
+                })
                 .when(
                     user_rules_count > 0 || project_rules_count > 0,
                     move |this| {