usage_fallback_test.go

  1package agent
  2
  3import (
  4	"errors"
  5	"testing"
  6
  7	"charm.land/catwalk/pkg/catwalk"
  8	"charm.land/fantasy"
  9	"github.com/charmbracelet/crush/internal/session"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13func TestUsageIsZero(t *testing.T) {
 14	t.Parallel()
 15
 16	require.True(t, usageIsZero(fantasy.Usage{}))
 17	require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
 18	require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
 19	require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
 20	require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
 21	require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
 22	require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
 23}
 24
 25func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
 26	t.Parallel()
 27
 28	usage := fantasy.Usage{
 29		InputTokens:  10,
 30		OutputTokens: 5,
 31		TotalTokens:  15,
 32	}
 33	step := fantasy.StepResult{
 34		Response: fantasy.Response{Usage: usage},
 35	}
 36
 37	fallbackUsage, estimated := fallbackStepUsage(nil, step)
 38	require.False(t, estimated)
 39	require.Equal(t, usage, fallbackUsage)
 40}
 41
 42func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
 43	t.Parallel()
 44
 45	messages := []fantasy.Message{
 46		fantasy.NewUserMessage("please explain the implementation details"),
 47	}
 48	step := fantasy.StepResult{
 49		Response: fantasy.Response{
 50			Content: fantasy.ResponseContent{
 51				fantasy.TextContent{Text: "the implementation stores state safely"},
 52			},
 53		},
 54	}
 55
 56	usage, estimated := fallbackStepUsage(messages, step)
 57	require.True(t, estimated)
 58	require.Positive(t, usage.InputTokens)
 59	require.Positive(t, usage.OutputTokens)
 60	require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
 61}
 62
 63func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
 64	t.Parallel()
 65
 66	messages := []fantasy.Message{
 67		{
 68			Role: fantasy.MessageRoleAssistant,
 69			Content: []fantasy.MessagePart{
 70				fantasy.ReasoningPart{Text: "first reason about the request"},
 71			},
 72		},
 73	}
 74	step := fantasy.StepResult{
 75		Response: fantasy.Response{
 76			Content: fantasy.ResponseContent{
 77				fantasy.ReasoningContent{Text: "second reason about the answer"},
 78			},
 79		},
 80	}
 81
 82	usage, estimated := fallbackStepUsage(messages, step)
 83	require.True(t, estimated)
 84	require.Positive(t, usage.InputTokens)
 85	require.Positive(t, usage.OutputTokens)
 86}
 87
 88func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
 89	t.Parallel()
 90
 91	step := fantasy.StepResult{
 92		Response: fantasy.Response{
 93			Content: fantasy.ResponseContent{
 94				fantasy.ToolCallContent{
 95					ToolCallID: "tool-call-1",
 96					ToolName:   "view",
 97					Input:      `{"file_path":"/tmp/example.go"}`,
 98				},
 99			},
100		},
101	}
102
103	usage, estimated := fallbackStepUsage(nil, step)
104	require.True(t, estimated)
105	require.Zero(t, usage.InputTokens)
106	require.Positive(t, usage.OutputTokens)
107	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
108}
109
110func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
111	t.Parallel()
112
113	messages := []fantasy.Message{
114		{
115			Role: fantasy.MessageRoleTool,
116			Content: []fantasy.MessagePart{
117				fantasy.ToolResultPart{
118					ToolCallID: "tool-call-1",
119					Output: fantasy.ToolResultOutputContentText{
120						Text: "file contents returned by the tool",
121					},
122				},
123				fantasy.ToolResultPart{
124					ToolCallID: "tool-call-2",
125					Output: fantasy.ToolResultOutputContentError{
126						Error: errors.New("permission denied"),
127					},
128				},
129				fantasy.ToolResultPart{
130					ToolCallID: "tool-call-3",
131					Output: fantasy.ToolResultOutputContentMedia{
132						MediaType: "image/png",
133						Text:      "screenshot",
134						Data:      "abc123",
135					},
136				},
137			},
138		},
139	}
140
141	usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
142	require.True(t, estimated)
143	require.Positive(t, usage.InputTokens)
144	require.Zero(t, usage.OutputTokens)
145	require.Equal(t, usage.InputTokens, usage.TotalTokens)
146}
147
148func TestFallbackStepUsageSkipsClientToolResultsAsOutput(t *testing.T) {
149	t.Parallel()
150
151	step := fantasy.StepResult{
152		Response: fantasy.Response{
153			Content: fantasy.ResponseContent{
154				fantasy.ToolResultContent{
155					ToolCallID: "tool-call-1",
156					ToolName:   "bash",
157					Result: fantasy.ToolResultOutputContentText{
158						Text: "large client-executed payload that should not count as model output tokens",
159					},
160				},
161			},
162		},
163	}
164
165	usage, estimated := fallbackStepUsage(nil, step)
166	require.False(t, estimated)
167	require.Zero(t, usage.OutputTokens)
168}
169
170func TestFallbackStepUsageCountsProviderToolResultsAsOutput(t *testing.T) {
171	t.Parallel()
172
173	step := fantasy.StepResult{
174		Response: fantasy.Response{
175			Content: fantasy.ResponseContent{
176				fantasy.ToolResultContent{
177					ToolCallID:       "tool-call-1",
178					ToolName:         "web_search",
179					ProviderExecuted: true,
180					ClientMetadata:   "provider metadata",
181					Result:           fantasy.ToolResultOutputContentText{Text: "provider-executed result"},
182				},
183			},
184		},
185	}
186
187	usage, estimated := fallbackStepUsage(nil, step)
188	require.True(t, estimated)
189	require.Positive(t, usage.OutputTokens)
190	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
191}
192
193func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
194	t.Parallel()
195
196	usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
197	require.False(t, estimated)
198	require.True(t, usageIsZero(usage))
199}
200
201func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
202	t.Parallel()
203
204	agent := &sessionAgent{}
205	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
206	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
207	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
208
209	agent.updateSessionUsage(model, currentSession, usage, nil, true)
210
211	require.Equal(t, 1.25, currentSession.Cost)
212	require.Equal(t, int64(1000), currentSession.PromptTokens)
213	require.Equal(t, int64(2000), currentSession.CompletionTokens)
214}
215
216func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
217	t.Parallel()
218
219	agent := &sessionAgent{}
220	currentSession := &session.Session{
221		ID:               "session-id",
222		PromptTokens:     123,
223		CompletionTokens: 456,
224		Cost:             1.25,
225	}
226	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
227
228	agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
229
230	require.Equal(t, 1.25, currentSession.Cost)
231	require.Equal(t, int64(123), currentSession.PromptTokens)
232	require.Equal(t, int64(456), currentSession.CompletionTokens)
233}
234
235func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) {
236	t.Parallel()
237
238	agent := &sessionAgent{}
239	currentSession := &session.Session{
240		ID:               "session-id",
241		PromptTokens:     123,
242		CompletionTokens: 456,
243	}
244	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
245	usage := fantasy.Usage{InputTokens: 789}
246
247	agent.updateSessionUsage(model, currentSession, usage, nil, false)
248
249	require.Equal(t, int64(789), currentSession.PromptTokens)
250	require.Zero(t, currentSession.CompletionTokens)
251}
252
253func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
254	t.Parallel()
255
256	agent := &sessionAgent{}
257	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
258	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
259	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
260
261	agent.updateSessionUsage(model, currentSession, usage, nil, false)
262
263	require.Equal(t, 1.3, currentSession.Cost)
264	require.Equal(t, int64(1000), currentSession.PromptTokens)
265	require.Equal(t, int64(2000), currentSession.CompletionTokens)
266}