fix(agent): estimate missing streamed usage

Greg Slepak created

Add a fallback token estimator for streamed steps that return zero usage so session context pressure remains accurate when providers omit final usage chunks.

Estimated usage updates prompt/completion counters but never contributes cost, while provider-reported usage continues to preserve normal cost accounting and OpenRouter overrides. Zero-usage updates now leave existing nonzero token counters intact.

Change summary

internal/agent/agent.go               |  42 ++++-
internal/agent/usage_fallback.go      | 172 ++++++++++++++++++++++++
internal/agent/usage_fallback_test.go | 203 +++++++++++++++++++++++++++++
3 files changed, 405 insertions(+), 12 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -261,6 +261,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 	a.eventPromptSent(call.SessionID)
 
 	var currentAssistant *message.Message
+	var stepMessages []fantasy.Message
 	var shouldSummarize bool
 	// Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
 	var maxOutputTokens *int64
@@ -319,6 +320,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
 			}
 
+			sessionLock.Lock()
+			stepMessages = cloneFantasyMessages(prepared.Messages)
+			sessionLock.Unlock()
+
 			var assistantMsg message.Message
 			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
 				Role:     message.Assistant,
@@ -444,7 +449,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 			if getSessionErr != nil {
 				return getSessionErr
 			}
-			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
+			usage, estimated := fallbackStepUsage(stepMessages, stepResult)
+			a.updateSessionUsage(largeModel, &updatedSession, usage, a.openrouterCost(stepResult.ProviderMetadata), estimated)
 			_, sessionErr := a.sessions.Save(ctx, updatedSession)
 			if sessionErr != nil {
 				return sessionErr
@@ -749,7 +755,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
 		}
 	}
 
-	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost)
+	a.updateSessionUsage(largeModel, &currentSession, resp.TotalUsage, openrouterCost, false)
 
 	// Just in case, get just the last usage info.
 	usage := resp.Response.Usage
@@ -1132,28 +1138,40 @@ func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float6
 	return &opts.Usage.Cost
 }
 
