fix(agent): correct fallback usage accounting

Greg Slepak created

Change summary

internal/agent/agent.go               |  6 -
internal/agent/usage_fallback.go      |  8 ++
internal/agent/usage_fallback_test.go | 63 +++++++++++++++++++++++++++++
3 files changed, 71 insertions(+), 6 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -1166,11 +1166,9 @@ func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session,
 	}
 
 	session.Cost += cost
-	if usage.OutputTokens != 0 {
+	if !usageIsZero(usage) {
 		session.CompletionTokens = usage.OutputTokens
-	}
-	if promptTokens := usage.InputTokens + usage.CacheReadTokens; promptTokens != 0 {
-		session.PromptTokens = promptTokens
+		session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
 	}
 }
 

internal/agent/usage_fallback.go 🔗

@@ -78,9 +78,13 @@ func estimateStepCompletionTokens(step fantasy.StepResult) int64 {
 		case *fantasy.ToolCallContent:
 			tokens += estimateToolCallTokens(c.ToolName, c.Input)
 		case fantasy.ToolResultContent:
-			tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result)
+			if c.ProviderExecuted {
+				tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result)
+			}
 		case *fantasy.ToolResultContent:
-			tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result)
+			if c.ProviderExecuted {
+				tokens += estimateToolResultContentTokens(c.ToolCallID, c.ToolName, c.ClientMetadata, c.Result)
+			}
 		}
 	}
 	return tokens

internal/agent/usage_fallback_test.go 🔗

@@ -145,6 +145,51 @@ func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
 	require.Equal(t, usage.InputTokens, usage.TotalTokens)
 }
 
+func TestFallbackStepUsageSkipsClientToolResultsAsOutput(t *testing.T) {
+	t.Parallel()
+
+	step := fantasy.StepResult{
+		Response: fantasy.Response{
+			Content: fantasy.ResponseContent{
+				fantasy.ToolResultContent{
+					ToolCallID: "tool-call-1",
+					ToolName:   "bash",
+					Result: fantasy.ToolResultOutputContentText{
+						Text: "large client-executed payload that should not count as model output tokens",
+					},
+				},
+			},
+		},
+	}
+
+	usage, estimated := fallbackStepUsage(nil, step)
+	require.False(t, estimated)
+	require.Zero(t, usage.OutputTokens)
+}
+
+func TestFallbackStepUsageCountsProviderToolResultsAsOutput(t *testing.T) {
+	t.Parallel()
+
+	step := fantasy.StepResult{
+		Response: fantasy.Response{
+			Content: fantasy.ResponseContent{
+				fantasy.ToolResultContent{
+					ToolCallID:       "tool-call-1",
+					ToolName:         "web_search",
+					ProviderExecuted: true,
+					ClientMetadata:   "provider metadata",
+					Result:           fantasy.ToolResultOutputContentText{Text: "provider-executed result"},
+				},
+			},
+		},
+	}
+
+	usage, estimated := fallbackStepUsage(nil, step)
+	require.True(t, estimated)
+	require.Positive(t, usage.OutputTokens)
+	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
+}
+
 func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
 	t.Parallel()
 
@@ -187,6 +232,24 @@ func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
 	require.Equal(t, int64(456), currentSession.CompletionTokens)
 }
 
+func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) {
+	t.Parallel()
+
+	agent := &sessionAgent{}
+	currentSession := &session.Session{
+		ID:               "session-id",
+		PromptTokens:     123,
+		CompletionTokens: 456,
+	}
+	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
+	usage := fantasy.Usage{InputTokens: 789}
+
+	agent.updateSessionUsage(model, currentSession, usage, nil, false)
+
+	require.Equal(t, int64(789), currentSession.PromptTokens)
+	require.Zero(t, currentSession.CompletionTokens)
+}
+
 func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
 	t.Parallel()