From 74e6e378e37493440cb01c956ef6de7ebcde38e1 Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 12 May 2026 16:27:01 -0700 Subject: [PATCH] fix(agent): harden fallback usage accounting --- internal/agent/agent.go | 27 +++++++--- internal/agent/usage_fallback_test.go | 75 ++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 9 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index fcedcc55615db8afd29f109231cae627b67ac98f..80d8ae64a359fec27b0b712fc5ecef5514d579b4 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -760,7 +760,9 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan // Just in case, get just the last usage info. usage := resp.Response.Usage currentSession.SummaryMessageID = summaryMessage.ID - currentSession.CompletionTokens = usage.OutputTokens + if completionTokens := summaryCompletionTokens(usage, summaryMessage); completionTokens != 0 { + currentSession.CompletionTokens = completionTokens + } currentSession.PromptTokens = 0 _, err = a.sessions.Save(genCtx, currentSession) if err != nil { @@ -1145,11 +1147,9 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) + modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens) - eventCost := cost - if estimated { - eventCost = 0 + if !estimated { + a.eventTokensUsed(session.ID, model, usage, cost) } - a.eventTokensUsed(session.ID, model, usage, eventCost) if estimated { cost = 0 @@ -1166,10 +1166,23 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, } session.Cost += cost - if !usageIsZero(usage) { + updateSessionTokenCounters(session, usage) +} + +func updateSessionTokenCounters(session *session.Session, usage fantasy.Usage) { + if usage.OutputTokens != 0 { session.CompletionTokens = usage.OutputTokens - session.PromptTokens = usage.InputTokens + usage.CacheReadTokens } + if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 { + session.PromptTokens = promptTokens + } +} + +func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message) int64 { + if usage.OutputTokens != 0 { + return usage.OutputTokens + } + return approxTokenCount(summaryMessage.Content().Text) + approxTokenCount(summaryMessage.ReasoningContent().String()) } func (a *sessionAgent) Cancel(sessionID string) { diff --git a/internal/agent/usage_fallback_test.go b/internal/agent/usage_fallback_test.go index ec925c4148da922edec666a5d8bfec350c9295bf..1f2442aef7d4a1623419f70c81cd494ccc88ef69 100644 --- a/internal/agent/usage_fallback_test.go +++ b/internal/agent/usage_fallback_test.go @@ -6,6 +6,7 @@ import ( "charm.land/catwalk/pkg/catwalk" "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/session" "github.com/stretchr/testify/require" ) @@ -232,7 +233,7 @@ func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) { require.Equal(t, int64(456), currentSession.CompletionTokens) } -func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) { +func TestUpdateSessionUsagePreservesOmittedCountersForPartialUsage(t *testing.T) { t.Parallel() agent := &sessionAgent{} @@ -247,7 +248,77 @@ func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) { agent.updateSessionUsage(model, currentSession, usage, nil, false) require.Equal(t, int64(789), currentSession.PromptTokens) - require.Zero(t, currentSession.CompletionTokens) + require.Equal(t, int64(456), currentSession.CompletionTokens) +} + +func TestUpdateSessionUsagePreservesCountersForTotalOnlyUsage(t *testing.T) { + t.Parallel() + + agent := &sessionAgent{} + currentSession := &session.Session{ + ID: "session-id", + PromptTokens: 123, + CompletionTokens: 456, + } + model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}} + usage := fantasy.Usage{TotalTokens: 100} + + agent.updateSessionUsage(model, currentSession, usage, nil, false) + + require.Equal(t, int64(123), currentSession.PromptTokens) + require.Equal(t, int64(456), currentSession.CompletionTokens) +} + +func TestUpdateSessionUsagePreservesPromptForOutputOnlyUsage(t *testing.T) { + t.Parallel() + + agent := &sessionAgent{} + currentSession := &session.Session{ + ID: "session-id", + PromptTokens: 123, + CompletionTokens: 456, + } + model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}} + usage := fantasy.Usage{OutputTokens: 50} + + agent.updateSessionUsage(model, currentSession, usage, nil, false) + + require.Equal(t, int64(123), currentSession.PromptTokens) + require.Equal(t, int64(50), currentSession.CompletionTokens) +} + +func TestUpdateSessionUsageKeepsCountersForEstimatedZeroUsage(t *testing.T) { + t.Parallel() + + agent := &sessionAgent{} + currentSession := &session.Session{ + ID: "session-id", + PromptTokens: 123, + CompletionTokens: 456, + Cost: 1.25, + } + model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}} + + agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, true) + + require.Equal(t, 1.25, currentSession.Cost) + require.Equal(t, int64(123), currentSession.PromptTokens) + require.Equal(t, int64(456), currentSession.CompletionTokens) +} + +func TestSummaryCompletionTokens(t *testing.T) { + t.Parallel() + + summaryMessage := message.Message{ + Parts: []message.ContentPart{ + message.TextContent{Text: "summary text"}, + message.ReasoningContent{Thinking: "reasoning text"}, + }, + } + + require.Equal(t, int64(42), summaryCompletionTokens(fantasy.Usage{OutputTokens: 42}, summaryMessage)) + require.Equal(t, approxTokenCount("summary text")+approxTokenCount("reasoning text"), summaryCompletionTokens(fantasy.Usage{}, summaryMessage)) + require.Zero(t, summaryCompletionTokens(fantasy.Usage{}, message.Message{})) } func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {