1package agent
2
3import (
4 "errors"
5 "testing"
6
7 "charm.land/catwalk/pkg/catwalk"
8 "charm.land/fantasy"
9 "github.com/charmbracelet/crush/internal/message"
10 "github.com/charmbracelet/crush/internal/session"
11 "github.com/stretchr/testify/require"
12)
13
14func TestUsageIsZero(t *testing.T) {
15 t.Parallel()
16
17 require.True(t, usageIsZero(fantasy.Usage{}))
18 require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
19 require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
20 require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
21 require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
22 require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
23 require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
24}
25
26func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
27 t.Parallel()
28
29 usage := fantasy.Usage{
30 InputTokens: 10,
31 OutputTokens: 5,
32 TotalTokens: 15,
33 }
34 step := fantasy.StepResult{
35 Response: fantasy.Response{Usage: usage},
36 }
37
38 fallbackUsage, estimated := fallbackStepUsage(nil, step)
39 require.False(t, estimated)
40 require.Equal(t, usage, fallbackUsage)
41}
42
43func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
44 t.Parallel()
45
46 messages := []fantasy.Message{
47 fantasy.NewUserMessage("please explain the implementation details"),
48 }
49 step := fantasy.StepResult{
50 Response: fantasy.Response{
51 Content: fantasy.ResponseContent{
52 fantasy.TextContent{Text: "the implementation stores state safely"},
53 },
54 },
55 }
56
57 usage, estimated := fallbackStepUsage(messages, step)
58 require.True(t, estimated)
59 require.Positive(t, usage.InputTokens)
60 require.Positive(t, usage.OutputTokens)
61 require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
62}
63
64func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
65 t.Parallel()
66
67 messages := []fantasy.Message{
68 {
69 Role: fantasy.MessageRoleAssistant,
70 Content: []fantasy.MessagePart{
71 fantasy.ReasoningPart{Text: "first reason about the request"},
72 },
73 },
74 }
75 step := fantasy.StepResult{
76 Response: fantasy.Response{
77 Content: fantasy.ResponseContent{
78 fantasy.ReasoningContent{Text: "second reason about the answer"},
79 },
80 },
81 }
82
83 usage, estimated := fallbackStepUsage(messages, step)
84 require.True(t, estimated)
85 require.Positive(t, usage.InputTokens)
86 require.Positive(t, usage.OutputTokens)
87}
88
89func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
90 t.Parallel()
91
92 step := fantasy.StepResult{
93 Response: fantasy.Response{
94 Content: fantasy.ResponseContent{
95 fantasy.ToolCallContent{
96 ToolCallID: "tool-call-1",
97 ToolName: "view",
98 Input: `{"file_path":"/tmp/example.go"}`,
99 },
100 },
101 },
102 }
103
104 usage, estimated := fallbackStepUsage(nil, step)
105 require.True(t, estimated)
106 require.Zero(t, usage.InputTokens)
107 require.Positive(t, usage.OutputTokens)
108 require.Equal(t, usage.OutputTokens, usage.TotalTokens)
109}
110
111func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
112 t.Parallel()
113
114 messages := []fantasy.Message{
115 {
116 Role: fantasy.MessageRoleTool,
117 Content: []fantasy.MessagePart{
118 fantasy.ToolResultPart{
119 ToolCallID: "tool-call-1",
120 Output: fantasy.ToolResultOutputContentText{
121 Text: "file contents returned by the tool",
122 },
123 },
124 fantasy.ToolResultPart{
125 ToolCallID: "tool-call-2",
126 Output: fantasy.ToolResultOutputContentError{
127 Error: errors.New("permission denied"),
128 },
129 },
130 fantasy.ToolResultPart{
131 ToolCallID: "tool-call-3",
132 Output: fantasy.ToolResultOutputContentMedia{
133 MediaType: "image/png",
134 Text: "screenshot",
135 Data: "abc123",
136 },
137 },
138 },
139 },
140 }
141
142 usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
143 require.True(t, estimated)
144 require.Positive(t, usage.InputTokens)
145 require.Zero(t, usage.OutputTokens)
146 require.Equal(t, usage.InputTokens, usage.TotalTokens)
147}
148
149func TestFallbackStepUsageSkipsClientToolResultsAsOutput(t *testing.T) {
150 t.Parallel()
151
152 step := fantasy.StepResult{
153 Response: fantasy.Response{
154 Content: fantasy.ResponseContent{
155 fantasy.ToolResultContent{
156 ToolCallID: "tool-call-1",
157 ToolName: "bash",
158 Result: fantasy.ToolResultOutputContentText{
159 Text: "large client-executed payload that should not count as model output tokens",
160 },
161 },
162 },
163 },
164 }
165
166 usage, estimated := fallbackStepUsage(nil, step)
167 require.False(t, estimated)
168 require.Zero(t, usage.OutputTokens)
169}
170
171func TestFallbackStepUsageCountsProviderToolResultsAsOutput(t *testing.T) {
172 t.Parallel()
173
174 step := fantasy.StepResult{
175 Response: fantasy.Response{
176 Content: fantasy.ResponseContent{
177 fantasy.ToolResultContent{
178 ToolCallID: "tool-call-1",
179 ToolName: "web_search",
180 ProviderExecuted: true,
181 ClientMetadata: "provider metadata",
182 Result: fantasy.ToolResultOutputContentText{Text: "provider-executed result"},
183 },
184 },
185 },
186 }
187
188 usage, estimated := fallbackStepUsage(nil, step)
189 require.True(t, estimated)
190 require.Positive(t, usage.OutputTokens)
191 require.Equal(t, usage.OutputTokens, usage.TotalTokens)
192}
193
194func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
195 t.Parallel()
196
197 usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
198 require.False(t, estimated)
199 require.True(t, usageIsZero(usage))
200}
201
202func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
203 t.Parallel()
204
205 agent := &sessionAgent{}
206 currentSession := &session.Session{ID: "session-id", Cost: 1.25}
207 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
208 usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
209
210 agent.updateSessionUsage(model, currentSession, usage, nil, true)
211
212 require.Equal(t, 1.25, currentSession.Cost)
213 require.Equal(t, int64(1000), currentSession.PromptTokens)
214 require.Equal(t, int64(2000), currentSession.CompletionTokens)
215}
216
217func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
218 t.Parallel()
219
220 agent := &sessionAgent{}
221 currentSession := &session.Session{
222 ID: "session-id",
223 PromptTokens: 123,
224 CompletionTokens: 456,
225 Cost: 1.25,
226 }
227 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
228
229 agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
230
231 require.Equal(t, 1.25, currentSession.Cost)
232 require.Equal(t, int64(123), currentSession.PromptTokens)
233 require.Equal(t, int64(456), currentSession.CompletionTokens)
234}
235
236func TestUpdateSessionUsagePreservesOmittedCountersForPartialUsage(t *testing.T) {
237 t.Parallel()
238
239 agent := &sessionAgent{}
240 currentSession := &session.Session{
241 ID: "session-id",
242 PromptTokens: 123,
243 CompletionTokens: 456,
244 }
245 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
246 usage := fantasy.Usage{InputTokens: 789}
247
248 agent.updateSessionUsage(model, currentSession, usage, nil, false)
249
250 require.Equal(t, int64(789), currentSession.PromptTokens)
251 require.Equal(t, int64(456), currentSession.CompletionTokens)
252}
253
254func TestUpdateSessionUsagePreservesCountersForTotalOnlyUsage(t *testing.T) {
255 t.Parallel()
256
257 agent := &sessionAgent{}
258 currentSession := &session.Session{
259 ID: "session-id",
260 PromptTokens: 123,
261 CompletionTokens: 456,
262 }
263 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
264 usage := fantasy.Usage{TotalTokens: 100}
265
266 agent.updateSessionUsage(model, currentSession, usage, nil, false)
267
268 require.Equal(t, int64(123), currentSession.PromptTokens)
269 require.Equal(t, int64(456), currentSession.CompletionTokens)
270}
271
272func TestUpdateSessionUsagePreservesPromptForOutputOnlyUsage(t *testing.T) {
273 t.Parallel()
274
275 agent := &sessionAgent{}
276 currentSession := &session.Session{
277 ID: "session-id",
278 PromptTokens: 123,
279 CompletionTokens: 456,
280 }
281 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
282 usage := fantasy.Usage{OutputTokens: 50}
283
284 agent.updateSessionUsage(model, currentSession, usage, nil, false)
285
286 require.Equal(t, int64(123), currentSession.PromptTokens)
287 require.Equal(t, int64(50), currentSession.CompletionTokens)
288}
289
290func TestUpdateSessionUsageKeepsCountersForEstimatedZeroUsage(t *testing.T) {
291 t.Parallel()
292
293 agent := &sessionAgent{}
294 currentSession := &session.Session{
295 ID: "session-id",
296 PromptTokens: 123,
297 CompletionTokens: 456,
298 Cost: 1.25,
299 }
300 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
301
302 agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, true)
303
304 require.Equal(t, 1.25, currentSession.Cost)
305 require.Equal(t, int64(123), currentSession.PromptTokens)
306 require.Equal(t, int64(456), currentSession.CompletionTokens)
307}
308
309func TestSummaryCompletionTokens(t *testing.T) {
310 t.Parallel()
311
312 summaryMessage := message.Message{
313 Parts: []message.ContentPart{
314 message.TextContent{Text: "summary text"},
315 message.ReasoningContent{Thinking: "reasoning text"},
316 },
317 }
318
319 require.Equal(t, int64(42), summaryCompletionTokens(fantasy.Usage{OutputTokens: 42}, summaryMessage))
320 require.Equal(t, approxTokenCount("summary text")+approxTokenCount("reasoning text"), summaryCompletionTokens(fantasy.Usage{}, summaryMessage))
321 require.Zero(t, summaryCompletionTokens(fantasy.Usage{}, message.Message{}))
322}
323
324func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
325 t.Parallel()
326
327 agent := &sessionAgent{}
328 currentSession := &session.Session{ID: "session-id", Cost: 1.25}
329 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
330 usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
331
332 agent.updateSessionUsage(model, currentSession, usage, nil, false)
333
334 require.Equal(t, 1.3, currentSession.Cost)
335 require.Equal(t, int64(1000), currentSession.PromptTokens)
336 require.Equal(t, int64(2000), currentSession.CompletionTokens)
337}