-func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
+func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64, estimated bool) {
 	modelConfig := model.CatwalkCfg
 	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
 		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
 
-	a.eventTokensUsed(session.ID, model, usage, cost)
-
-	// Use override cost if available (e.g., from OpenRouter).
-	if overrideCost != nil {
-		cost = *overrideCost
+	eventCost := cost
+	if estimated {
+		eventCost = 0
 	}
+	a.eventTokensUsed(session.ID, model, usage, eventCost)
 
-	// Skip cost accumulation
-	if model.FlatRate {
+	if estimated {
 		cost = 0
+	} else {
+		// Use override cost if available (e.g., from OpenRouter).
+		if overrideCost != nil {
+			cost = *overrideCost
+		}
+
+		// Skip cost accumulation
+		if model.FlatRate {
+			cost = 0
+		}
 	}
 
 	session.Cost += cost
-	session.CompletionTokens = usage.OutputTokens
-	session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
+	if usage.OutputTokens != 0 {
+		session.CompletionTokens = usage.OutputTokens
+	}
+	if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
+		session.PromptTokens = promptTokens
+	}
 }
 
 func (a *sessionAgent) Cancel(sessionID string) {

internal/agent/usage_fallback.go 🔗

@@ -0,0 +1,172 @@
+package agent
+
+import (
+	"fmt"
+
+	"charm.land/fantasy"
+)
+
+func usageIsZero(usage fantasy.Usage) bool {
+	return usage.InputTokens == 0 &&
+		usage.OutputTokens == 0 &&
+		usage.TotalTokens == 0 &&
+		usage.ReasoningTokens == 0 &&
+		usage.CacheCreationTokens == 0 &&
+		usage.CacheReadTokens == 0
+}
+
+func fallbackStepUsage(messages []fantasy.Message, step fantasy.StepResult) (fantasy.Usage, bool) {
+	if !usageIsZero(step.Usage) {
+		return step.Usage, false
+	}
+
+	inputTokens := estimateMessageTokens(messages)
+	outputTokens := estimateStepCompletionTokens(step)
+	if inputTokens == 0 && outputTokens == 0 {
+		return fantasy.Usage{}, false
+	}
+
+	return fantasy.Usage{
+		InputTokens:  inputTokens,
+		OutputTokens: outputTokens,
+		TotalTokens:  inputTokens + outputTokens,
+	}, true
+}
+
+func cloneFantasyMessages(messages []fantasy.Message) []fantasy.Message {
+	cloned := make([]fantasy.Message, len(messages))
+	for i, msg := range messages {
+		cloned[i] = msg
+		cloned[i].Content = append([]fantasy.MessagePart(nil), msg.Content...)
+	}
+	return cloned
+}
+
+func estimateMessageTokens(messages []fantasy.Message) int64 {
+	var tokens int64
+	for _, msg := range messages {
+		tokens += approxTokenCount(string(msg.Role))
+		for _, part := range msg.Content {
+			tokens += estimateMessagePartTokens(part)
+		}
+	}
+	return tokens
+}
+
+func estimateStepCompletionTokens(step fantasy.StepResult) int64 {
+	var tokens int64
+	for _, content := range step.Content {
+		switch c := content.(type) {
+		case fantasy.TextContent:
+			tokens += approxTokenCount(c.Text)
+		case *fantasy.TextContent:
+			tokens += approxTokenCount(c.Text)
+		case fantasy.ReasoningContent:
+			tokens += approxTokenCount(c.Text)
+		case *fantasy.ReasoningContent:
+			tokens += approxTokenCount(c.Text)
+		case fantasy.FileContent:
+			tokens += estimateGeneratedFileTokens(c)
+		case *fantasy.FileContent:
+			tokens += estimateGeneratedFileTokens(*c)
+		case fantasy.SourceContent:
+			tokens += estimateSourceTokens(c)
+		case *fantasy.SourceContent:
+			tokens += estimateSourceTokens(*c)
+		case fantasy.ToolCallContent:
+			tokens += estimateToolCallTokens(c.ToolName, c.Input)
+		case *fantasy.ToolCallContent:
+			tokens += estimateToolCallTokens(c.ToolName, c.Input)
+		case fantasy.ToolResultContent:
+			tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result)
+		case *fantasy.ToolResultContent:
+			tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result)
+		}
+	}
+	return tokens
+}
+
+func estimateMessagePartTokens(part fantasy.MessagePart) int64 {
+	switch p := part.(type) {
+	case fantasy.TextPart:
+		return approxTokenCount(p.Text)
+	case *fantasy.TextPart:
+		return approxTokenCount(p.Text)
+	case fantasy.ReasoningPart:
+		return approxTokenCount(p.Text)
+	case *fantasy.ReasoningPart:
+		return approxTokenCount(p.Text)
+	case fantasy.FilePart:
+		return estimateFilePartTokens(p)
+	case *fantasy.FilePart:
+		return estimateFilePartTokens(*p)
+	case fantasy.ToolCallPart:
+		return estimateToolCallTokens(p.ToolName, p.Input)
+	case *fantasy.ToolCallPart:
+		return estimateToolCallTokens(p.ToolName, p.Input)
+	case fantasy.ToolResultPart:
+		return estimateToolResultContentTokens(p.ToolCallID, "", "", p.Output)
+	case *fantasy.ToolResultPart:
+		return estimateToolResultContentTokens(p.ToolCallID, "", "", p.Output)
+	default:
+		return 0
+	}
+}
+
+func estimateToolCallTokens(toolName, input string) int64 {
+	return approxTokenCount(toolName) + approxTokenCount(input)
+}
+
+func estimateToolResultContentTokens(toolCallID, toolName, metadata string, output fantasy.ToolResultOutputContent) int64 {
+	tokens := approxTokenCount(toolCallID) + approxTokenCount(toolName) + approxTokenCount(metadata)
+	switch result := output.(type) {
+	case fantasy.ToolResultOutputContentText:
+		tokens += approxTokenCount(result.Text)
+	case *fantasy.ToolResultOutputContentText:
+		tokens += approxTokenCount(result.Text)
+	case fantasy.ToolResultOutputContentError:
+		if result.Error != nil {
+			tokens += approxTokenCount(result.Error.Error())
+		}
+	case *fantasy.ToolResultOutputContentError:
+		if result.Error != nil {
+			tokens += approxTokenCount(result.Error.Error())
+		}
+	case fantasy.ToolResultOutputContentMedia:
+		tokens += estimateMediaTokens(result.MediaType, result.Text, len(result.Data))
+	case *fantasy.ToolResultOutputContentMedia:
+		tokens += estimateMediaTokens(result.MediaType, result.Text, len(result.Data))
+	}
+	return tokens
+}
+
+func estimateFilePartTokens(file fantasy.FilePart) int64 {
+	return estimateMediaTokens(file.MediaType, file.Filename, len(file.Data))
+}
+
+func estimateGeneratedFileTokens(file fantasy.FileContent) int64 {
+	return estimateMediaTokens(file.MediaType, "", len(file.Data))
+}
+
+func estimateMediaTokens(mediaType, text string, dataBytes int) int64 {
+	if dataBytes == 0 {
+		return approxTokenCount(mediaType) + approxTokenCount(text)
+	}
+	return approxTokenCount(fmt.Sprintf("%s %s %d bytes", mediaType, text, dataBytes))
+}
+
+func estimateSourceTokens(source fantasy.SourceContent) int64 {
+	return approxTokenCount(string(source.SourceType)) +
+		approxTokenCount(source.ID) +
+		approxTokenCount(source.URL) +
+		approxTokenCount(source.Title) +
+		approxTokenCount(source.MediaType) +
+		approxTokenCount(source.Filename)
+}
+
+func approxTokenCount(s string) int64 {
+	if s == "" {
+		return 0
+	}
+	return int64((len(s) + 3) / 4)
+}

internal/agent/usage_fallback_test.go 🔗

@@ -0,0 +1,203 @@
+package agent
+
+import (
+	"errors"
+	"testing"
+
+	"charm.land/catwalk/pkg/catwalk"
+	"charm.land/fantasy"
+	"github.com/charmbracelet/crush/internal/session"
+	"github.com/stretchr/testify/require"
+)
+
+func TestUsageIsZero(t *testing.T) {
+	t.Parallel()
+
+	require.True(t, usageIsZero(fantasy.Usage{}))
+	require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
+	require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
+	require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
+	require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
+	require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
+	require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
+}
+
+func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
+	t.Parallel()
+
+	usage := fantasy.Usage{
+		InputTokens:  10,
+		OutputTokens: 5,
+		TotalTokens:  15,
+	}
+	step := fantasy.StepResult{
+		Response: fantasy.Response{Usage: usage},
+	}
+
+	fallbackUsage, estimated := fallbackStepUsage(nil, step)
+	require.False(t, estimated)
+	require.Equal(t, usage, fallbackUsage)
+}
+
+func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
+	t.Parallel()
+
+	messages := []fantasy.Message{
+		fantasy.NewUserMessage("please explain the implementation details"),
+	}
+	step := fantasy.StepResult{
+		Response: fantasy.Response{
+			Content: fantasy.ResponseContent{
+				fantasy.TextContent{Text: "the implementation stores state safely"},
+			},
+		},
+	}
+
+	usage, estimated := fallbackStepUsage(messages, step)
+	require.True(t, estimated)
+	require.Positive(t, usage.InputTokens)
+	require.Positive(t, usage.OutputTokens)
+	require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
+}
+
+func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
+	t.Parallel()
+
+	messages := []fantasy.Message{
+		{
+			Role: fantasy.MessageRoleAssistant,
+			Content: []fantasy.MessagePart{
+				fantasy.ReasoningPart{Text: "first reason about the request"},
+			},
+		},
+	}
+	step := fantasy.StepResult{
+		Response: fantasy.Response{
+			Content: fantasy.ResponseContent{
+				fantasy.ReasoningContent{Text: "second reason about the answer"},
+			},
+		},
+	}
+
+	usage, estimated := fallbackStepUsage(messages, step)
+	require.True(t, estimated)
+	require.Positive(t, usage.InputTokens)
+	require.Positive(t, usage.OutputTokens)
+}
+
+func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
+	t.Parallel()
+
+	step := fantasy.StepResult{
+		Response: fantasy.Response{
+			Content: fantasy.ResponseContent{
+				fantasy.ToolCallContent{
+					ToolCallID: "tool-call-1",
+					ToolName:   "view",
+					Input:      `{"file_path":"/tmp/example.go"}`,
+				},
+			},
+		},
+	}
+
+	usage, estimated := fallbackStepUsage(nil, step)
+	require.True(t, estimated)
+	require.Zero(t, usage.InputTokens)
+	require.Positive(t, usage.OutputTokens)
+	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
+}
+
+func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
+	t.Parallel()
+
+	messages := []fantasy.Message{
+		{
+			Role: fantasy.MessageRoleTool,
+			Content: []fantasy.MessagePart{
+				fantasy.ToolResultPart{
+					ToolCallID: "tool-call-1",
+					Output: fantasy.ToolResultOutputContentText{
+						Text: "file contents returned by the tool",
+					},
+				},
+				fantasy.ToolResultPart{
+					ToolCallID: "tool-call-2",
+					Output: fantasy.ToolResultOutputContentError{
+						Error: errors.New("permission denied"),
+					},
+				},
+				fantasy.ToolResultPart{
+					ToolCallID: "tool-call-3",
+					Output: fantasy.ToolResultOutputContentMedia{
+						MediaType: "image/png",
+						Text:      "screenshot",
+						Data:      "abc123",
+					},
+				},
+			},
+		},
+	}
+
+	usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
+	require.True(t, estimated)
+	require.Positive(t, usage.InputTokens)
+	require.Zero(t, usage.OutputTokens)
+	require.Equal(t, usage.InputTokens, usage.TotalTokens)
+}
+
+func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
+	t.Parallel()
+
+	usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
+	require.False(t, estimated)
+	require.True(t, usageIsZero(usage))
+}
+
+func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
+	t.Parallel()
+
+	agent := &sessionAgent{}
+	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
+	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
+
+	agent.updateSessionUsage(model, currentSession, usage, nil, true)
+
+	require.Equal(t, 1.25, currentSession.Cost)
+	require.Equal(t, int64(1000), currentSession.PromptTokens)
+	require.Equal(t, int64(2000), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
+	t.Parallel()
+
+	agent := &sessionAgent{}
+	currentSession := &session.Session{
+		ID:               "session-id",
+		PromptTokens:     123,
+		CompletionTokens: 456,
+		Cost:             1.25,
+	}
+	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+
+	agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
+
+	require.Equal(t, 1.25, currentSession.Cost)
+	require.Equal(t, int64(123), currentSession.PromptTokens)
+	require.Equal(t, int64(456), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
+	t.Parallel()
+
+	agent := &sessionAgent{}
+	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
+	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
+
+	agent.updateSessionUsage(model, currentSession, usage, nil, false)
+
+	require.Equal(t, 1.3, currentSession.Cost)
+	require.Equal(t, int64(1000), currentSession.PromptTokens)
+	require.Equal(t, int64(2000), currentSession.CompletionTokens)
+}