fix(openai): subtract cached tokens from input tokens to avoid double counting (#176)

Andrey Nering created

OpenAI's API reports prompt_tokens/input_tokens INCLUDING cached tokens,
while also separately reporting cached_tokens in prompt_tokens_details.
This caused double-counting when users summed InputTokens + CacheReadTokens.

For example, if OpenAI reports:
  - prompt_tokens: 1000 (includes 900 cached)
  - cached_tokens: 900

Before this fix, fantasy reported:
  - InputTokens: 1000
  - CacheReadTokens: 900

After this fix, fantasy reports:
  - InputTokens: 100 (non-cached only)
  - CacheReadTokens: 900

This matches the behavior of Vercel AI SDK and prevents billing
miscalculations when pricing input tokens and cache read tokens separately.

See: https://platform.openai.com/docs/guides/prompt-caching#requirements

💘 Generated with Crush

Assisted-by: Kimi K2.5 via Crush <crush@charm.land>

Change summary

providers/openai/language_model_hooks.go     | 8 ++++++--
providers/openai/openai_test.go              | 6 ++++--
providers/openai/responses_language_model.go | 4 +++-
3 files changed, 13 insertions(+), 5 deletions(-)

Detailed changes

providers/openai/language_model_hooks.go 🔗

@@ -211,8 +211,10 @@ func DefaultUsageFunc(response openai.ChatCompletion) (fantasy.Usage, fantasy.Pr
 			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 		}
 	}
+	// OpenAI reports prompt_tokens INCLUDING cached tokens. Subtract to avoid double-counting.
+	inputTokens := max(response.Usage.PromptTokens-promptTokenDetails.CachedTokens, 0)
 	return fantasy.Usage{
-		InputTokens:     response.Usage.PromptTokens,
+		InputTokens:     inputTokens,
 		OutputTokens:    response.Usage.CompletionTokens,
 		TotalTokens:     response.Usage.TotalTokens,
 		ReasoningTokens: completionTokenDetails.ReasoningTokens,
@@ -237,8 +239,10 @@ func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, _ map[string]any,
 	// we do this here because the acc does not add prompt details
 	completionTokenDetails := chunk.Usage.CompletionTokensDetails
 	promptTokenDetails := chunk.Usage.PromptTokensDetails
+	// OpenAI reports prompt_tokens INCLUDING cached tokens. Subtract to avoid double-counting.
+	inputTokens := max(chunk.Usage.PromptTokens-promptTokenDetails.CachedTokens, 0)
 	usage := fantasy.Usage{
-		InputTokens:     chunk.Usage.PromptTokens,
+		InputTokens:     inputTokens,
 		OutputTokens:    chunk.Usage.CompletionTokens,
 		TotalTokens:     chunk.Usage.TotalTokens,
 		ReasoningTokens: completionTokenDetails.ReasoningTokens,

providers/openai/openai_test.go 🔗

@@ -1425,7 +1425,8 @@ func TestDoGenerate(t *testing.T) {
 
 		require.NoError(t, err)
 		require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
-		require.Equal(t, int64(15), result.Usage.InputTokens)
+		// InputTokens = prompt_tokens - cached_tokens = 15 - 1152 = -1137 → clamped to 0
+		require.Equal(t, int64(0), result.Usage.InputTokens)
 		require.Equal(t, int64(20), result.Usage.OutputTokens)
 		require.Equal(t, int64(35), result.Usage.TotalTokens)
 	})
@@ -2594,7 +2595,8 @@ func TestDoStream(t *testing.T) {
 
 		require.NotNil(t, finishPart)
 		require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
-		require.Equal(t, int64(15), finishPart.Usage.InputTokens)
+		// InputTokens = prompt_tokens - cached_tokens = 15 - 1152 = -1137 → clamped to 0
+		require.Equal(t, int64(0), finishPart.Usage.InputTokens)
 		require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
 		require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
 	})

providers/openai/responses_language_model.go 🔗

@@ -375,8 +375,10 @@ func responsesProviderMetadata(responseID string) fantasy.ProviderMetadata {
 }
 
 func responsesUsage(resp responses.Response) fantasy.Usage {
+	// OpenAI reports input_tokens INCLUDING cached tokens. Subtract to avoid double-counting.
+	inputTokens := max(resp.Usage.InputTokens-resp.Usage.InputTokensDetails.CachedTokens, 0)
 	usage := fantasy.Usage{
-		InputTokens:  resp.Usage.InputTokens,
+		InputTokens:  inputTokens,
 		OutputTokens: resp.Usage.OutputTokens,
 		TotalTokens:  resp.Usage.InputTokens + resp.Usage.OutputTokens,
 	}