Detailed changes
@@ -1,4 +1,4 @@
-use crate::context::{AssistantContext, ContextId};
+use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
@@ -13,16 +13,18 @@ use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
-use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
+use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
- DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState,
- MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task,
- TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
+ DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment,
+ ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription,
+ Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
-use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
+use language_model::{
+ LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, Role, StopReason,
+};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use project::ProjectItem as _;
@@ -682,6 +684,9 @@ fn open_markdown_link(
struct EditMessageState {
editor: Entity<Editor>,
+ last_estimated_token_count: Option<usize>,
+ _subscription: Subscription,
+ _update_token_count_task: Option<Task<anyhow::Result<()>>>,
}
impl ActiveThread {
@@ -781,6 +786,13 @@ impl ActiveThread {
self.last_error.take();
}
+ /// Returns the editing message id and the estimated token count in the content
+ pub fn editing_message_id(&self) -> Option<(MessageId, usize)> {
+ self.editing_message
+ .as_ref()
+ .map(|(id, state)| (*id, state.last_estimated_token_count.unwrap_or(0)))
+ }
+
fn push_message(
&mut self,
id: &MessageId,
@@ -1126,15 +1138,91 @@ impl ActiveThread {
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
editor
});
+ let subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
+ EditorEvent::BufferEdited => {
+ this.update_editing_message_token_count(true, cx);
+ }
+ _ => {}
+ });
self.editing_message = Some((
message_id,
EditMessageState {
editor: editor.clone(),
+ last_estimated_token_count: None,
+ _subscription: subscription,
+ _update_token_count_task: None,
},
));
+ self.update_editing_message_token_count(false, cx);
cx.notify();
}
+ fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context<Self>) {
+ let Some((message_id, state)) = self.editing_message.as_mut() else {
+ return;
+ };
+
+ cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
+ state._update_token_count_task.take();
+
+ let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
+ state.last_estimated_token_count.take();
+ return;
+ };
+
+ let editor = state.editor.clone();
+ let thread = self.thread.clone();
+ let message_id = *message_id;
+
+ state._update_token_count_task = Some(cx.spawn(async move |this, cx| {
+ if debounce {
+ cx.background_executor()
+ .timer(Duration::from_millis(200))
+ .await;
+ }
+
+ let token_count = if let Some(task) = cx.update(|cx| {
+ let context = thread.read(cx).context_for_message(message_id);
+ let new_context = thread.read(cx).filter_new_context(context);
+ let context_text =
+ format_context_as_string(new_context, cx).unwrap_or(String::new());
+ let message_text = editor.read(cx).text(cx);
+
+ let content = context_text + &message_text;
+
+ if content.is_empty() {
+ return None;
+ }
+
+ let request = language_model::LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: language_model::Role::User,
+ content: vec![content.into()],
+ cache: false,
+ }],
+ tools: vec![],
+ stop: vec![],
+ temperature: None,
+ };
+
+ Some(default_model.model.count_tokens(request, cx))
+ })? {
+ task.await?
+ } else {
+ 0
+ };
+
+ this.update(cx, |this, cx| {
+ let Some((_message_id, state)) = this.editing_message.as_mut() else {
+ return;
+ };
+
+ state.last_estimated_token_count = Some(token_count);
+ cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
+ })
+ }));
+ }
+
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
self.editing_message.take();
cx.notify();
@@ -1676,6 +1764,9 @@ impl ActiveThread {
"confirm-edit-message",
"Regenerate",
)
+ .disabled(
+ edit_message_editor.read(cx).is_empty(cx),
+ )
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
@@ -1738,8 +1829,16 @@ impl ActiveThread {
),
};
+ let after_editing_message = self
+ .editing_message
+ .as_ref()
+ .map_or(false, |(editing_message_id, _)| {
+ message_id > *editing_message_id
+ });
+
v_flex()
.w_full()
+ .when(after_editing_message, |parent| parent.opacity(0.2))
.when_some(checkpoint, |parent, checkpoint| {
let mut is_pending = false;
let mut error = None;
@@ -2965,6 +3064,12 @@ impl ActiveThread {
}
}
+pub enum ActiveThreadEvent {
+ EditingMessageTokenCountChanged,
+}
+
+impl EventEmitter<ActiveThreadEvent> for ActiveThread {}
+
impl Render for ActiveThread {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
@@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::{Result, anyhow};
use assistant_context_editor::{
AssistantPanelDelegate, ConfigurationError, ContextEditor, SlashCommandCompletionProvider,
- make_lsp_adapter_delegate, render_remaining_tokens,
+ humanize_token_count, make_lsp_adapter_delegate, render_remaining_tokens,
};
use assistant_settings::{AssistantDockPosition, AssistantSettings};
use assistant_slash_command::SlashCommandWorkingSet;
@@ -37,10 +37,10 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
use zed_actions::agent::OpenConfiguration;
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
-use crate::active_thread::ActiveThread;
+use crate::active_thread::{ActiveThread, ActiveThreadEvent};
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
use crate::history_store::{HistoryEntry, HistoryStore};
-use crate::message_editor::MessageEditor;
+use crate::message_editor::{MessageEditor, MessageEditorEvent};
use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
use crate::thread_history::{PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
@@ -181,8 +181,8 @@ pub struct AssistantPanel {
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<ActiveThread>,
- _thread_subscription: Subscription,
message_editor: Entity<MessageEditor>,
+ _active_thread_subscriptions: Vec<Subscription>,
context_store: Entity<assistant_context_editor::ContextStore>,
context_editor: Option<Entity<ContextEditor>>,
configuration: Option<Entity<AssistantConfiguration>>,
@@ -264,6 +264,13 @@ impl AssistantPanel {
)
});
+ let message_editor_subscription =
+ cx.subscribe(&message_editor, |_, _, event, cx| match event {
+ MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
+ cx.notify();
+ }
+ });
+
let history_store =
cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx));
@@ -288,6 +295,12 @@ impl AssistantPanel {
)
});
+ let active_thread_subscription = cx.subscribe(&thread, |_, _, event, cx| match &event {
+ ActiveThreadEvent::EditingMessageTokenCountChanged => {
+ cx.notify();
+ }
+ });
+
Self {
active_view,
workspace,
@@ -296,8 +309,12 @@ impl AssistantPanel {
language_registry,
thread_store: thread_store.clone(),
thread,
- _thread_subscription: thread_subscription,
message_editor,
+ _active_thread_subscriptions: vec![
+ thread_subscription,
+ active_thread_subscription,
+ message_editor_subscription,
+ ],
context_store,
context_editor: None,
configuration: None,
@@ -382,6 +399,13 @@ impl AssistantPanel {
.detach_and_log_err(cx);
}
+ let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
+ if let ThreadEvent::MessageAdded(_) = &event {
+ // needed to leave empty state
+ cx.notify();
+ }
+ });
+
self.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
@@ -394,12 +418,12 @@ impl AssistantPanel {
)
});
- self._thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
- if let ThreadEvent::MessageAdded(_) = &event {
- // needed to leave empty state
- cx.notify();
- }
- });
+ let active_thread_subscription =
+ cx.subscribe(&self.thread, |_, _, event, cx| match &event {
+ ActiveThreadEvent::EditingMessageTokenCountChanged => {
+ cx.notify();
+ }
+ });
self.message_editor = cx.new(|cx| {
MessageEditor::new(
@@ -413,6 +437,19 @@ impl AssistantPanel {
)
});
self.message_editor.focus_handle(cx).focus(window);
+
+ let message_editor_subscription =
+ cx.subscribe(&self.message_editor, |_, _, event, cx| match event {
+ MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
+ cx.notify();
+ }
+ });
+
+ self._active_thread_subscriptions = vec![
+ thread_subscription,
+ active_thread_subscription,
+ message_editor_subscription,
+ ];
}
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -538,6 +575,13 @@ impl AssistantPanel {
Some(this.thread_store.downgrade()),
)
});
+ let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
+ if let ThreadEvent::MessageAdded(_) = &event {
+ // needed to leave empty state
+ cx.notify();
+ }
+ });
+
this.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
@@ -549,6 +593,14 @@ impl AssistantPanel {
cx,
)
});
+
+ let active_thread_subscription =
+ cx.subscribe(&this.thread, |_, _, event, cx| match &event {
+ ActiveThreadEvent::EditingMessageTokenCountChanged => {
+ cx.notify();
+ }
+ });
+
this.message_editor = cx.new(|cx| {
MessageEditor::new(
this.fs.clone(),
@@ -561,6 +613,19 @@ impl AssistantPanel {
)
});
this.message_editor.focus_handle(cx).focus(window);
+
+ let message_editor_subscription =
+ cx.subscribe(&this.message_editor, |_, _, event, cx| match event {
+ MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
+ cx.notify();
+ }
+ });
+
+ this._active_thread_subscriptions = vec![
+ thread_subscription,
+ active_thread_subscription,
+ message_editor_subscription,
+ ];
})
})
}
@@ -853,7 +918,7 @@ impl Panel for AssistantPanel {
}
impl AssistantPanel {
- fn render_title_view(&self, _window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
+ fn render_title_view(&self, _window: &mut Window, cx: &Context<Self>) -> AnyElement {
const LOADING_SUMMARY_PLACEHOLDER: &str = "Loading Summaryβ¦";
let content = match &self.active_view {
@@ -913,13 +978,8 @@ impl AssistantPanel {
fn render_toolbar(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let active_thread = self.thread.read(cx);
let thread = active_thread.thread().read(cx);
- let token_usage = thread.total_token_usage(cx);
let thread_id = thread.id().clone();
-
- let is_generating = thread.is_generating();
let is_empty = active_thread.is_empty();
- let focus_handle = self.focus_handle(cx);
-
let is_history = matches!(self.active_view, ActiveView::History);
let show_token_count = match &self.active_view {
@@ -928,6 +988,8 @@ impl AssistantPanel {
_ => false,
};
+ let focus_handle = self.focus_handle(cx);
+
let go_back_button = match &self.active_view {
ActiveView::History | ActiveView::Configuration => Some(
div().pl_1().child(
@@ -974,69 +1036,9 @@ impl AssistantPanel {
h_flex()
.h_full()
.gap_2()
- .when(show_token_count, |parent| match self.active_view {
- ActiveView::Thread { .. } => {
- if token_usage.total == 0 {
- return parent;
- }
-
- let token_color = match token_usage.ratio {
- TokenUsageRatio::Normal => Color::Muted,
- TokenUsageRatio::Warning => Color::Warning,
- TokenUsageRatio::Exceeded => Color::Error,
- };
-
- parent.child(
- h_flex()
- .flex_shrink_0()
- .gap_0p5()
- .child(
- Label::new(assistant_context_editor::humanize_token_count(
- token_usage.total,
- ))
- .size(LabelSize::Small)
- .color(token_color)
- .map(|label| {
- if is_generating {
- label
- .with_animation(
- "used-tokens-label",
- Animation::new(Duration::from_secs(2))
- .repeat()
- .with_easing(pulsating_between(
- 0.6, 1.,
- )),
- |label, delta| label.alpha(delta),
- )
- .into_any()
- } else {
- label.into_any_element()
- }
- }),
- )
- .child(
- Label::new("/").size(LabelSize::Small).color(Color::Muted),
- )
- .child(
- Label::new(assistant_context_editor::humanize_token_count(
- token_usage.max,
- ))
- .size(LabelSize::Small)
- .color(Color::Muted),
- ),
- )
- }
- ActiveView::PromptEditor => {
- let Some(editor) = self.context_editor.as_ref() else {
- return parent;
- };
- let Some(element) = render_remaining_tokens(editor, cx) else {
- return parent;
- };
- parent.child(element)
- }
- _ => parent,
- })
+ .when(show_token_count, |parent|
+ parent.children(self.render_token_count(&thread, cx))
+ )
.child(
h_flex()
.h_full()
@@ -1132,6 +1134,111 @@ impl AssistantPanel {
)
}
+ fn render_token_count(&self, thread: &Thread, cx: &App) -> Option<AnyElement> {
+ let is_generating = thread.is_generating();
+ let message_editor = self.message_editor.read(cx);
+
+ let conversation_token_usage = thread.total_token_usage(cx);
+ let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
+ self.thread.read(cx).editing_message_id()
+ {
+ let combined = thread
+ .token_usage_up_to_message(editing_message_id, cx)
+ .add(unsent_tokens);
+
+ (combined, unsent_tokens > 0)
+ } else {
+ let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0);
+ let combined = conversation_token_usage.add(unsent_tokens);
+
+ (combined, unsent_tokens > 0)
+ };
+
+ let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count();
+
+ match self.active_view {
+ ActiveView::Thread { .. } => {
+ if total_token_usage.total == 0 {
+ return None;
+ }
+
+ let token_color = match total_token_usage.ratio() {
+ TokenUsageRatio::Normal if is_estimating => Color::Default,
+ TokenUsageRatio::Normal => Color::Muted,
+ TokenUsageRatio::Warning => Color::Warning,
+ TokenUsageRatio::Exceeded => Color::Error,
+ };
+
+ let token_count = h_flex()
+ .id("token-count")
+ .flex_shrink_0()
+ .gap_0p5()
+ .when(!is_generating && is_estimating, |parent| {
+ parent
+ .child(
+ h_flex()
+ .mr_0p5()
+ .size_2()
+ .justify_center()
+ .rounded_full()
+ .bg(cx.theme().colors().text.opacity(0.1))
+ .child(
+ div().size_1().rounded_full().bg(cx.theme().colors().text),
+ ),
+ )
+ .tooltip(move |window, cx| {
+ Tooltip::with_meta(
+ "Estimated New Token Count",
+ None,
+ format!(
+ "Current Conversation Tokens: {}",
+ humanize_token_count(conversation_token_usage.total)
+ ),
+ window,
+ cx,
+ )
+ })
+ })
+ .child(
+ Label::new(humanize_token_count(total_token_usage.total))
+ .size(LabelSize::Small)
+ .color(token_color)
+ .map(|label| {
+ if is_generating || is_waiting_to_update_token_count {
+ label
+ .with_animation(
+ "used-tokens-label",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.6, 1.)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any()
+ } else {
+ label.into_any_element()
+ }
+ }),
+ )
+ .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
+ .child(
+ Label::new(humanize_token_count(total_token_usage.max))
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any();
+
+ Some(token_count)
+ }
+ ActiveView::PromptEditor => {
+ let editor = self.context_editor.as_ref()?;
+ let element = render_remaining_tokens(editor, cx)?;
+
+ Some(element.into_any_element())
+ }
+ _ => None,
+ }
+ }
+
fn render_active_thread_or_empty_state(
&self,
window: &mut Window,
@@ -2,22 +2,23 @@ use std::collections::BTreeMap;
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
+use crate::context::format_context_as_string;
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{
- ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, EditorStyle,
- MultiBuffer,
+ ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent, EditorMode,
+ EditorStyle, MultiBuffer,
};
use file_icons::FileIcons;
use fs::Fs;
use gpui::{
- Animation, AnimationExt, App, Entity, Focusable, Subscription, TextStyle, WeakEntity,
- linear_color_stop, linear_gradient, point, pulsating_between,
+ Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle,
+ WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::{Buffer, Language};
-use language_model::{ConfiguredModel, LanguageModelRegistry};
+use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
@@ -55,6 +56,8 @@ pub struct MessageEditor {
edits_expanded: bool,
editor_is_expanded: bool,
waiting_for_summaries_to_send: bool,
+ last_estimated_token_count: Option<usize>,
+ update_token_count_task: Option<Task<anyhow::Result<()>>>,
_subscriptions: Vec<Subscription>,
}
@@ -129,8 +132,18 @@ impl MessageEditor {
let incompatible_tools =
cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx));
- let subscriptions =
- vec![cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event)];
+ let subscriptions = vec![
+ cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
+ cx.subscribe(&editor, |this, _, event, cx| match event {
+ EditorEvent::BufferEdited => {
+ this.message_or_context_changed(true, cx);
+ }
+ _ => {}
+ }),
+ cx.observe(&context_store, |this, _, cx| {
+ this.message_or_context_changed(false, cx);
+ }),
+ ];
Self {
editor: editor.clone(),
@@ -156,6 +169,8 @@ impl MessageEditor {
waiting_for_summaries_to_send: false,
profile_selector: cx
.new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)),
+ last_estimated_token_count: None,
+ update_token_count_task: None,
_subscriptions: subscriptions,
}
}
@@ -256,6 +271,9 @@ impl MessageEditor {
text
});
+ self.last_estimated_token_count.take();
+ cx.emit(MessageEditorEvent::EstimatedTokenCount);
+
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
@@ -937,6 +955,80 @@ impl MessageEditor {
.label_size(LabelSize::Small),
)
}
+
+ pub fn last_estimated_token_count(&self) -> Option<usize> {
+ self.last_estimated_token_count
+ }
+
+ pub fn is_waiting_to_update_token_count(&self) -> bool {
+ self.update_token_count_task.is_some()
+ }
+
+ fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
+ cx.emit(MessageEditorEvent::Changed);
+ self.update_token_count_task.take();
+
+ let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
+ self.last_estimated_token_count.take();
+ return;
+ };
+
+ let context_store = self.context_store.clone();
+ let editor = self.editor.clone();
+ let thread = self.thread.clone();
+
+ self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
+ if debounce {
+ cx.background_executor()
+ .timer(Duration::from_millis(200))
+ .await;
+ }
+
+ let token_count = if let Some(task) = cx.update(|cx| {
+ let context = context_store.read(cx).context().iter();
+ let new_context = thread.read(cx).filter_new_context(context);
+ let context_text =
+ format_context_as_string(new_context, cx).unwrap_or(String::new());
+ let message_text = editor.read(cx).text(cx);
+
+ let content = context_text + &message_text;
+
+ if content.is_empty() {
+ return None;
+ }
+
+ let request = language_model::LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: language_model::Role::User,
+ content: vec![content.into()],
+ cache: false,
+ }],
+ tools: vec![],
+ stop: vec![],
+ temperature: None,
+ };
+
+ Some(default_model.model.count_tokens(request, cx))
+ })? {
+ task.await?
+ } else {
+ 0
+ };
+
+ this.update(cx, |this, cx| {
+ this.last_estimated_token_count = Some(token_count);
+ cx.emit(MessageEditorEvent::EstimatedTokenCount);
+ this.update_token_count_task.take();
+ })
+ }));
+ }
+}
+
+impl EventEmitter<MessageEditorEvent> for MessageEditor {}
+
+pub enum MessageEditorEvent {
+ EstimatedTokenCount,
+ Changed,
}
impl Focusable for MessageEditor {
@@ -949,6 +1041,7 @@ impl Render for MessageEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let thread = self.thread.read(cx);
let total_token_usage = thread.total_token_usage(cx);
+ let token_usage_ratio = total_token_usage.ratio();
let action_log = self.thread.read(cx).action_log();
let changed_buffers = action_log.read(cx).changed_buffers(cx);
@@ -997,15 +1090,8 @@ impl Render for MessageEditor {
parent.child(self.render_changed_buffers(&changed_buffers, window, cx))
})
.child(self.render_editor(font_size, line_height, window, cx))
- .when(
- total_token_usage.ratio != TokenUsageRatio::Normal,
- |parent| {
- parent.child(self.render_token_limit_callout(
- line_height,
- total_token_usage.ratio,
- cx,
- ))
- },
- )
+ .when(token_usage_ratio != TokenUsageRatio::Normal, |parent| {
+ parent.child(self.render_token_limit_callout(line_height, token_usage_ratio, cx))
+ })
}
}
@@ -227,7 +227,33 @@ pub enum DetailedSummaryState {
pub struct TotalTokenUsage {
pub total: usize,
pub max: usize,
- pub ratio: TokenUsageRatio,
+}
+
+impl TotalTokenUsage {
+ pub fn ratio(&self) -> TokenUsageRatio {
+ #[cfg(debug_assertions)]
+ let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
+ .unwrap_or("0.8".to_string())
+ .parse()
+ .unwrap();
+ #[cfg(not(debug_assertions))]
+ let warning_threshold: f32 = 0.8;
+
+ if self.total >= self.max {
+ TokenUsageRatio::Exceeded
+ } else if self.total as f32 / self.max as f32 >= warning_threshold {
+ TokenUsageRatio::Warning
+ } else {
+ TokenUsageRatio::Normal
+ }
+ }
+
+ pub fn add(&self, tokens: usize) -> TotalTokenUsage {
+ TotalTokenUsage {
+ total: self.total + tokens,
+ max: self.max,
+ }
+ }
}
#[derive(Debug, Default, PartialEq, Eq)]
@@ -261,6 +287,7 @@ pub struct Thread {
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<ThreadCheckpoint>,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
+ request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
feedback: Option<ThreadFeedback>,
@@ -311,6 +338,7 @@ impl Thread {
.spawn(async move { Some(project_snapshot.await) })
.shared()
},
+ request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
feedback: None,
@@ -378,6 +406,7 @@ impl Thread {
tool_use,
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
+ request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
feedback: None,
@@ -643,6 +672,18 @@ impl Thread {
self.tool_use.message_has_tool_results(message_id)
}
+ /// Filter out contexts that have already been included in previous messages
+ pub fn filter_new_context<'a>(
+ &self,
+ context: impl Iterator<Item = &'a AssistantContext>,
+ ) -> impl Iterator<Item = &'a AssistantContext> {
+ context.filter(|ctx| self.is_context_new(ctx))
+ }
+
+ fn is_context_new(&self, context: &AssistantContext) -> bool {
+ !self.context.contains_key(&context.id())
+ }
+
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
@@ -654,10 +695,9 @@ impl Thread {
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
- // Filter out contexts that have already been included in previous messages
let new_context: Vec<_> = context
.into_iter()
- .filter(|ctx| !self.context.contains_key(&ctx.id()))
+ .filter(|ctx| self.is_context_new(ctx))
.collect();
if !new_context.is_empty() {
@@ -837,6 +877,7 @@ impl Thread {
.collect(),
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage,
+ request_token_usage: this.request_token_usage.clone(),
detailed_summary_state: this.detailed_summary_state.clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
})
@@ -1022,7 +1063,6 @@ impl Thread {
cx: &mut Context<Self>,
) {
let pending_completion_id = post_inc(&mut self.completion_count);
-
let task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion(request, &cx);
let initial_token_usage =
@@ -1048,6 +1088,7 @@ impl Thread {
stop_reason = reason;
}
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
+ thread.update_token_usage_at_last_message(token_usage);
thread.cumulative_token_usage = thread.cumulative_token_usage
+ token_usage
- current_token_usage;
@@ -1889,6 +1930,35 @@ impl Thread {
self.cumulative_token_usage
}
+ pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
+ let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
+ return TotalTokenUsage::default();
+ };
+
+ let max = model.model.max_token_count();
+
+ let index = self
+ .messages
+ .iter()
+ .position(|msg| msg.id == message_id)
+ .unwrap_or(0);
+
+ if index == 0 {
+ return TotalTokenUsage { total: 0, max };
+ }
+
+ let token_usage = &self
+ .request_token_usage
+ .get(index - 1)
+ .cloned()
+ .unwrap_or_default();
+
+ TotalTokenUsage {
+ total: token_usage.total_tokens() as usize,
+ max,
+ }
+ }
+
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.default_model() else {
@@ -1902,30 +1972,33 @@ impl Thread {
return TotalTokenUsage {
total: exceeded_error.token_count,
max,
- ratio: TokenUsageRatio::Exceeded,
};
}
}
- #[cfg(debug_assertions)]
- let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
- .unwrap_or("0.8".to_string())
- .parse()
- .unwrap();
- #[cfg(not(debug_assertions))]
- let warning_threshold: f32 = 0.8;
+ let total = self
+ .token_usage_at_last_message()
+ .unwrap_or_default()
+ .total_tokens() as usize;
- let total = self.cumulative_token_usage.total_tokens() as usize;
+ TotalTokenUsage { total, max }
+ }
- let ratio = if total >= max {
- TokenUsageRatio::Exceeded
- } else if total as f32 / max as f32 >= warning_threshold {
- TokenUsageRatio::Warning
- } else {
- TokenUsageRatio::Normal
- };
+ fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
+ self.request_token_usage
+ .get(self.messages.len().saturating_sub(1))
+ .or_else(|| self.request_token_usage.last())
+ .cloned()
+ }
+
+ fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
+ let placeholder = self.token_usage_at_last_message().unwrap_or_default();
+ self.request_token_usage
+ .resize(self.messages.len(), placeholder);
- TotalTokenUsage { total, max, ratio }
+ if let Some(last) = self.request_token_usage.last_mut() {
+ *last = token_usage;
+ }
}
pub fn deny_tool_use(
@@ -509,6 +509,8 @@ pub struct SerializedThread {
#[serde(default)]
pub cumulative_token_usage: TokenUsage,
#[serde(default)]
+ pub request_token_usage: Vec<TokenUsage>,
+ #[serde(default)]
pub detailed_summary_state: DetailedSummaryState,
#[serde(default)]
pub exceeded_window_error: Option<ExceededWindowError>,
@@ -597,6 +599,7 @@ impl LegacySerializedThread {
messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
initial_project_snapshot: self.initial_project_snapshot,
cumulative_token_usage: TokenUsage::default(),
+ request_token_usage: Vec::new(),
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
}
@@ -97,7 +97,10 @@ pub struct TokenUsage {
impl TokenUsage {
pub fn total_tokens(&self) -> u32 {
- self.input_tokens + self.output_tokens
+ self.input_tokens
+ + self.output_tokens
+ + self.cache_read_input_tokens
+ + self.cache_creation_input_tokens
}
}
@@ -705,12 +705,12 @@ pub fn map_to_language_model_completion_events(
update_usage(&mut state.usage, &message.usage);
return Some((
vec![
- Ok(LanguageModelCompletionEvent::StartMessage {
- message_id: message.id,
- }),
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
&state.usage,
))),
+ Ok(LanguageModelCompletionEvent::StartMessage {
+ message_id: message.id,
+ }),
],
state,
));
@@ -73,11 +73,11 @@ impl Tab {
self
}
- pub fn content_height(cx: &mut App) -> Pixels {
+ pub fn content_height(cx: &App) -> Pixels {
DynamicSpacing::Base32.px(cx) - px(1.)
}
- pub fn container_height(cx: &mut App) -> Pixels {
+ pub fn container_height(cx: &App) -> Pixels {
DynamicSpacing::Base32.px(cx)
}
}