Detailed changes
@@ -6,6 +6,7 @@ mod terminal;
pub use connection::*;
pub use diff::*;
pub use mention::*;
+use serde::{Deserialize, Serialize};
pub use terminal::*;
use action_log::ActionLog;
@@ -664,6 +665,12 @@ impl PlanEntry {
}
}
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub struct TokenUsage {
+ pub max_tokens: u64,
+ pub used_tokens: u64,
+}
+
#[derive(Debug, Clone)]
pub struct RetryStatus {
pub last_error: SharedString,
@@ -683,12 +690,14 @@ pub struct AcpThread {
send_task: Option<Task<()>>,
connection: Rc<dyn AgentConnection>,
session_id: acp::SessionId,
+ token_usage: Option<TokenUsage>,
}
#[derive(Debug)]
pub enum AcpThreadEvent {
NewEntry,
TitleUpdated,
+ TokenUsageUpdated,
EntryUpdated(usize),
EntriesRemoved(Range<usize>),
ToolAuthorizationRequired,
@@ -748,6 +757,7 @@ impl AcpThread {
send_task: None,
connection,
session_id,
+ token_usage: None,
}
}
@@ -787,6 +797,10 @@ impl AcpThread {
}
}
+ pub fn token_usage(&self) -> Option<&TokenUsage> {
+ self.token_usage.as_ref()
+ }
+
pub fn has_pending_edit_tool_calls(&self) -> bool {
for entry in self.entries.iter().rev() {
match entry {
@@ -937,6 +951,11 @@ impl AcpThread {
Ok(())
}
+ pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
+ self.token_usage = usage;
+ cx.emit(AcpThreadEvent::TokenUsageUpdated);
+ }
+
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::Retry(status));
}
@@ -10,7 +10,7 @@ use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
-#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct UserMessageId(Arc<str>);
impl UserMessageId {
@@ -66,6 +66,7 @@ zstd.workspace = true
[dev-dependencies]
agent = { workspace = true, "features" = ["test-support"] }
+assistant_context = { workspace = true, "features" = ["test-support"] }
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
@@ -1,8 +1,8 @@
+use crate::HistoryStore;
use crate::{
- ContextServerRegistry, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent,
- templates::Templates,
+ ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
+ UserMessageContent, templates::Templates,
};
-use crate::{HistoryStore, ThreadsDatabase};
use acp_thread::{AcpThread, AgentModelSelector};
use action_log::ActionLog;
use agent_client_protocol as acp;
@@ -673,6 +673,11 @@ impl NativeAgentConnection {
thread.update_tool_call(update, cx)
})??;
}
+ ThreadEvent::TokenUsageUpdate(usage) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.update_token_usage(Some(usage), cx)
+ })?;
+ }
ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;
@@ -895,10 +900,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
self.0.update(cx, |agent, _cx| {
- agent
- .sessions
- .get(session_id)
- .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
+ agent.sessions.get(session_id).map(|session| {
+ Rc::new(NativeAgentSessionEditor {
+ thread: session.thread.clone(),
+ acp_thread: session.acp_thread.clone(),
+ }) as _
+ })
})
}
@@ -907,14 +914,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}
}
-struct NativeAgentSessionEditor(Entity<Thread>);
+struct NativeAgentSessionEditor {
+ thread: Entity<Thread>,
+ acp_thread: WeakEntity<AcpThread>,
+}
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
- Task::ready(
- self.0
- .update(cx, |thread, cx| thread.truncate(message_id, cx)),
- )
+ match self.thread.update(cx, |thread, cx| {
+ thread.truncate(message_id.clone(), cx)?;
+ Ok(thread.latest_token_usage())
+ }) {
+ Ok(usage) => {
+ self.acp_thread
+ .update(cx, |thread, cx| {
+ thread.update_token_usage(usage, cx);
+ })
+ .ok();
+ Task::ready(Ok(()))
+ }
+ Err(error) => Task::ready(Err(error)),
+ }
}
}
@@ -1,4 +1,5 @@
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
+use acp_thread::UserMessageId;
use agent::thread_store;
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, CompletionMode};
@@ -42,7 +43,7 @@ pub struct DbThread {
#[serde(default)]
pub cumulative_token_usage: language_model::TokenUsage,
#[serde(default)]
- pub request_token_usage: Vec<language_model::TokenUsage>,
+ pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
#[serde(default)]
pub model: Option<DbLanguageModel>,
#[serde(default)]
@@ -67,7 +68,10 @@ impl DbThread {
fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
let mut messages = Vec::new();
- for msg in thread.messages {
+ let mut request_token_usage = HashMap::default();
+
+ let mut last_user_message_id = None;
+ for (ix, msg) in thread.messages.into_iter().enumerate() {
let message = match msg.role {
language_model::Role::User => {
let mut content = Vec::new();
@@ -93,9 +97,12 @@ impl DbThread {
content.push(UserMessageContent::Text(msg.context));
}
+ let id = UserMessageId::new();
+ last_user_message_id = Some(id.clone());
+
crate::Message::User(UserMessage {
// MessageId from old format can't be meaningfully converted, so generate a new one
- id: acp_thread::UserMessageId::new(),
+ id,
content,
})
}
@@ -154,6 +161,12 @@ impl DbThread {
);
}
+ if let Some(last_user_message_id) = &last_user_message_id
+ && let Some(token_usage) = thread.request_token_usage.get(ix).copied()
+ {
+ request_token_usage.insert(last_user_message_id.clone(), token_usage);
+ }
+
crate::Message::Agent(AgentMessage {
content,
tool_results,
@@ -175,7 +188,7 @@ impl DbThread {
summary: thread.detailed_summary_state,
initial_project_snapshot: thread.initial_project_snapshot,
cumulative_token_usage: thread.cumulative_token_usage,
- request_token_usage: thread.request_token_usage,
+ request_token_usage,
model: thread.model,
completion_mode: thread.completion_mode,
profile: thread.profile,
@@ -1117,7 +1117,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
}
#[gpui::test]
-async fn test_truncate(cx: &mut TestAppContext) {
+async fn test_truncate_first_message(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
@@ -1137,9 +1137,18 @@ async fn test_truncate(cx: &mut TestAppContext) {
Hello
"}
);
+ assert_eq!(thread.latest_token_usage(), None);
});
fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 32_000,
+ output_tokens: 16_000,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@@ -1154,6 +1163,13 @@ async fn test_truncate(cx: &mut TestAppContext) {
Hey!
"}
);
+ assert_eq!(
+ thread.latest_token_usage(),
+ Some(acp_thread::TokenUsage {
+ used_tokens: 32_000 + 16_000,
+ max_tokens: 1_000_000,
+ })
+ );
});
thread
@@ -1162,6 +1178,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(thread.to_markdown(), "");
+ assert_eq!(thread.latest_token_usage(), None);
});
// Ensure we can still send a new message after truncation.
@@ -1182,6 +1199,14 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 40_000,
+ output_tokens: 20_000,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@@ -1196,7 +1221,124 @@ async fn test_truncate(cx: &mut TestAppContext) {
Ahoy!
"}
);
+
+ assert_eq!(
+ thread.latest_token_usage(),
+ Some(acp_thread::TokenUsage {
+ used_tokens: 40_000 + 20_000,
+ max_tokens: 1_000_000,
+ })
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_truncate_second_message(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Message 1"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Message 1 response");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 32_000,
+ output_tokens: 16_000,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let assert_first_message_state = |cx: &mut TestAppContext| {
+ thread.clone().read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Message 1
+
+ ## Assistant
+
+ Message 1 response
+ "}
+ );
+
+ assert_eq!(
+ thread.latest_token_usage(),
+ Some(acp_thread::TokenUsage {
+ used_tokens: 32_000 + 16_000,
+ max_tokens: 1_000_000,
+ })
+ );
+ });
+ };
+
+ assert_first_message_state(cx);
+
+ let second_message_id = UserMessageId::new();
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(second_message_id.clone(), ["Message 2"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Message 2 response");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 40_000,
+ output_tokens: 20_000,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Message 1
+
+ ## Assistant
+
+ Message 1 response
+
+ ## User
+
+ Message 2
+
+ ## Assistant
+
+ Message 2 response
+ "}
+ );
+
+ assert_eq!(
+ thread.latest_token_usage(),
+ Some(acp_thread::TokenUsage {
+ used_tokens: 40_000 + 20_000,
+ max_tokens: 1_000_000,
+ })
+ );
});
+
+ thread
+ .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ assert_first_message_state(cx);
}
#[gpui::test]
@@ -13,7 +13,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
-use collections::IndexMap;
+use collections::{HashMap, IndexMap};
use fs::Fs;
use futures::{
FutureExt,
@@ -24,8 +24,8 @@ use futures::{
use git::repository::DiffType;
use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language_model::{
- LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
- LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
+ LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
@@ -481,6 +481,7 @@ pub enum ThreadEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
+ TokenUsageUpdate(acp_thread::TokenUsage),
TitleUpdate(SharedString),
Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
@@ -509,8 +510,7 @@ pub struct Thread {
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool,
- #[allow(unused)]
- request_token_usage: Vec<TokenUsage>,
+ request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
#[allow(unused)]
cumulative_token_usage: TokenUsage,
#[allow(unused)]
@@ -548,7 +548,7 @@ impl Thread {
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
- request_token_usage: Vec::new(),
+ request_token_usage: HashMap::default(),
cumulative_token_usage: TokenUsage::default(),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project.clone(), cx);
@@ -951,6 +951,15 @@ impl Thread {
self.flush_pending_message(cx);
}
+ pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
+ let Some(last_user_message) = self.last_user_message() else {
+ return;
+ };
+
+ self.request_token_usage
+ .insert(last_user_message.id.clone(), update);
+ }
+
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
self.cancel(cx);
let Some(position) = self.messages.iter().position(
@@ -958,11 +967,31 @@ impl Thread {
) else {
return Err(anyhow!("Message not found"));
};
- self.messages.truncate(position);
+
+ for message in self.messages.drain(position..) {
+ match message {
+ Message::User(message) => {
+ self.request_token_usage.remove(&message.id);
+ }
+ Message::Agent(_) | Message::Resume => {}
+ }
+ }
+
cx.notify();
Ok(())
}
+ pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
+ let last_user_message = self.last_user_message()?;
+ let tokens = self.request_token_usage.get(&last_user_message.id)?;
+ let model = self.model.clone()?;
+
+ Some(acp_thread::TokenUsage {
+ max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
+ used_tokens: tokens.total_tokens(),
+ })
+ }
+
pub fn resume(
&mut self,
cx: &mut Context<Self>,
@@ -1148,6 +1177,21 @@ impl Thread {
)) => {
*tool_use_limit_reached = true;
}
+ Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
+ let usage = acp_thread::TokenUsage {
+ max_tokens: model.max_token_count_for_mode(
+ request
+ .mode
+ .unwrap_or(cloud_llm_client::CompletionMode::Normal),
+ ),
+ used_tokens: token_usage.total_tokens(),
+ };
+
+ this.update(cx, |this, _cx| this.update_token_usage(token_usage))
+ .ok();
+
+ event_stream.send_token_usage_update(usage);
+ }
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
*refusal = true;
return Ok(FuturesUnordered::default());
@@ -1532,6 +1576,16 @@ impl Thread {
})
}))
}
+ fn last_user_message(&self) -> Option<&UserMessage> {
+ self.messages
+ .iter()
+ .rev()
+ .find_map(|message| match message {
+ Message::User(user_message) => Some(user_message),
+ Message::Agent(_) => None,
+ Message::Resume => None,
+ })
+ }
fn pending_message(&mut self) -> &mut AgentMessage {
self.pending_message.get_or_insert_default()
@@ -2051,6 +2105,12 @@ impl ThreadEventStream {
.ok();
}
+ fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
+ self.0
+ .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
+ .ok();
+ }
+
fn send_retry(&self, status: acp_thread::RetryStatus) {
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
}
@@ -816,7 +816,7 @@ impl AcpThreadView {
self.thread_retry_status.take();
self.thread_state = ThreadState::ServerExited { status: *status };
}
- AcpThreadEvent::TitleUpdated => {}
+ AcpThreadEvent::TitleUpdated | AcpThreadEvent::TokenUsageUpdated => {}
}
cx.notify();
}
@@ -2794,6 +2794,7 @@ impl AcpThreadView {
.child(
h_flex()
.gap_1()
+ .children(self.render_token_usage(cx))
.children(self.profile_selector.clone())
.children(self.model_selector.clone())
.child(self.render_send_button(cx)),
@@ -2816,6 +2817,44 @@ impl AcpThreadView {
.thread(acp_thread.session_id(), cx)
}
+ fn render_token_usage(&self, cx: &mut Context<Self>) -> Option<Div> {
+ let thread = self.thread()?.read(cx);
+ let usage = thread.token_usage()?;
+ let is_generating = thread.status() != ThreadStatus::Idle;
+
+ let used = crate::text_thread_editor::humanize_token_count(usage.used_tokens);
+ let max = crate::text_thread_editor::humanize_token_count(usage.max_tokens);
+
+ Some(
+ h_flex()
+ .flex_shrink_0()
+ .gap_0p5()
+ .mr_1()
+ .child(
+ Label::new(used)
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .map(|label| {
+ if is_generating {
+ label
+ .with_animation(
+ "used-tokens-label",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.6, 1.)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any()
+ } else {
+ label.into_any_element()
+ }
+ }),
+ )
+ .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
+ .child(Label::new(max).size(LabelSize::Small).color(Color::Muted)),
+ )
+ }
+
fn toggle_burn_mode(
&mut self,
_: &ToggleBurnMode,
@@ -1526,6 +1526,7 @@ impl AgentDiff {
self.update_reviewing_editors(workspace, window, cx);
}
AcpThreadEvent::TitleUpdated
+ | AcpThreadEvent::TokenUsageUpdated
| AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Retry(_) => {}