@@ -1,8 +1,8 @@
-use crate::HistoryStore;
use crate::{
ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
UserMessageContent, templates::Templates,
};
+use crate::{HistoryStore, TokenUsageUpdated};
use acp_thread::{AcpThread, AgentModelSelector};
use action_log::ActionLog;
use agent_client_protocol as acp;
@@ -253,6 +253,7 @@ impl NativeAgent {
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
+ cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
cx.observe(&thread_handle, move |this, thread, cx| {
this.save_thread(thread.clone(), cx)
}),
@@ -440,6 +441,23 @@ impl NativeAgent {
})
}
+ fn handle_thread_token_usage_updated(
+ &mut self,
+ thread: Entity<Thread>,
+ usage: &TokenUsageUpdated,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(session) = self.sessions.get(thread.read(cx).id()) else {
+ return;
+ };
+ session
+ .acp_thread
+ .update(cx, |acp_thread, cx| {
+ acp_thread.update_token_usage(usage.0.clone(), cx);
+ })
+ .ok();
+ }
+
fn handle_project_event(
&mut self,
_project: Entity<Project>,
@@ -695,11 +713,6 @@ impl NativeAgentConnection {
thread.update_tool_call(update, cx)
})??;
}
- ThreadEvent::TokenUsageUpdate(usage) => {
- acp_thread.update(cx, |thread, cx| {
- thread.update_token_usage(Some(usage), cx)
- })?;
- }
ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;
@@ -15,7 +15,8 @@ use agent_settings::{
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc};
-use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
+use client::{ModelRequestUsage, RequestUsage};
+use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use collections::{HashMap, IndexMap};
use fs::Fs;
use futures::{
@@ -25,7 +26,9 @@ use futures::{
stream::FuturesUnordered,
};
use git::repository::DiffType;
-use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
+use gpui::{
+ App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
+};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
@@ -484,7 +487,6 @@ pub enum ThreadEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
- TokenUsageUpdate(acp_thread::TokenUsage),
TitleUpdate(SharedString),
Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
@@ -873,7 +875,12 @@ impl Thread {
}
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
+ let old_usage = self.latest_token_usage();
self.model = Some(model);
+ let new_usage = self.latest_token_usage();
+ if old_usage != new_usage {
+ cx.emit(TokenUsageUpdated(new_usage));
+ }
cx.notify()
}
@@ -891,7 +898,12 @@ impl Thread {
}
pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
+ let old_usage = self.latest_token_usage();
self.completion_mode = mode;
+ let new_usage = self.latest_token_usage();
+ if old_usage != new_usage {
+ cx.emit(TokenUsageUpdated(new_usage));
+ }
cx.notify()
}
@@ -953,13 +965,15 @@ impl Thread {
self.flush_pending_message(cx);
}
- pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
+ fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
let Some(last_user_message) = self.last_user_message() else {
return;
};
self.request_token_usage
.insert(last_user_message.id.clone(), update);
+ cx.emit(TokenUsageUpdated(self.latest_token_usage()));
+ cx.notify();
}
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
@@ -1180,20 +1194,15 @@ impl Thread {
)) => {
*tool_use_limit_reached = true;
}
+ Ok(LanguageModelCompletionEvent::StatusUpdate(
+ CompletionRequestStatus::UsageUpdated { amount, limit },
+ )) => {
+ this.update(cx, |this, cx| {
+ this.update_model_request_usage(amount, limit, cx)
+ })?;
+ }
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
- let usage = acp_thread::TokenUsage {
- max_tokens: model.max_token_count_for_mode(
- request
- .mode
- .unwrap_or(cloud_llm_client::CompletionMode::Normal),
- ),
- used_tokens: token_usage.total_tokens(),
- };
-
- this.update(cx, |this, _cx| this.update_token_usage(token_usage))
- .ok();
-
- event_stream.send_token_usage_update(usage);
+ this.update(cx, |this, cx| this.update_token_usage(token_usage, cx))?;
}
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
*refusal = true;
@@ -1214,8 +1223,7 @@ impl Thread {
event_stream,
cx,
));
- })
- .ok();
+ })?;
}
Err(error) => {
let completion_mode =
@@ -1325,8 +1333,8 @@ impl Thread {
json_parse_error,
)));
}
- UsageUpdate(_) | StatusUpdate(_) => {}
- Stop(_) => unreachable!(),
+ StatusUpdate(_) => {}
+ UsageUpdate(_) | Stop(_) => unreachable!(),
}
None
@@ -1506,6 +1514,21 @@ impl Thread {
}
}
+ fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
+ self.project
+ .read(cx)
+ .user_store()
+ .update(cx, |user_store, cx| {
+ user_store.update_model_request_usage(
+ ModelRequestUsage(RequestUsage {
+ amount: amount as i32,
+ limit,
+ }),
+ cx,
+ )
+ });
+ }
+
pub fn title(&self) -> SharedString {
self.title.clone().unwrap_or("New Thread".into())
}
@@ -1636,6 +1659,7 @@ impl Thread {
})
}))
}
+
fn last_user_message(&self) -> Option<&UserMessage> {
self.messages
.iter()
@@ -1934,6 +1958,10 @@ impl RunningTurn {
}
}
+pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
+
+impl EventEmitter<TokenUsageUpdated> for Thread {}
+
pub trait AgentTool
where
Self: 'static + Sized,
@@ -2166,12 +2194,6 @@ impl ThreadEventStream {
.ok();
}
- fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
- self.0
- .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
- .ok();
- }
-
fn send_retry(&self, status: acp_thread::RetryStatus) {
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
}