@@ -261,6 +261,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
a.eventPromptSent(call.SessionID)
var currentAssistant *message.Message
+ var stepMessages []fantasy.Message
var shouldSummarize bool
// Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
var maxOutputTokens *int64
@@ -319,6 +320,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
}
+ sessionLock.Lock()
+ stepMessages = cloneFantasyMessages(prepared.Messages)
+ sessionLock.Unlock()
+
var assistantMsg message.Message
assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
Role: message.Assistant,
@@ -444,7 +449,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
if getSessionErr != nil {
return getSessionErr
}
- a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
+ usage, estimated := fallbackStepUsage(stepMessages, stepResult)
+ a.updateSessionUsage(largeModel, &updatedSession, usage, a.openrouterCost(stepResult.ProviderMetadata), estimated)
_, sessionErr := a.sessions.Save(ctx, updatedSession)
if sessionErr != nil {
return sessionErr
@@ -749,7 +755,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
}
}
- a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
+ a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost, false)
// Just in case, get just the last usage info.
usage := resp.Response.Usage
@@ -1132,28 +1138,40 @@ func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float6
return &opts.Usage.Cost
}
-func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
+func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64, estimated bool) {
modelConfig := model.CatwalkCfg
cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
- a.eventTokensUsed(session.ID, model, usage, cost)
-
- // Use override cost if available (e.g., from OpenRouter).
- if overrideCost != nil {
- cost = *overrideCost
+ eventCost := cost
+ if estimated {
+ eventCost = 0
}
+ a.eventTokensUsed(session.ID, model, usage, eventCost)
- // Skip cost accumulation
- if model.FlatRate {
+ if estimated {
cost = 0
+ } else {
+ // Use override cost if available (e.g., from OpenRouter).
+ if overrideCost != nil {
+ cost = *overrideCost
+ }
+
+ // Skip cost accumulation
+ if model.FlatRate {
+ cost = 0
+ }
}
session.Cost += cost
- session.CompletionTokens = usage.OutputTokens
- session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
+ if usage.OutputTokens != 0 {
+ session.CompletionTokens = usage.OutputTokens
+ }
+ if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
+ session.PromptTokens = promptTokens
+ }
}
func (a *sessionAgent) Cancel(sessionID string) {
@@ -0,0 +1,172 @@
+package agent
+
+import (
+ "fmt"
+
+ "charm.land/fantasy"
+)
+
+func usageIsZero(usage fantasy.Usage) bool {
+ return usage.InputTokens == 0 &&
+ usage.OutputTokens == 0 &&
+ usage.TotalTokens == 0 &&
+ usage.ReasoningTokens == 0 &&
+ usage.CacheCreationTokens == 0 &&
+ usage.CacheReadTokens == 0
+}
+
+func fallbackStepUsage(messages []fantasy.Message, step fantasy.StepResult) (fantasy.Usage, bool) {
+ if !usageIsZero(step.Usage) {
+ return step.Usage, false
+ }
+
+ inputTokens := estimateMessageTokens(messages)
+ outputTokens := estimateStepCompletionTokens(step)
+ if inputTokens == 0 && outputTokens == 0 {
+ return fantasy.Usage{}, false
+ }
+
+ return fantasy.Usage{
+ InputTokens: inputTokens,
+ OutputTokens: outputTokens,
+ TotalTokens: inputTokens + outputTokens,
+ }, true
+}
+
+func cloneFantasyMessages(messages []fantasy.Message) []fantasy.Message {
+ cloned := make([]fantasy.Message, len(messages))
+ for i, msg := range messages {
+ cloned[i] = msg
+ cloned[i].Content = append([]fantasy.MessagePart(nil), msg.Content...)
+ }
+ return cloned
+}
+
+func estimateMessageTokens(messages []fantasy.Message) int64 {
+ var tokens int64
+ for _, msg := range messages {
+ tokens += approxTokenCount(string(msg.Role))
+ for _, part := range msg.Content {
+ tokens += estimateMessagePartTokens(part)
+ }
+ }
+ return tokens
+}
+
+func estimateStepCompletionTokens(step fantasy.StepResult) int64 {
+ var tokens int64
+ for _, content := range step.Content {
+ switch c := content.(type) {
+ case fantasy.TextContent:
+ tokens += approxTokenCount(c.Text)
+ case *fantasy.TextContent:
+ tokens += approxTokenCount(c.Text)
+ case fantasy.ReasoningContent:
+ tokens += approxTokenCount(c.Text)
+ case *fantasy.ReasoningContent:
+ tokens += approxTokenCount(c.Text)
+ case fantasy.FileContent:
+ tokens += estimateGeneratedFileTokens(c)
+ case *fantasy.FileContent:
+ tokens += estimateGeneratedFileTokens(*c)
+ case fantasy.SourceContent:
+ tokens += estimateSourceTokens(c)
+ case *fantasy.SourceContent:
+ tokens += estimateSourceTokens(*c)
+ case fantasy.ToolCallContent:
+ tokens += estimateToolCallTokens(c.ToolName, c.Input)
+ case *fantasy.ToolCallContent:
+ tokens += estimateToolCallTokens(c.ToolName, c.Input)
+ case fantasy.ToolResultContent:
+ tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result)
+ case *fantasy.ToolResultContent:
+ tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result)
+ }
+ }
+ return tokens
+}
+
+func estimateMessagePartTokens(part fantasy.MessagePart) int64 {
+ switch p := part.(type) {
+ case fantasy.TextPart:
+ return approxTokenCount(p.Text)
+ case *fantasy.TextPart:
+ return approxTokenCount(p.Text)
+ case fantasy.ReasoningPart:
+ return approxTokenCount(p.Text)
+ case *fantasy.ReasoningPart:
+ return approxTokenCount(p.Text)
+ case fantasy.FilePart:
+ return estimateFilePartTokens(p)
+ case *fantasy.FilePart:
+ return estimateFilePartTokens(*p)
+ case fantasy.ToolCallPart:
+ return estimateToolCallTokens(p.ToolName, p.Input)
+ case *fantasy.ToolCallPart:
+ return estimateToolCallTokens(p.ToolName, p.Input)
+ case fantasy.ToolResultPart:
+ return estimateToolResultContentTokens(p.ToolCallID, "", "", p.Output)
+ case *fantasy.ToolResultPart:
+ return estimateToolResultContentTokens(p.ToolCallID, "", "", p.Output)
+ default:
+ return 0
+ }
+}
+
+func estimateToolCallTokens(toolName, input string) int64 {
+ return approxTokenCount(toolName) + approxTokenCount(input)
+}
+
+func estimateToolResultContentTokens(toolCallID, toolName, metadata string, output fantasy.ToolResultOutputContent) int64 {
+ tokens := approxTokenCount(toolCallID) + approxTokenCount(toolName) + approxTokenCount(metadata)
+ switch result := output.(type) {
+ case fantasy.ToolResultOutputContentText:
+ tokens += approxTokenCount(result.Text)
+ case *fantasy.ToolResultOutputContentText:
+ tokens += approxTokenCount(result.Text)
+ case fantasy.ToolResultOutputContentError:
+ if result.Error != nil {
+ tokens += approxTokenCount(result.Error.Error())
+ }
+ case *fantasy.ToolResultOutputContentError:
+ if result.Error != nil {
+ tokens += approxTokenCount(result.Error.Error())
+ }
+ case fantasy.ToolResultOutputContentMedia:
+ tokens += estimateMediaTokens(result.MediaType, result.Text, len(result.Data))
+ case *fantasy.ToolResultOutputContentMedia:
+ tokens += estimateMediaTokens(result.MediaType, result.Text, len(result.Data))
+ }
+ return tokens
+}
+
+func estimateFilePartTokens(file fantasy.FilePart) int64 {
+ return estimateMediaTokens(file.MediaType, file.Filename, len(file.Data))
+}
+
+func estimateGeneratedFileTokens(file fantasy.FileContent) int64 {
+ return estimateMediaTokens(file.MediaType, "", len(file.Data))
+}
+
+func estimateMediaTokens(mediaType, text string, dataBytes int) int64 {
+ if dataBytes == 0 {
+ return approxTokenCount(mediaType) + approxTokenCount(text)
+ }
+ return approxTokenCount(fmt.Sprintf("%s %s %d bytes", mediaType, text, dataBytes))
+}
+
+func estimateSourceTokens(source fantasy.SourceContent) int64 {
+ return approxTokenCount(string(source.SourceType)) +
+ approxTokenCount(source.ID) +
+ approxTokenCount(source.URL) +
+ approxTokenCount(source.Title) +
+ approxTokenCount(source.MediaType) +
+ approxTokenCount(source.Filename)
+}
+
+func approxTokenCount(s string) int64 {
+ if s == "" {
+ return 0
+ }
+ return int64((len(s) + 3) / 4)
+}
@@ -0,0 +1,203 @@
+package agent
+
+import (
+ "errors"
+ "testing"
+
+ "charm.land/catwalk/pkg/catwalk"
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/session"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUsageIsZero(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, usageIsZero(fantasy.Usage{}))
+ require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
+ require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
+ require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
+ require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
+ require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
+ require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
+}
+
+func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
+ t.Parallel()
+
+ usage := fantasy.Usage{
+ InputTokens: 10,
+ OutputTokens: 5,
+ TotalTokens: 15,
+ }
+ step := fantasy.StepResult{
+ Response: fantasy.Response{Usage: usage},
+ }
+
+ fallbackUsage, estimated := fallbackStepUsage(nil, step)
+ require.False(t, estimated)
+ require.Equal(t, usage, fallbackUsage)
+}
+
+func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
+ t.Parallel()
+
+ messages := []fantasy.Message{
+ fantasy.NewUserMessage("please explain the implementation details"),
+ }
+ step := fantasy.StepResult{
+ Response: fantasy.Response{
+ Content: fantasy.ResponseContent{
+ fantasy.TextContent{Text: "the implementation stores state safely"},
+ },
+ },
+ }
+
+ usage, estimated := fallbackStepUsage(messages, step)
+ require.True(t, estimated)
+ require.Positive(t, usage.InputTokens)
+ require.Positive(t, usage.OutputTokens)
+ require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
+}
+
+func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
+ t.Parallel()
+
+ messages := []fantasy.Message{
+ {
+ Role: fantasy.MessageRoleAssistant,
+ Content: []fantasy.MessagePart{
+ fantasy.ReasoningPart{Text: "first reason about the request"},
+ },
+ },
+ }
+ step := fantasy.StepResult{
+ Response: fantasy.Response{
+ Content: fantasy.ResponseContent{
+ fantasy.ReasoningContent{Text: "second reason about the answer"},
+ },
+ },
+ }
+
+ usage, estimated := fallbackStepUsage(messages, step)
+ require.True(t, estimated)
+ require.Positive(t, usage.InputTokens)
+ require.Positive(t, usage.OutputTokens)
+}
+
+func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
+ t.Parallel()
+
+ step := fantasy.StepResult{
+ Response: fantasy.Response{
+ Content: fantasy.ResponseContent{
+ fantasy.ToolCallContent{
+ ToolCallID: "tool-call-1",
+ ToolName: "view",
+ Input: `{"file_path":"/tmp/example.go"}`,
+ },
+ },
+ },
+ }
+
+ usage, estimated := fallbackStepUsage(nil, step)
+ require.True(t, estimated)
+ require.Zero(t, usage.InputTokens)
+ require.Positive(t, usage.OutputTokens)
+ require.Equal(t, usage.OutputTokens, usage.TotalTokens)
+}
+
+func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
+ t.Parallel()
+
+ messages := []fantasy.Message{
+ {
+ Role: fantasy.MessageRoleTool,
+ Content: []fantasy.MessagePart{
+ fantasy.ToolResultPart{
+ ToolCallID: "tool-call-1",
+ Output: fantasy.ToolResultOutputContentText{
+ Text: "file contents returned by the tool",
+ },
+ },
+ fantasy.ToolResultPart{
+ ToolCallID: "tool-call-2",
+ Output: fantasy.ToolResultOutputContentError{
+ Error: errors.New("permission denied"),
+ },
+ },
+ fantasy.ToolResultPart{
+ ToolCallID: "tool-call-3",
+ Output: fantasy.ToolResultOutputContentMedia{
+ MediaType: "image/png",
+ Text: "screenshot",
+ Data: "abc123",
+ },
+ },
+ },
+ },
+ }
+
+ usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
+ require.True(t, estimated)
+ require.Positive(t, usage.InputTokens)
+ require.Zero(t, usage.OutputTokens)
+ require.Equal(t, usage.InputTokens, usage.TotalTokens)
+}
+
+func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
+ t.Parallel()
+
+ usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
+ require.False(t, estimated)
+ require.True(t, usageIsZero(usage))
+}
+
+func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
+ t.Parallel()
+
+ agent := &sessionAgent{}
+ currentSession := &session.Session{ID: "session-id", Cost: 1.25}
+ model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+ usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
+
+ agent.updateSessionUsage(model, currentSession, usage, nil, true)
+
+ require.Equal(t, 1.25, currentSession.Cost)
+ require.Equal(t, int64(1000), currentSession.PromptTokens)
+ require.Equal(t, int64(2000), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
+ t.Parallel()
+
+ agent := &sessionAgent{}
+ currentSession := &session.Session{
+ ID: "session-id",
+ PromptTokens: 123,
+ CompletionTokens: 456,
+ Cost: 1.25,
+ }
+ model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+
+ agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
+
+ require.Equal(t, 1.25, currentSession.Cost)
+ require.Equal(t, int64(123), currentSession.PromptTokens)
+ require.Equal(t, int64(456), currentSession.CompletionTokens)
+}
+
+func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
+ t.Parallel()
+
+ agent := &sessionAgent{}
+ currentSession := &session.Session{ID: "session-id", Cost: 1.25}
+ model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+ usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
+
+ agent.updateSessionUsage(model, currentSession, usage, nil, false)
+
+ require.Equal(t, 1.3, currentSession.Cost)
+ require.Equal(t, int64(1000), currentSession.PromptTokens)
+ require.Equal(t, int64(2000), currentSession.CompletionTokens)
+}