fix(agent): harden fallback usage accounting

Greg Slepak created

Change summary

internal/agent/agent.go               | 27 +++++++--
internal/agent/usage_fallback_test.go | 75 ++++++++++++++++++++++++++++
2 files changed, 93 insertions(+), 9 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -760,7 +760,9 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
 	// Just in case, get just the last usage info.
 	usage := resp.Response.Usage
 	currentSession.SummaryMessageID = summaryMessage.ID
-	currentSession.CompletionTokens = usage.OutputTokens
+	if completionTokens := summaryCompletionTokens(usage, summaryMessage); completionTokens != 0 {
+		currentSession.CompletionTokens = completionTokens
+	}
 	currentSession.PromptTokens = 0
 	_, err = a.sessions.Save(genCtx, currentSession)
 	if err != nil {
@@ -1145,11 +1147,9 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session,
 		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 
-	eventCost := cost
-	if estimated {
-		eventCost = 0
+	if !estimated {
+		a.eventTokensUsed(session.ID, model, usage, cost)
 	}
-	a.eventTokensUsed(session.ID, model, usage, eventCost)
 
 	if estimated {
 		cost = 0
@@ -1166,10 +1166,23 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session,
 	}
 
 	session.Cost += cost
-	if !usageIsZero(usage) {
+	updateSessionTokenCounters(session, usage)
+}
+
+func updateSessionTokenCounters(session *session.Session, usage fantasy.Usage) {
+	if usage.OutputTokens != 0 {
 		session.CompletionTokens = usage.OutputTokens
-		session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
 	}
+	if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
+		session.PromptTokens = promptTokens
+	}
+}
+
+func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message) int64 {
+	if usage.OutputTokens != 0 {
+		return usage.OutputTokens
+	}
+	return approxTokenCount(summaryMessage.Content().Text) + approxTokenCount(summaryMessage.ReasoningContent().String())
 }
 
 func (a *sessionAgent) Cancel(sessionID string) {

internal/agent/usage_fallback_test.go 🔗

@@ -6,6 +6,7 @@ import (
 
 	"charm.land/catwalk/pkg/catwalk"
 	"charm.land/fantasy"
+	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/session"
 	"github.com/stretchr/testify/require"
 )
@@ -232,7 +233,7 @@ func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
 	require.Equal(t, int64(456), currentSession.CompletionTokens)
 }
 
-func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) {
+func TestUpdateSessionUsagePreservesOmittedCountersForPartialUsage(t *testing.T) {
 	t.Parallel()
 
 	agent := &sessionAgent{}
@@ -247,7 +248,77 @@ func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) {
 	agent.updateSessionUsage(model, currentSession, usage, nil, false)
 
 	require.Equal(t, int64(789), currentSession.PromptTokens)
-	require.Zero(t, currentSession.CompletionTokens)
+	require.Equal(t, int64(456), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsagePreservesCountersForTotalOnlyUsage(t *testing.T) {
+	t.Parallel()
+
+	agent := &sessionAgent{}
+	currentSession := &session.Session{
+		ID:               "session-id",
+		PromptTokens:     123,
+		CompletionTokens: 456,
+	}
+	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+	usage := fantasy.Usage{TotalTokens: 100}
+
+	agent.updateSessionUsage(model, currentSession, usage, nil, false)
+
+	require.Equal(t, int64(123), currentSession.PromptTokens)
+	require.Equal(t, int64(456), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsagePreservesPromptForOutputOnlyUsage(t *testing.T) {
+	t.Parallel()
+
+	agent := &sessionAgent{}
+	currentSession := &session.Session{
+		ID:               "session-id",
+		PromptTokens:     123,
+		CompletionTokens: 456,
+	}
+	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+	usage := fantasy.Usage{OutputTokens: 50}
+
+	agent.updateSessionUsage(model, currentSession, usage, nil, false)
+
+	require.Equal(t, int64(123), currentSession.PromptTokens)
+	require.Equal(t, int64(50), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsageKeepsCountersForEstimatedZeroUsage(t *testing.T) {
+	t.Parallel()
+
+	agent := &sessionAgent{}
+	currentSession := &session.Session{
+		ID:               "session-id",
+		PromptTokens:     123,
+		CompletionTokens: 456,
+		Cost:             1.25,
+	}
+	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+
+	agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, true)
+
+	require.Equal(t, 1.25, currentSession.Cost)
+	require.Equal(t, int64(123), currentSession.PromptTokens)
+	require.Equal(t, int64(456), currentSession.CompletionTokens)
+}
+
+func TestSummaryCompletionTokens(t *testing.T) {
+	t.Parallel()
+
+	summaryMessage := message.Message{
+		Parts: []message.ContentPart{
+			message.TextContent{Text: "summary text"},
+			message.ReasoningContent{Thinking: "reasoning text"},
+		},
+	}
+
+	require.Equal(t, int64(42), summaryCompletionTokens(fantasy.Usage{OutputTokens: 42}, summaryMessage))
+	require.Equal(t, approxTokenCount("summary text")+approxTokenCount("reasoning text"), summaryCompletionTokens(fantasy.Usage{}, summaryMessage))
+	require.Zero(t, summaryCompletionTokens(fantasy.Usage{}, message.Message{}))
 }
 
 func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {