1package agent
2
3import (
4 "errors"
5 "testing"
6
7 "charm.land/catwalk/pkg/catwalk"
8 "charm.land/fantasy"
9 "github.com/charmbracelet/crush/internal/session"
10 "github.com/stretchr/testify/require"
11)
12
13func TestUsageIsZero(t *testing.T) {
14 t.Parallel()
15
16 require.True(t, usageIsZero(fantasy.Usage{}))
17 require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
18 require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
19 require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
20 require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
21 require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
22 require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
23}
24
25func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
26 t.Parallel()
27
28 usage := fantasy.Usage{
29 InputTokens: 10,
30 OutputTokens: 5,
31 TotalTokens: 15,
32 }
33 step := fantasy.StepResult{
34 Response: fantasy.Response{Usage: usage},
35 }
36
37 fallbackUsage, estimated := fallbackStepUsage(nil, step)
38 require.False(t, estimated)
39 require.Equal(t, usage, fallbackUsage)
40}
41
42func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
43 t.Parallel()
44
45 messages := []fantasy.Message{
46 fantasy.NewUserMessage("please explain the implementation details"),
47 }
48 step := fantasy.StepResult{
49 Response: fantasy.Response{
50 Content: fantasy.ResponseContent{
51 fantasy.TextContent{Text: "the implementation stores state safely"},
52 },
53 },
54 }
55
56 usage, estimated := fallbackStepUsage(messages, step)
57 require.True(t, estimated)
58 require.Positive(t, usage.InputTokens)
59 require.Positive(t, usage.OutputTokens)
60 require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
61}
62
63func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
64 t.Parallel()
65
66 messages := []fantasy.Message{
67 {
68 Role: fantasy.MessageRoleAssistant,
69 Content: []fantasy.MessagePart{
70 fantasy.ReasoningPart{Text: "first reason about the request"},
71 },
72 },
73 }
74 step := fantasy.StepResult{
75 Response: fantasy.Response{
76 Content: fantasy.ResponseContent{
77 fantasy.ReasoningContent{Text: "second reason about the answer"},
78 },
79 },
80 }
81
82 usage, estimated := fallbackStepUsage(messages, step)
83 require.True(t, estimated)
84 require.Positive(t, usage.InputTokens)
85 require.Positive(t, usage.OutputTokens)
86}
87
88func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
89 t.Parallel()
90
91 step := fantasy.StepResult{
92 Response: fantasy.Response{
93 Content: fantasy.ResponseContent{
94 fantasy.ToolCallContent{
95 ToolCallID: "tool-call-1",
96 ToolName: "view",
97 Input: `{"file_path":"/tmp/example.go"}`,
98 },
99 },
100 },
101 }
102
103 usage, estimated := fallbackStepUsage(nil, step)
104 require.True(t, estimated)
105 require.Zero(t, usage.InputTokens)
106 require.Positive(t, usage.OutputTokens)
107 require.Equal(t, usage.OutputTokens, usage.TotalTokens)
108}
109
110func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
111 t.Parallel()
112
113 messages := []fantasy.Message{
114 {
115 Role: fantasy.MessageRoleTool,
116 Content: []fantasy.MessagePart{
117 fantasy.ToolResultPart{
118 ToolCallID: "tool-call-1",
119 Output: fantasy.ToolResultOutputContentText{
120 Text: "file contents returned by the tool",
121 },
122 },
123 fantasy.ToolResultPart{
124 ToolCallID: "tool-call-2",
125 Output: fantasy.ToolResultOutputContentError{
126 Error: errors.New("permission denied"),
127 },
128 },
129 fantasy.ToolResultPart{
130 ToolCallID: "tool-call-3",
131 Output: fantasy.ToolResultOutputContentMedia{
132 MediaType: "image/png",
133 Text: "screenshot",
134 Data: "abc123",
135 },
136 },
137 },
138 },
139 }
140
141 usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
142 require.True(t, estimated)
143 require.Positive(t, usage.InputTokens)
144 require.Zero(t, usage.OutputTokens)
145 require.Equal(t, usage.InputTokens, usage.TotalTokens)
146}
147
148func TestFallbackStepUsageSkipsClientToolResultsAsOutput(t *testing.T) {
149 t.Parallel()
150
151 step := fantasy.StepResult{
152 Response: fantasy.Response{
153 Content: fantasy.ResponseContent{
154 fantasy.ToolResultContent{
155 ToolCallID: "tool-call-1",
156 ToolName: "bash",
157 Result: fantasy.ToolResultOutputContentText{
158 Text: "large client-executed payload that should not count as model output tokens",
159 },
160 },
161 },
162 },
163 }
164
165 usage, estimated := fallbackStepUsage(nil, step)
166 require.False(t, estimated)
167 require.Zero(t, usage.OutputTokens)
168}
169
170func TestFallbackStepUsageCountsProviderToolResultsAsOutput(t *testing.T) {
171 t.Parallel()
172
173 step := fantasy.StepResult{
174 Response: fantasy.Response{
175 Content: fantasy.ResponseContent{
176 fantasy.ToolResultContent{
177 ToolCallID: "tool-call-1",
178 ToolName: "web_search",
179 ProviderExecuted: true,
180 ClientMetadata: "provider metadata",
181 Result: fantasy.ToolResultOutputContentText{Text: "provider-executed result"},
182 },
183 },
184 },
185 }
186
187 usage, estimated := fallbackStepUsage(nil, step)
188 require.True(t, estimated)
189 require.Positive(t, usage.OutputTokens)
190 require.Equal(t, usage.OutputTokens, usage.TotalTokens)
191}
192
193func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
194 t.Parallel()
195
196 usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
197 require.False(t, estimated)
198 require.True(t, usageIsZero(usage))
199}
200
201func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
202 t.Parallel()
203
204 agent := &sessionAgent{}
205 currentSession := &session.Session{ID: "session-id", Cost: 1.25}
206 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
207 usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
208
209 agent.updateSessionUsage(model, currentSession, usage, nil, true)
210
211 require.Equal(t, 1.25, currentSession.Cost)
212 require.Equal(t, int64(1000), currentSession.PromptTokens)
213 require.Equal(t, int64(2000), currentSession.CompletionTokens)
214}
215
216func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
217 t.Parallel()
218
219 agent := &sessionAgent{}
220 currentSession := &session.Session{
221 ID: "session-id",
222 PromptTokens: 123,
223 CompletionTokens: 456,
224 Cost: 1.25,
225 }
226 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
227
228 agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
229
230 require.Equal(t, 1.25, currentSession.Cost)
231 require.Equal(t, int64(123), currentSession.PromptTokens)
232 require.Equal(t, int64(456), currentSession.CompletionTokens)
233}
234
235func TestUpdateSessionUsageReplacesCountersForPartialUsage(t *testing.T) {
236 t.Parallel()
237
238 agent := &sessionAgent{}
239 currentSession := &session.Session{
240 ID: "session-id",
241 PromptTokens: 123,
242 CompletionTokens: 456,
243 }
244 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
245 usage := fantasy.Usage{InputTokens: 789}
246
247 agent.updateSessionUsage(model, currentSession, usage, nil, false)
248
249 require.Equal(t, int64(789), currentSession.PromptTokens)
250 require.Zero(t, currentSession.CompletionTokens)
251}
252
253func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
254 t.Parallel()
255
256 agent := &sessionAgent{}
257 currentSession := &session.Session{ID: "session-id", Cost: 1.25}
258 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
259 usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
260
261 agent.updateSessionUsage(model, currentSession, usage, nil, false)
262
263 require.Equal(t, 1.3, currentSession.Cost)
264 require.Equal(t, int64(1000), currentSession.PromptTokens)
265 require.Equal(t, int64(2000), currentSession.CompletionTokens)
266}