From 0286b8ab3eeae8bb8ca89e770a39ff524488b298 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 16 Apr 2025 13:27:36 -0600 Subject: [PATCH] agent: Fix conversation token usage and estimate unsent message (#28878) The UI was mistakenly using the cumulative token usage for the token counter. It will now display the last request token count, plus an estimation of the tokens in the message editor and context entries that haven't been sent yet. https://github.com/user-attachments/assets/0438c501-b850-4397-9135-57214ca3c07a Additionally, when the user edits a message, we'll display the actual token count up to it and estimate the tokens in the new message. Note: We don't currently estimate the delta when switching profiles. In the future, we want to use the count tokens API to measure every part of the request and display a breakdown. Release Notes: - agent: Made the token count more accurate and added back estimation of used tokens as you type and add context. --------- Co-authored-by: Bennet Bo Fenner Co-authored-by: Danilo Leal --- crates/agent/src/active_thread.rs | 117 +++++++- crates/agent/src/assistant_panel.rs | 267 ++++++++++++------ crates/agent/src/message_editor.rs | 120 ++++++-- crates/agent/src/thread.rs | 115 ++++++-- crates/agent/src/thread_store.rs | 3 + crates/language_model/src/language_model.rs | 5 +- .../language_models/src/provider/anthropic.rs | 6 +- crates/ui/src/components/tab.rs | 4 +- 8 files changed, 507 insertions(+), 130 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 0cd90ba79686cee5cfb6b441af34a2bf2c272f47..fea66c3dea355f01a8a909e713551a2a3d064a0c 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,4 +1,4 @@ -use crate::context::{AssistantContext, ContextId}; +use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::context_picker::MentionLink; use crate::thread::{ LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError, @@ -13,16 +13,18 @@ use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting}; use assistant_tool::ToolUseStatus; use collections::{HashMap, HashSet}; use editor::scroll::Autoscroll; -use editor::{Editor, EditorElement, EditorStyle, MultiBuffer}; +use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer}; use gpui::{ AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem, - DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState, - MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task, - TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle, + DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment, + ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, + Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage, pulsating_between, }; use language::{Buffer, LanguageRegistry}; -use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, Role, StopReason, +}; use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; use project::ProjectItem as _; @@ -682,6 +684,9 @@ fn open_markdown_link( struct EditMessageState { editor: Entity, + last_estimated_token_count: Option, + _subscription: Subscription, + _update_token_count_task: Option>>, } impl ActiveThread { @@ -781,6 +786,13 @@ impl ActiveThread { self.last_error.take(); } + /// Returns the editing message id and the estimated token count in the content + pub fn editing_message_id(&self) -> Option<(MessageId, usize)> { + self.editing_message + .as_ref() + .map(|(id, state)| (*id, state.last_estimated_token_count.unwrap_or(0))) + } + fn push_message( &mut self, id: &MessageId, @@ -1126,15 +1138,91 @@ impl ActiveThread { editor.move_to_end(&editor::actions::MoveToEnd, window, cx); editor }); + let subscription = cx.subscribe(&editor, |this, _, event, cx| match event { + EditorEvent::BufferEdited => { + this.update_editing_message_token_count(true, cx); + } + _ => {} + }); self.editing_message = Some(( message_id, EditMessageState { editor: editor.clone(), + last_estimated_token_count: None, + _subscription: subscription, + _update_token_count_task: None, }, )); + self.update_editing_message_token_count(false, cx); cx.notify(); } + fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context) { + let Some((message_id, state)) = self.editing_message.as_mut() else { + return; + }; + + cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged); + state._update_token_count_task.take(); + + let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else { + state.last_estimated_token_count.take(); + return; + }; + + let editor = state.editor.clone(); + let thread = self.thread.clone(); + let message_id = *message_id; + + state._update_token_count_task = Some(cx.spawn(async move |this, cx| { + if debounce { + cx.background_executor() + .timer(Duration::from_millis(200)) + .await; + } + + let token_count = if let Some(task) = cx.update(|cx| { + let context = thread.read(cx).context_for_message(message_id); + let new_context = thread.read(cx).filter_new_context(context); + let context_text = + format_context_as_string(new_context, cx).unwrap_or(String::new()); + let message_text = editor.read(cx).text(cx); + + let content = context_text + &message_text; + + if content.is_empty() { + return None; + } + + let request = language_model::LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: language_model::Role::User, + content: vec![content.into()], + cache: false, + }], + tools: vec![], + stop: vec![], + temperature: None, + }; + + Some(default_model.model.count_tokens(request, cx)) + })? { + task.await? + } else { + 0 + }; + + this.update(cx, |this, cx| { + let Some((_message_id, state)) = this.editing_message.as_mut() else { + return; + }; + + state.last_estimated_token_count = Some(token_count); + cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged); + }) + })); + } + fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { self.editing_message.take(); cx.notify(); @@ -1676,6 +1764,9 @@ impl ActiveThread { "confirm-edit-message", "Regenerate", ) + .disabled( + edit_message_editor.read(cx).is_empty(cx), + ) .label_size(LabelSize::Small) .key_binding( KeyBinding::for_action_in( @@ -1738,8 +1829,16 @@ impl ActiveThread { ), }; + let after_editing_message = self + .editing_message + .as_ref() + .map_or(false, |(editing_message_id, _)| { + message_id > *editing_message_id + }); + v_flex() .w_full() + .when(after_editing_message, |parent| parent.opacity(0.2)) .when_some(checkpoint, |parent, checkpoint| { let mut is_pending = false; let mut error = None; @@ -2965,6 +3064,12 @@ impl ActiveThread { } } +pub enum ActiveThreadEvent { + EditingMessageTokenCountChanged, +} + +impl EventEmitter for ActiveThread {} + impl Render for ActiveThread { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 3090f4682d3ebf23ffbbbdc4aaa08569a4e0b5ce..6b7ac9e849bf6fc6ad4770be0219de09cc762832 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -5,7 +5,7 @@ use std::time::Duration; use anyhow::{Result, anyhow}; use assistant_context_editor::{ AssistantPanelDelegate, ConfigurationError, ContextEditor, SlashCommandCompletionProvider, - make_lsp_adapter_delegate, render_remaining_tokens, + humanize_token_count, make_lsp_adapter_delegate, render_remaining_tokens, }; use assistant_settings::{AssistantDockPosition, AssistantSettings}; use assistant_slash_command::SlashCommandWorkingSet; @@ -37,10 +37,10 @@ use workspace::dock::{DockPosition, Panel, PanelEvent}; use zed_actions::agent::OpenConfiguration; use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus}; -use crate::active_thread::ActiveThread; +use crate::active_thread::{ActiveThread, ActiveThreadEvent}; use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent}; use crate::history_store::{HistoryEntry, HistoryStore}; -use crate::message_editor::MessageEditor; +use crate::message_editor::{MessageEditor, MessageEditorEvent}; use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio}; use crate::thread_history::{PastContext, PastThread, ThreadHistory}; use crate::thread_store::ThreadStore; @@ -181,8 +181,8 @@ pub struct AssistantPanel { language_registry: Arc, thread_store: Entity, thread: Entity, - _thread_subscription: Subscription, message_editor: Entity, + _active_thread_subscriptions: Vec, context_store: Entity, context_editor: Option>, configuration: Option>, @@ -264,6 +264,13 @@ impl AssistantPanel { ) }); + let message_editor_subscription = + cx.subscribe(&message_editor, |_, _, event, cx| match event { + MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => { + cx.notify(); + } + }); + let history_store = cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx)); @@ -288,6 +295,12 @@ impl AssistantPanel { ) }); + let active_thread_subscription = cx.subscribe(&thread, |_, _, event, cx| match &event { + ActiveThreadEvent::EditingMessageTokenCountChanged => { + cx.notify(); + } + }); + Self { active_view, workspace, @@ -296,8 +309,12 @@ impl AssistantPanel { language_registry, thread_store: thread_store.clone(), thread, - _thread_subscription: thread_subscription, message_editor, + _active_thread_subscriptions: vec![ + thread_subscription, + active_thread_subscription, + message_editor_subscription, + ], context_store, context_editor: None, configuration: None, @@ -382,6 +399,13 @@ impl AssistantPanel { .detach_and_log_err(cx); } + let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| { + if let ThreadEvent::MessageAdded(_) = &event { + // needed to leave empty state + cx.notify(); + } + }); + self.thread = cx.new(|cx| { ActiveThread::new( thread.clone(), @@ -394,12 +418,12 @@ impl AssistantPanel { ) }); - self._thread_subscription = cx.subscribe(&thread, |_, _, event, cx| { - if let ThreadEvent::MessageAdded(_) = &event { - // needed to leave empty state - cx.notify(); - } - }); + let active_thread_subscription = + cx.subscribe(&self.thread, |_, _, event, cx| match &event { + ActiveThreadEvent::EditingMessageTokenCountChanged => { + cx.notify(); + } + }); self.message_editor = cx.new(|cx| { MessageEditor::new( @@ -413,6 +437,19 @@ impl AssistantPanel { ) }); self.message_editor.focus_handle(cx).focus(window); + + let message_editor_subscription = + cx.subscribe(&self.message_editor, |_, _, event, cx| match event { + MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => { + cx.notify(); + } + }); + + self._active_thread_subscriptions = vec![ + thread_subscription, + active_thread_subscription, + message_editor_subscription, + ]; } fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { @@ -538,6 +575,13 @@ impl AssistantPanel { Some(this.thread_store.downgrade()), ) }); + let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| { + if let ThreadEvent::MessageAdded(_) = &event { + // needed to leave empty state + cx.notify(); + } + }); + this.thread = cx.new(|cx| { ActiveThread::new( thread.clone(), @@ -549,6 +593,14 @@ impl AssistantPanel { cx, ) }); + + let active_thread_subscription = + cx.subscribe(&this.thread, |_, _, event, cx| match &event { + ActiveThreadEvent::EditingMessageTokenCountChanged => { + cx.notify(); + } + }); + this.message_editor = cx.new(|cx| { MessageEditor::new( this.fs.clone(), @@ -561,6 +613,19 @@ impl AssistantPanel { ) }); this.message_editor.focus_handle(cx).focus(window); + + let message_editor_subscription = + cx.subscribe(&this.message_editor, |_, _, event, cx| match event { + MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => { + cx.notify(); + } + }); + + this._active_thread_subscriptions = vec![ + thread_subscription, + active_thread_subscription, + message_editor_subscription, + ]; }) }) } @@ -853,7 +918,7 @@ impl Panel for AssistantPanel { } impl AssistantPanel { - fn render_title_view(&self, _window: &mut Window, cx: &mut Context) -> AnyElement { + fn render_title_view(&self, _window: &mut Window, cx: &Context) -> AnyElement { const LOADING_SUMMARY_PLACEHOLDER: &str = "Loading Summary…"; let content = match &self.active_view { @@ -913,13 +978,8 @@ impl AssistantPanel { fn render_toolbar(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let active_thread = self.thread.read(cx); let thread = active_thread.thread().read(cx); - let token_usage = thread.total_token_usage(cx); let thread_id = thread.id().clone(); - - let is_generating = thread.is_generating(); let is_empty = active_thread.is_empty(); - let focus_handle = self.focus_handle(cx); - let is_history = matches!(self.active_view, ActiveView::History); let show_token_count = match &self.active_view { @@ -928,6 +988,8 @@ impl AssistantPanel { _ => false, }; + let focus_handle = self.focus_handle(cx); + let go_back_button = match &self.active_view { ActiveView::History | ActiveView::Configuration => Some( div().pl_1().child( @@ -974,69 +1036,9 @@ impl AssistantPanel { h_flex() .h_full() .gap_2() - .when(show_token_count, |parent| match self.active_view { - ActiveView::Thread { .. } => { - if token_usage.total == 0 { - return parent; - } - - let token_color = match token_usage.ratio { - TokenUsageRatio::Normal => Color::Muted, - TokenUsageRatio::Warning => Color::Warning, - TokenUsageRatio::Exceeded => Color::Error, - }; - - parent.child( - h_flex() - .flex_shrink_0() - .gap_0p5() - .child( - Label::new(assistant_context_editor::humanize_token_count( - token_usage.total, - )) - .size(LabelSize::Small) - .color(token_color) - .map(|label| { - if is_generating { - label - .with_animation( - "used-tokens-label", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between( - 0.6, 1., - )), - |label, delta| label.alpha(delta), - ) - .into_any() - } else { - label.into_any_element() - } - }), - ) - .child( - Label::new("/").size(LabelSize::Small).color(Color::Muted), - ) - .child( - Label::new(assistant_context_editor::humanize_token_count( - token_usage.max, - )) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ) - } - ActiveView::PromptEditor => { - let Some(editor) = self.context_editor.as_ref() else { - return parent; - }; - let Some(element) = render_remaining_tokens(editor, cx) else { - return parent; - }; - parent.child(element) - } - _ => parent, - }) + .when(show_token_count, |parent| + parent.children(self.render_token_count(&thread, cx)) + ) .child( h_flex() .h_full() @@ -1132,6 +1134,111 @@ impl AssistantPanel { ) } + fn render_token_count(&self, thread: &Thread, cx: &App) -> Option { + let is_generating = thread.is_generating(); + let message_editor = self.message_editor.read(cx); + + let conversation_token_usage = thread.total_token_usage(cx); + let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) = + self.thread.read(cx).editing_message_id() + { + let combined = thread + .token_usage_up_to_message(editing_message_id, cx) + .add(unsent_tokens); + + (combined, unsent_tokens > 0) + } else { + let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0); + let combined = conversation_token_usage.add(unsent_tokens); + + (combined, unsent_tokens > 0) + }; + + let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count(); + + match self.active_view { + ActiveView::Thread { .. } => { + if total_token_usage.total == 0 { + return None; + } + + let token_color = match total_token_usage.ratio() { + TokenUsageRatio::Normal if is_estimating => Color::Default, + TokenUsageRatio::Normal => Color::Muted, + TokenUsageRatio::Warning => Color::Warning, + TokenUsageRatio::Exceeded => Color::Error, + }; + + let token_count = h_flex() + .id("token-count") + .flex_shrink_0() + .gap_0p5() + .when(!is_generating && is_estimating, |parent| { + parent + .child( + h_flex() + .mr_0p5() + .size_2() + .justify_center() + .rounded_full() + .bg(cx.theme().colors().text.opacity(0.1)) + .child( + div().size_1().rounded_full().bg(cx.theme().colors().text), + ), + ) + .tooltip(move |window, cx| { + Tooltip::with_meta( + "Estimated New Token Count", + None, + format!( + "Current Conversation Tokens: {}", + humanize_token_count(conversation_token_usage.total) + ), + window, + cx, + ) + }) + }) + .child( + Label::new(humanize_token_count(total_token_usage.total)) + .size(LabelSize::Small) + .color(token_color) + .map(|label| { + if is_generating || is_waiting_to_update_token_count { + label + .with_animation( + "used-tokens-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.6, 1.)), + |label, delta| label.alpha(delta), + ) + .into_any() + } else { + label.into_any_element() + } + }), + ) + .child(Label::new("/").size(LabelSize::Small).color(Color::Muted)) + .child( + Label::new(humanize_token_count(total_token_usage.max)) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any(); + + Some(token_count) + } + ActiveView::PromptEditor => { + let editor = self.context_editor.as_ref()?; + let element = render_remaining_tokens(editor, cx)?; + + Some(element.into_any_element()) + } + _ => None, + } + } + fn render_active_thread_or_empty_state( &self, window: &mut Window, diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 59c01d16dd373d55caadd2373874959d3fafdca1..df998b510261169a4b0956d178e101b93119d6a2 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -2,22 +2,23 @@ use std::collections::BTreeMap; use std::sync::Arc; use crate::assistant_model_selector::ModelType; +use crate::context::format_context_as_string; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use buffer_diff::BufferDiff; use collections::HashSet; use editor::actions::MoveUp; use editor::{ - ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, EditorStyle, - MultiBuffer, + ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent, EditorMode, + EditorStyle, MultiBuffer, }; use file_icons::FileIcons; use fs::Fs; use gpui::{ - Animation, AnimationExt, App, Entity, Focusable, Subscription, TextStyle, WeakEntity, - linear_color_stop, linear_gradient, point, pulsating_between, + Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle, + WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, }; use language::{Buffer, Language}; -use language_model::{ConfiguredModel, LanguageModelRegistry}; +use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage}; use language_model_selector::ToggleModelSelector; use multi_buffer; use project::Project; @@ -55,6 +56,8 @@ pub struct MessageEditor { edits_expanded: bool, editor_is_expanded: bool, waiting_for_summaries_to_send: bool, + last_estimated_token_count: Option, + update_token_count_task: Option>>, _subscriptions: Vec, } @@ -129,8 +132,18 @@ impl MessageEditor { let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx)); - let subscriptions = - vec![cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event)]; + let subscriptions = vec![ + cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), + cx.subscribe(&editor, |this, _, event, cx| match event { + EditorEvent::BufferEdited => { + this.message_or_context_changed(true, cx); + } + _ => {} + }), + cx.observe(&context_store, |this, _, cx| { + this.message_or_context_changed(false, cx); + }), + ]; Self { editor: editor.clone(), @@ -156,6 +169,8 @@ impl MessageEditor { waiting_for_summaries_to_send: false, profile_selector: cx .new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)), + last_estimated_token_count: None, + update_token_count_task: None, _subscriptions: subscriptions, } } @@ -256,6 +271,9 @@ impl MessageEditor { text }); + self.last_estimated_token_count.take(); + cx.emit(MessageEditorEvent::EstimatedTokenCount); + let refresh_task = refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx); @@ -937,6 +955,80 @@ impl MessageEditor { .label_size(LabelSize::Small), ) } + + pub fn last_estimated_token_count(&self) -> Option { + self.last_estimated_token_count + } + + pub fn is_waiting_to_update_token_count(&self) -> bool { + self.update_token_count_task.is_some() + } + + fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context) { + cx.emit(MessageEditorEvent::Changed); + self.update_token_count_task.take(); + + let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else { + self.last_estimated_token_count.take(); + return; + }; + + let context_store = self.context_store.clone(); + let editor = self.editor.clone(); + let thread = self.thread.clone(); + + self.update_token_count_task = Some(cx.spawn(async move |this, cx| { + if debounce { + cx.background_executor() + .timer(Duration::from_millis(200)) + .await; + } + + let token_count = if let Some(task) = cx.update(|cx| { + let context = context_store.read(cx).context().iter(); + let new_context = thread.read(cx).filter_new_context(context); + let context_text = + format_context_as_string(new_context, cx).unwrap_or(String::new()); + let message_text = editor.read(cx).text(cx); + + let content = context_text + &message_text; + + if content.is_empty() { + return None; + } + + let request = language_model::LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: language_model::Role::User, + content: vec![content.into()], + cache: false, + }], + tools: vec![], + stop: vec![], + temperature: None, + }; + + Some(default_model.model.count_tokens(request, cx)) + })? { + task.await? + } else { + 0 + }; + + this.update(cx, |this, cx| { + this.last_estimated_token_count = Some(token_count); + cx.emit(MessageEditorEvent::EstimatedTokenCount); + this.update_token_count_task.take(); + }) + })); + } +} + +impl EventEmitter for MessageEditor {} + +pub enum MessageEditorEvent { + EstimatedTokenCount, + Changed, } impl Focusable for MessageEditor { @@ -949,6 +1041,7 @@ impl Render for MessageEditor { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let thread = self.thread.read(cx); let total_token_usage = thread.total_token_usage(cx); + let token_usage_ratio = total_token_usage.ratio(); let action_log = self.thread.read(cx).action_log(); let changed_buffers = action_log.read(cx).changed_buffers(cx); @@ -997,15 +1090,8 @@ impl Render for MessageEditor { parent.child(self.render_changed_buffers(&changed_buffers, window, cx)) }) .child(self.render_editor(font_size, line_height, window, cx)) - .when( - total_token_usage.ratio != TokenUsageRatio::Normal, - |parent| { - parent.child(self.render_token_limit_callout( - line_height, - total_token_usage.ratio, - cx, - )) - }, - ) + .when(token_usage_ratio != TokenUsageRatio::Normal, |parent| { + parent.child(self.render_token_limit_callout(line_height, token_usage_ratio, cx)) + }) } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 694b212e31536abf38a7c57639bd8797afca27de..a0d6f99ea0bc2da3a9b373909a0002e75d5669ce 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -227,7 +227,33 @@ pub enum DetailedSummaryState { pub struct TotalTokenUsage { pub total: usize, pub max: usize, - pub ratio: TokenUsageRatio, +} + +impl TotalTokenUsage { + pub fn ratio(&self) -> TokenUsageRatio { + #[cfg(debug_assertions)] + let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") + .unwrap_or("0.8".to_string()) + .parse() + .unwrap(); + #[cfg(not(debug_assertions))] + let warning_threshold: f32 = 0.8; + + if self.total >= self.max { + TokenUsageRatio::Exceeded + } else if self.total as f32 / self.max as f32 >= warning_threshold { + TokenUsageRatio::Warning + } else { + TokenUsageRatio::Normal + } + } + + pub fn add(&self, tokens: usize) -> TotalTokenUsage { + TotalTokenUsage { + total: self.total + tokens, + max: self.max, + } + } } #[derive(Debug, Default, PartialEq, Eq)] @@ -261,6 +287,7 @@ pub struct Thread { last_restore_checkpoint: Option, pending_checkpoint: Option, initial_project_snapshot: Shared>>>, + request_token_usage: Vec, cumulative_token_usage: TokenUsage, exceeded_window_error: Option, feedback: Option, @@ -311,6 +338,7 @@ impl Thread { .spawn(async move { Some(project_snapshot.await) }) .shared() }, + request_token_usage: Vec::new(), cumulative_token_usage: TokenUsage::default(), exceeded_window_error: None, feedback: None, @@ -378,6 +406,7 @@ impl Thread { tool_use, action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), + request_token_usage: serialized.request_token_usage, cumulative_token_usage: serialized.cumulative_token_usage, exceeded_window_error: None, feedback: None, @@ -643,6 +672,18 @@ impl Thread { self.tool_use.message_has_tool_results(message_id) } + /// Filter out contexts that have already been included in previous messages + pub fn filter_new_context<'a>( + &self, + context: impl Iterator, + ) -> impl Iterator { + context.filter(|ctx| self.is_context_new(ctx)) + } + + fn is_context_new(&self, context: &AssistantContext) -> bool { + !self.context.contains_key(&context.id()) + } + pub fn insert_user_message( &mut self, text: impl Into, @@ -654,10 +695,9 @@ impl Thread { let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx); - // Filter out contexts that have already been included in previous messages let new_context: Vec<_> = context .into_iter() - .filter(|ctx| !self.context.contains_key(&ctx.id())) + .filter(|ctx| self.is_context_new(ctx)) .collect(); if !new_context.is_empty() { @@ -837,6 +877,7 @@ impl Thread { .collect(), initial_project_snapshot, cumulative_token_usage: this.cumulative_token_usage, + request_token_usage: this.request_token_usage.clone(), detailed_summary_state: this.detailed_summary_state.clone(), exceeded_window_error: this.exceeded_window_error.clone(), }) @@ -1022,7 +1063,6 @@ impl Thread { cx: &mut Context, ) { let pending_completion_id = post_inc(&mut self.completion_count); - let task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion(request, &cx); let initial_token_usage = @@ -1048,6 +1088,7 @@ impl Thread { stop_reason = reason; } LanguageModelCompletionEvent::UsageUpdate(token_usage) => { + thread.update_token_usage_at_last_message(token_usage); thread.cumulative_token_usage = thread.cumulative_token_usage + token_usage - current_token_usage; @@ -1889,6 +1930,35 @@ impl Thread { self.cumulative_token_usage } + pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage { + let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { + return TotalTokenUsage::default(); + }; + + let max = model.model.max_token_count(); + + let index = self + .messages + .iter() + .position(|msg| msg.id == message_id) + .unwrap_or(0); + + if index == 0 { + return TotalTokenUsage { total: 0, max }; + } + + let token_usage = &self + .request_token_usage + .get(index - 1) + .cloned() + .unwrap_or_default(); + + TotalTokenUsage { + total: token_usage.total_tokens() as usize, + max, + } + } + pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage { let model_registry = LanguageModelRegistry::read_global(cx); let Some(model) = model_registry.default_model() else { @@ -1902,30 +1972,33 @@ impl Thread { return TotalTokenUsage { total: exceeded_error.token_count, max, - ratio: TokenUsageRatio::Exceeded, }; } } - #[cfg(debug_assertions)] - let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") - .unwrap_or("0.8".to_string()) - .parse() - .unwrap(); - #[cfg(not(debug_assertions))] - let warning_threshold: f32 = 0.8; + let total = self + .token_usage_at_last_message() + .unwrap_or_default() + .total_tokens() as usize; - let total = self.cumulative_token_usage.total_tokens() as usize; + TotalTokenUsage { total, max } + } - let ratio = if total >= max { - TokenUsageRatio::Exceeded - } else if total as f32 / max as f32 >= warning_threshold { - TokenUsageRatio::Warning - } else { - TokenUsageRatio::Normal - }; + fn token_usage_at_last_message(&self) -> Option { + self.request_token_usage + .get(self.messages.len().saturating_sub(1)) + .or_else(|| self.request_token_usage.last()) + .cloned() + } + + fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) { + let placeholder = self.token_usage_at_last_message().unwrap_or_default(); + self.request_token_usage + .resize(self.messages.len(), placeholder); - TotalTokenUsage { total, max, ratio } + if let Some(last) = self.request_token_usage.last_mut() { + *last = token_usage; + } } pub fn deny_tool_use( diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 6fb0f6c7a253bdbc808d5209005d598ee0fcf99b..a72313061c6a4073485d5bd6e9cbecb6abb0f0b2 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -509,6 +509,8 @@ pub struct SerializedThread { #[serde(default)] pub cumulative_token_usage: TokenUsage, #[serde(default)] + pub request_token_usage: Vec, + #[serde(default)] pub detailed_summary_state: DetailedSummaryState, #[serde(default)] pub exceeded_window_error: Option, @@ -597,6 +599,7 @@ impl LegacySerializedThread { messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(), initial_project_snapshot: self.initial_project_snapshot, cumulative_token_usage: TokenUsage::default(), + request_token_usage: Vec::new(), detailed_summary_state: DetailedSummaryState::default(), exceeded_window_error: None, } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index cf695023c8b2e32b826f2be917a8d4c2c841e7eb..35bf5d60940e8ce6dbe68cda15db8ef96d7b5aa2 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -97,7 +97,10 @@ pub struct TokenUsage { impl TokenUsage { pub fn total_tokens(&self) -> u32 { - self.input_tokens + self.output_tokens + self.input_tokens + + self.output_tokens + + self.cache_read_input_tokens + + self.cache_creation_input_tokens } } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 7746d214b4f4893bc27e1c0cb438e46c00f9dd90..ee0f941afacfc1c2bf432e985521590e4265e8a1 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -705,12 +705,12 @@ pub fn map_to_language_model_completion_events( update_usage(&mut state.usage, &message.usage); return Some(( vec![ - Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - }), Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( &state.usage, ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), ], state, )); diff --git a/crates/ui/src/components/tab.rs b/crates/ui/src/components/tab.rs index b9b4cb43ce550254e0732942688e821b8c7074c2..8b7d5bbdd4b72f500e29851c1bc4dd5aee4afcfc 100644 --- a/crates/ui/src/components/tab.rs +++ b/crates/ui/src/components/tab.rs @@ -73,11 +73,11 @@ impl Tab { self } - pub fn content_height(cx: &mut App) -> Pixels { + pub fn content_height(cx: &App) -> Pixels { DynamicSpacing::Base32.px(cx) - px(1.) } - pub fn container_height(cx: &mut App) -> Pixels { + pub fn container_height(cx: &App) -> Pixels { DynamicSpacing::Base32.px(cx) } }