From 6ed8852b621faff72fe3f791efdc41678b439b9e Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 12 May 2026 15:16:59 -0700 Subject: [PATCH] fix(agent): estimate missing streamed usage Add a fallback token estimator for streamed steps that return zero usage so session context pressure remains accurate when providers omit final usage chunks. Estimated usage updates prompt/completion counters but never contributes cost, while provider-reported usage continues to preserve normal cost accounting and OpenRouter overrides. Zero-usage updates now leave existing nonzero token counters intact. --- internal/agent/agent.go | 42 ++++-- internal/agent/usage_fallback.go | 172 ++++++++++++++++++++++ internal/agent/usage_fallback_test.go | 203 ++++++++++++++++++++++++++ 3 files changed, 405 insertions(+), 12 deletions(-) create mode 100644 internal/agent/usage_fallback.go create mode 100644 internal/agent/usage_fallback_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index e8707a3e2b3e35281e57a9d396adbe81c5b7ebbd..c54331910b13a0e5a7ef747451be32895596ab65 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -261,6 +261,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy a.eventPromptSent(call.SessionID) var currentAssistant *message.Message + var stepMessages []fantasy.Message var shouldSummarize bool // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it var maxOutputTokens *int64 @@ -319,6 +320,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...) } + sessionLock.Lock() + stepMessages = cloneFantasyMessages(prepared.Messages) + sessionLock.Unlock() + var assistantMsg message.Message assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{ Role: message.Assistant, @@ -444,7 +449,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy if getSessionErr != nil { return getSessionErr } - a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) + usage, estimated := fallbackStepUsage(stepMessages, stepResult) + a.updateSessionUsage(largeModel, &updatedSession, usage, a.openrouterCost(stepResult.ProviderMetadata), estimated) _, sessionErr := a.sessions.Save(ctx, updatedSession) if sessionErr != nil { return sessionErr @@ -749,7 +755,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan } } - a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost) + a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost, false) // Just in case, get just the last usage info. usage := resp.Response.Usage @@ -1132,28 +1138,40 @@ func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float6 return &opts.Usage.Cost } -func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) { +func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64, estimated bool) { modelConfig := model.CatwalkCfg cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) + modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens) - a.eventTokensUsed(session.ID, model, usage, cost) - - // Use override cost if available (e.g., from OpenRouter). - if overrideCost != nil { - cost = *overrideCost + eventCost := cost + if estimated { + eventCost = 0 } + a.eventTokensUsed(session.ID, model, usage, eventCost) - // Skip cost accumulation - if model.FlatRate { + if estimated { cost = 0 + } else { + // Use override cost if available (e.g., from OpenRouter). + if overrideCost != nil { + cost = *overrideCost + } + + // Skip cost accumulation + if model.FlatRate { + cost = 0 + } } session.Cost += cost - session.CompletionTokens = usage.OutputTokens - session.PromptTokens = usage.InputTokens + usage.CacheReadTokens + if usage.OutputTokens != 0 { + session.CompletionTokens = usage.OutputTokens + } + if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 { + session.PromptTokens = promptTokens + } } func (a *sessionAgent) Cancel(sessionID string) { diff --git a/internal/agent/usage_fallback.go b/internal/agent/usage_fallback.go new file mode 100644 index 0000000000000000000000000000000000000000..78903c57a6c8c54a3598964c1c132579d0cad36d --- /dev/null +++ b/internal/agent/usage_fallback.go @@ -0,0 +1,172 @@ +package agent + +import ( + "fmt" + + "charm.land/fantasy" +) + +func usageIsZero(usage fantasy.Usage) bool { + return usage.InputTokens == 0 && + usage.OutputTokens == 0 && + usage.TotalTokens == 0 && + usage.ReasoningTokens == 0 && + usage.CacheCreationTokens == 0 && + usage.CacheReadTokens == 0 +} + +func fallbackStepUsage(messages []fantasy.Message, step fantasy.StepResult) (fantasy.Usage, bool) { + if !usageIsZero(step.Usage) { + return step.Usage, false + } + + inputTokens := estimateMessageTokens(messages) + outputTokens := estimateStepCompletionTokens(step) + if inputTokens == 0 && outputTokens == 0 { + return fantasy.Usage{}, false + } + + return fantasy.Usage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, + }, true +} + +func cloneFantasyMessages(messages []fantasy.Message) []fantasy.Message { + cloned := make([]fantasy.Message, len(messages)) + for i, msg := range messages { + cloned[i] = msg + cloned[i].Content = append([]fantasy.MessagePart(nil), msg.Content...) + } + return cloned +} + +func estimateMessageTokens(messages []fantasy.Message) int64 { + var tokens int64 + for _, msg := range messages { + tokens += approxTokenCount(string(msg.Role)) + for _, part := range msg.Content { + tokens += estimateMessagePartTokens(part) + } + } + return tokens +} + +func estimateStepCompletionTokens(step fantasy.StepResult) int64 { + var tokens int64 + for _, content := range step.Content { + switch c := content.(type) { + case fantasy.TextContent: + tokens += approxTokenCount(c.Text) + case *fantasy.TextContent: + tokens += approxTokenCount(c.Text) + case fantasy.ReasoningContent: + tokens += approxTokenCount(c.Text) + case *fantasy.ReasoningContent: + tokens += approxTokenCount(c.Text) + case fantasy.FileContent: + tokens += estimateGeneratedFileTokens(c) + case *fantasy.FileContent: + tokens += estimateGeneratedFileTokens(*c) + case fantasy.SourceContent: + tokens += estimateSourceTokens(c) + case *fantasy.SourceContent: + tokens += estimateSourceTokens(*c) + case fantasy.ToolCallContent: + tokens += estimateToolCallTokens(c.ToolName, c.Input) + case *fantasy.ToolCallContent: + tokens += estimateToolCallTokens(c.ToolName, c.Input) + case fantasy.ToolResultContent: + tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result) + case *fantasy.ToolResultContent: + tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result) + } + } + return tokens +} + +func estimateMessagePartTokens(part fantasy.MessagePart) int64 { + switch p := part.(type) { + case fantasy.TextPart: + return approxTokenCount(p.Text) + case *fantasy.TextPart: + return approxTokenCount(p.Text) + case fantasy.ReasoningPart: + return approxTokenCount(p.Text) + case *fantasy.ReasoningPart: + return approxTokenCount(p.Text) + case fantasy.FilePart: + return estimateFilePartTokens(p) + case *fantasy.FilePart: + return estimateFilePartTokens(*p) + case fantasy.ToolCallPart: + return estimateToolCallTokens(p.ToolName, p.Input) + case *fantasy.ToolCallPart: + return estimateToolCallTokens(p.ToolName, p.Input) + case fantasy.ToolResultPart: + return estimateToolResultContentTokens(p.ToolCallID, "", "", p.Output) + case *fantasy.ToolResultPart: + return estimateToolResultContentTokens(p.ToolCallID, "", "", p.Output) + default: + return 0 + } +} + +func estimateToolCallTokens(toolName, input string) int64 { + return approxTokenCount(toolName) + approxTokenCount(input) +} + +func estimateToolResultContentTokens(toolCallID, toolName, metadata string, output fantasy.ToolResultOutputContent) int64 { + tokens := approxTokenCount(toolCallID) + approxTokenCount(toolName) + approxTokenCount(metadata) + switch result := output.(type) { + case fantasy.ToolResultOutputContentText: + tokens += approxTokenCount(result.Text) + case *fantasy.ToolResultOutputContentText: + tokens += approxTokenCount(result.Text) + case fantasy.ToolResultOutputContentError: + if result.Error != nil { + tokens += approxTokenCount(result.Error.Error()) + } + case *fantasy.ToolResultOutputContentError: + if result.Error != nil { + tokens += approxTokenCount(result.Error.Error()) + } + case fantasy.ToolResultOutputContentMedia: + tokens += estimateMediaTokens(result.MediaType, result.Text, len(result.Data)) + case *fantasy.ToolResultOutputContentMedia: + tokens += estimateMediaTokens(result.MediaType, result.Text, len(result.Data)) + } + return tokens +} + +func estimateFilePartTokens(file fantasy.FilePart) int64 { + return estimateMediaTokens(file.MediaType, file.Filename, len(file.Data)) +} + +func estimateGeneratedFileTokens(file fantasy.FileContent) int64 { + return estimateMediaTokens(file.MediaType, "", len(file.Data)) +} + +func estimateMediaTokens(mediaType, text string, dataBytes int) int64 { + if dataBytes == 0 { + return approxTokenCount(mediaType) + approxTokenCount(text) + } + return approxTokenCount(fmt.Sprintf("%s %s %d bytes", mediaType, text, dataBytes)) +} + +func estimateSourceTokens(source fantasy.SourceContent) int64 { + return approxTokenCount(string(source.SourceType)) + + approxTokenCount(source.ID) + + approxTokenCount(source.URL) + + approxTokenCount(source.Title) + + approxTokenCount(source.MediaType) + + approxTokenCount(source.Filename) +} + +func approxTokenCount(s string) int64 { + if s == "" { + return 0 + } + return int64((len(s) + 3) / 4) +} diff --git a/internal/agent/usage_fallback_test.go b/internal/agent/usage_fallback_test.go new file mode 100644 index 0000000000000000000000000000000000000000..96263f214a1be035dc5978827c286b81dde2e0ba --- /dev/null +++ b/internal/agent/usage_fallback_test.go @@ -0,0 +1,203 @@ +package agent + +import ( + "errors" + "testing" + + "charm.land/catwalk/pkg/catwalk" + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/session" + "github.com/stretchr/testify/require" +) + +func TestUsageIsZero(t *testing.T) { + t.Parallel() + + require.True(t, usageIsZero(fantasy.Usage{})) + require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1})) + require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1})) + require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1})) + require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1})) + require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1})) + require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1})) +} + +func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) { + t.Parallel() + + usage := fantasy.Usage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + } + step := fantasy.StepResult{ + Response: fantasy.Response{Usage: usage}, + } + + fallbackUsage, estimated := fallbackStepUsage(nil, step) + require.False(t, estimated) + require.Equal(t, usage, fallbackUsage) +} + +func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) { + t.Parallel() + + messages := []fantasy.Message{ + fantasy.NewUserMessage("please explain the implementation details"), + } + step := fantasy.StepResult{ + Response: fantasy.Response{ + Content: fantasy.ResponseContent{ + fantasy.TextContent{Text: "the implementation stores state safely"}, + }, + }, + } + + usage, estimated := fallbackStepUsage(messages, step) + require.True(t, estimated) + require.Positive(t, usage.InputTokens) + require.Positive(t, usage.OutputTokens) + require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens) +} + +func TestFallbackStepUsageEstimatesReasoning(t *testing.T) { + t.Parallel() + + messages := []fantasy.Message{ + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ReasoningPart{Text: "first reason about the request"}, + }, + }, + } + step := fantasy.StepResult{ + Response: fantasy.Response{ + Content: fantasy.ResponseContent{ + fantasy.ReasoningContent{Text: "second reason about the answer"}, + }, + }, + } + + usage, estimated := fallbackStepUsage(messages, step) + require.True(t, estimated) + require.Positive(t, usage.InputTokens) + require.Positive(t, usage.OutputTokens) +} + +func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) { + t.Parallel() + + step := fantasy.StepResult{ + Response: fantasy.Response{ + Content: fantasy.ResponseContent{ + fantasy.ToolCallContent{ + ToolCallID: "tool-call-1", + ToolName: "view", + Input: `{"file_path":"/tmp/example.go"}`, + }, + }, + }, + } + + usage, estimated := fallbackStepUsage(nil, step) + require.True(t, estimated) + require.Zero(t, usage.InputTokens) + require.Positive(t, usage.OutputTokens) + require.Equal(t, usage.OutputTokens, usage.TotalTokens) +} + +func TestFallbackStepUsageEstimatesToolResults(t *testing.T) { + t.Parallel() + + messages := []fantasy.Message{ + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: "tool-call-1", + Output: fantasy.ToolResultOutputContentText{ + Text: "file contents returned by the tool", + }, + }, + fantasy.ToolResultPart{ + ToolCallID: "tool-call-2", + Output: fantasy.ToolResultOutputContentError{ + Error: errors.New("permission denied"), + }, + }, + fantasy.ToolResultPart{ + ToolCallID: "tool-call-3", + Output: fantasy.ToolResultOutputContentMedia{ + MediaType: "image/png", + Text: "screenshot", + Data: "abc123", + }, + }, + }, + }, + } + + usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{}) + require.True(t, estimated) + require.Positive(t, usage.InputTokens) + require.Zero(t, usage.OutputTokens) + require.Equal(t, usage.InputTokens, usage.TotalTokens) +} + +func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) { + t.Parallel() + + usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{}) + require.False(t, estimated) + require.True(t, usageIsZero(usage)) +} + +func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) { + t.Parallel() + + agent := &sessionAgent{} + currentSession := &session.Session{ID: "session-id", Cost: 1.25} + model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}} + usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000} + + agent.updateSessionUsage(model, currentSession, usage, nil, true) + + require.Equal(t, 1.25, currentSession.Cost) + require.Equal(t, int64(1000), currentSession.PromptTokens) + require.Equal(t, int64(2000), currentSession.CompletionTokens) +} + +func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) { + t.Parallel() + + agent := &sessionAgent{} + currentSession := &session.Session{ + ID: "session-id", + PromptTokens: 123, + CompletionTokens: 456, + Cost: 1.25, + } + model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}} + + agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false) + + require.Equal(t, 1.25, currentSession.Cost) + require.Equal(t, int64(123), currentSession.PromptTokens) + require.Equal(t, int64(456), currentSession.CompletionTokens) +} + +func TestUpdateSessionUsageAddsProviderCost(t *testing.T) { + t.Parallel() + + agent := &sessionAgent{} + currentSession := &session.Session{ID: "session-id", Cost: 1.25} + model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}} + usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000} + + agent.updateSessionUsage(model, currentSession, usage, nil, false) + + require.Equal(t, 1.3, currentSession.Cost) + require.Equal(t, int64(1000), currentSession.PromptTokens) + require.Equal(t, int64(2000), currentSession.CompletionTokens) +}