1package agent
2
3import (
4 "errors"
5 "testing"
6
7 "charm.land/catwalk/pkg/catwalk"
8 "charm.land/fantasy"
9 "github.com/charmbracelet/crush/internal/message"
10 "github.com/charmbracelet/crush/internal/session"
11 "github.com/stretchr/testify/require"
12)
13
14func TestUsageIsZero(t *testing.T) {
15 t.Parallel()
16
17 require.True(t, usageIsZero(fantasy.Usage{}))
18 require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
19 require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
20 require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
21 require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
22 require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
23 require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
24}
25
26func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
27 t.Parallel()
28
29 usage := fantasy.Usage{
30 InputTokens: 10,
31 OutputTokens: 5,
32 TotalTokens: 15,
33 }
34 step := fantasy.StepResult{
35 Response: fantasy.Response{Usage: usage},
36 }
37
38 fallbackUsage, estimated := fallbackStepUsage(nil, step)
39 require.False(t, estimated)
40 require.Equal(t, usage, fallbackUsage)
41}
42
43func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
44 t.Parallel()
45
46 messages := []fantasy.Message{
47 fantasy.NewUserMessage("please explain the implementation details"),
48 }
49 step := fantasy.StepResult{
50 Response: fantasy.Response{
51 Content: fantasy.ResponseContent{
52 fantasy.TextContent{Text: "the implementation stores state safely"},
53 },
54 },
55 }
56
57 usage, estimated := fallbackStepUsage(messages, step)
58 require.True(t, estimated)
59 require.Positive(t, usage.InputTokens)
60 require.Positive(t, usage.OutputTokens)
61 require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
62}
63
64func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
65 t.Parallel()
66
67 messages := []fantasy.Message{
68 {
69 Role: fantasy.MessageRoleAssistant,
70 Content: []fantasy.MessagePart{
71 fantasy.ReasoningPart{Text: "first reason about the request"},
72 },
73 },
74 }
75 step := fantasy.StepResult{
76 Response: fantasy.Response{
77 Content: fantasy.ResponseContent{
78 fantasy.ReasoningContent{Text: "second reason about the answer"},
79 },
80 },
81 }
82
83 usage, estimated := fallbackStepUsage(messages, step)
84 require.True(t, estimated)
85 require.Positive(t, usage.InputTokens)
86 require.Positive(t, usage.OutputTokens)
87}
88
89func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
90 t.Parallel()
91
92 step := fantasy.StepResult{
93 Response: fantasy.Response{
94 Content: fantasy.ResponseContent{
95 fantasy.ToolCallContent{
96 ToolCallID: "tool-call-1",
97 ToolName: "view",
98 Input: `{"file_path":"/tmp/example.go"}`,
99 },
100 },
101 },
102 }
103
104 usage, estimated := fallbackStepUsage(nil, step)
105 require.True(t, estimated)
106 require.Zero(t, usage.InputTokens)
107 require.Positive(t, usage.OutputTokens)
108 require.Equal(t, usage.OutputTokens, usage.TotalTokens)
109}
110
111func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
112 t.Parallel()
113
114 messages := []fantasy.Message{
115 {
116 Role: fantasy.MessageRoleTool,
117 Content: []fantasy.MessagePart{
118 fantasy.ToolResultPart{
119 ToolCallID: "tool-call-1",
120 Output: fantasy.ToolResultOutputContentText{
121 Text: "file contents returned by the tool",
122 },
123 },
124 fantasy.ToolResultPart{
125 ToolCallID: "tool-call-2",
126 Output: fantasy.ToolResultOutputContentError{
127 Error: errors.New("permission denied"),
128 },
129 },
130 fantasy.ToolResultPart{
131 ToolCallID: "tool-call-3",
132 Output: fantasy.ToolResultOutputContentMedia{
133 MediaType: "image/png",
134 Text: "screenshot",
135 Data: "abc123",
136 },
137 },
138 },
139 },
140 }
141
142 usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
143 require.True(t, estimated)
144 require.Positive(t, usage.InputTokens)
145 require.Zero(t, usage.OutputTokens)
146 require.Equal(t, usage.InputTokens, usage.TotalTokens)
147}
148
149func TestFallbackStepUsageSkipsClientToolResultsAsOutput(t *testing.T) {
150 t.Parallel()
151
152 step := fantasy.StepResult{
153 Response: fantasy.Response{
154 Content: fantasy.ResponseContent{
155 fantasy.ToolResultContent{
156 ToolCallID: "tool-call-1",
157 ToolName: "bash",
158 Result: fantasy.ToolResultOutputContentText{
159 Text: "large client-executed payload that should not count as model output tokens",
160 },
161 },
162 },
163 },
164 }
165
166 usage, estimated := fallbackStepUsage(nil, step)
167 require.False(t, estimated)
168 require.Zero(t, usage.OutputTokens)
169}
170
171func TestFallbackStepUsageCountsProviderToolResultsAsOutput(t *testing.T) {
172 t.Parallel()
173
174 step := fantasy.StepResult{
175 Response: fantasy.Response{
176 Content: fantasy.ResponseContent{
177 fantasy.ToolResultContent{
178 ToolCallID: "tool-call-1",
179 ToolName: "web_search",
180 ProviderExecuted: true,
181 ClientMetadata: "provider metadata",
182 Result: fantasy.ToolResultOutputContentText{Text: "provider-executed result"},
183 },
184 },
185 },
186 }
187
188 usage, estimated := fallbackStepUsage(nil, step)
189 require.True(t, estimated)
190 require.Positive(t, usage.OutputTokens)
191 require.Equal(t, usage.OutputTokens, usage.TotalTokens)
192}
193
194func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
195 t.Parallel()
196
197 usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
198 require.False(t, estimated)
199 require.True(t, usageIsZero(usage))
200}
201
202func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
203 t.Parallel()
204
205 agent := &sessionAgent{}
206 currentSession := &session.Session{ID: "session-id", Cost: 1.25}
207 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
208 usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
209
210 agent.updateSessionUsage(model, currentSession, usage, nil, true)
211
212 require.Equal(t, 1.25, currentSession.Cost)
213 require.Equal(t, int64(1000), currentSession.PromptTokens)
214 require.Equal(t, int64(2000), currentSession.CompletionTokens)
215 require.True(t, currentSession.EstimatedUsage)
216}
217
218func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
219 t.Parallel()
220
221 agent := &sessionAgent{}
222 currentSession := &session.Session{
223 ID: "session-id",
224 PromptTokens: 123,
225 CompletionTokens: 456,
226 Cost: 1.25,
227 }
228 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
229
230 agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
231
232 require.Equal(t, 1.25, currentSession.Cost)
233 require.Equal(t, int64(123), currentSession.PromptTokens)
234 require.Equal(t, int64(456), currentSession.CompletionTokens)
235}
236
237func TestUpdateSessionUsagePreservesOmittedCountersForPartialUsage(t *testing.T) {
238 t.Parallel()
239
240 agent := &sessionAgent{}
241 currentSession := &session.Session{
242 ID: "session-id",
243 PromptTokens: 123,
244 CompletionTokens: 456,
245 }
246 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
247 usage := fantasy.Usage{InputTokens: 789}
248
249 agent.updateSessionUsage(model, currentSession, usage, nil, false)
250
251 require.Equal(t, int64(789), currentSession.PromptTokens)
252 require.Equal(t, int64(456), currentSession.CompletionTokens)
253}
254
255func TestUpdateSessionUsagePreservesCountersForTotalOnlyUsage(t *testing.T) {
256 t.Parallel()
257
258 agent := &sessionAgent{}
259 currentSession := &session.Session{
260 ID: "session-id",
261 PromptTokens: 123,
262 CompletionTokens: 456,
263 }
264 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
265 usage := fantasy.Usage{TotalTokens: 100}
266
267 agent.updateSessionUsage(model, currentSession, usage, nil, false)
268
269 require.Equal(t, int64(123), currentSession.PromptTokens)
270 require.Equal(t, int64(456), currentSession.CompletionTokens)
271}
272
273func TestUpdateSessionUsagePreservesPromptForOutputOnlyUsage(t *testing.T) {
274 t.Parallel()
275
276 agent := &sessionAgent{}
277 currentSession := &session.Session{
278 ID: "session-id",
279 PromptTokens: 123,
280 CompletionTokens: 456,
281 }
282 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
283 usage := fantasy.Usage{OutputTokens: 50}
284
285 agent.updateSessionUsage(model, currentSession, usage, nil, false)
286
287 require.Equal(t, int64(123), currentSession.PromptTokens)
288 require.Equal(t, int64(50), currentSession.CompletionTokens)
289}
290
291func TestUpdateSessionUsageKeepsCountersForEstimatedZeroUsage(t *testing.T) {
292 t.Parallel()
293
294 agent := &sessionAgent{}
295 currentSession := &session.Session{
296 ID: "session-id",
297 PromptTokens: 123,
298 CompletionTokens: 456,
299 Cost: 1.25,
300 }
301 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
302
303 agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, true)
304
305 require.Equal(t, 1.25, currentSession.Cost)
306 require.Equal(t, int64(123), currentSession.PromptTokens)
307 require.Equal(t, int64(456), currentSession.CompletionTokens)
308}
309
310func TestSummaryCompletionTokens(t *testing.T) {
311 t.Parallel()
312
313 summaryMessage := message.Message{
314 Parts: []message.ContentPart{
315 message.TextContent{Text: "summary text"},
316 message.ReasoningContent{Thinking: "reasoning text"},
317 },
318 }
319
320 require.Equal(t, int64(42), summaryCompletionTokens(fantasy.Usage{OutputTokens: 42}, summaryMessage))
321 require.Equal(t, approxTokenCount("summary text")+approxTokenCount("reasoning text"), summaryCompletionTokens(fantasy.Usage{}, summaryMessage))
322 require.Zero(t, summaryCompletionTokens(fantasy.Usage{}, message.Message{}))
323}
324
325func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
326 t.Parallel()
327
328 agent := &sessionAgent{}
329 currentSession := &session.Session{ID: "session-id", Cost: 1.25}
330 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
331 usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
332
333 agent.updateSessionUsage(model, currentSession, usage, nil, false)
334
335 require.Equal(t, 1.3, currentSession.Cost)
336 require.Equal(t, int64(1000), currentSession.PromptTokens)
337 require.Equal(t, int64(2000), currentSession.CompletionTokens)
338 require.False(t, currentSession.EstimatedUsage)
339}