usage_fallback_test.go

  1package agent
  2
  3import (
  4	"errors"
  5	"testing"
  6
  7	"charm.land/catwalk/pkg/catwalk"
  8	"charm.land/fantasy"
  9	"github.com/charmbracelet/crush/internal/message"
 10	"github.com/charmbracelet/crush/internal/session"
 11	"github.com/stretchr/testify/require"
 12)
 13
 14func TestUsageIsZero(t *testing.T) {
 15	t.Parallel()
 16
 17	require.True(t, usageIsZero(fantasy.Usage{}))
 18	require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
 19	require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
 20	require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
 21	require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
 22	require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
 23	require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
 24}
 25
 26func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
 27	t.Parallel()
 28
 29	usage := fantasy.Usage{
 30		InputTokens:  10,
 31		OutputTokens: 5,
 32		TotalTokens:  15,
 33	}
 34	step := fantasy.StepResult{
 35		Response: fantasy.Response{Usage: usage},
 36	}
 37
 38	fallbackUsage, estimated := fallbackStepUsage(nil, step)
 39	require.False(t, estimated)
 40	require.Equal(t, usage, fallbackUsage)
 41}
 42
 43func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
 44	t.Parallel()
 45
 46	messages := []fantasy.Message{
 47		fantasy.NewUserMessage("please explain the implementation details"),
 48	}
 49	step := fantasy.StepResult{
 50		Response: fantasy.Response{
 51			Content: fantasy.ResponseContent{
 52				fantasy.TextContent{Text: "the implementation stores state safely"},
 53			},
 54		},
 55	}
 56
 57	usage, estimated := fallbackStepUsage(messages, step)
 58	require.True(t, estimated)
 59	require.Positive(t, usage.InputTokens)
 60	require.Positive(t, usage.OutputTokens)
 61	require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
 62}
 63
 64func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
 65	t.Parallel()
 66
 67	messages := []fantasy.Message{
 68		{
 69			Role: fantasy.MessageRoleAssistant,
 70			Content: []fantasy.MessagePart{
 71				fantasy.ReasoningPart{Text: "first reason about the request"},
 72			},
 73		},
 74	}
 75	step := fantasy.StepResult{
 76		Response: fantasy.Response{
 77			Content: fantasy.ResponseContent{
 78				fantasy.ReasoningContent{Text: "second reason about the answer"},
 79			},
 80		},
 81	}
 82
 83	usage, estimated := fallbackStepUsage(messages, step)
 84	require.True(t, estimated)
 85	require.Positive(t, usage.InputTokens)
 86	require.Positive(t, usage.OutputTokens)
 87}
 88
 89func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
 90	t.Parallel()
 91
 92	step := fantasy.StepResult{
 93		Response: fantasy.Response{
 94			Content: fantasy.ResponseContent{
 95				fantasy.ToolCallContent{
 96					ToolCallID: "tool-call-1",
 97					ToolName:   "view",
 98					Input:      `{"file_path":"/tmp/example.go"}`,
 99				},
100			},
101		},
102	}
103
104	usage, estimated := fallbackStepUsage(nil, step)
105	require.True(t, estimated)
106	require.Zero(t, usage.InputTokens)
107	require.Positive(t, usage.OutputTokens)
108	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
109}
110
111func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
112	t.Parallel()
113
114	messages := []fantasy.Message{
115		{
116			Role: fantasy.MessageRoleTool,
117			Content: []fantasy.MessagePart{
118				fantasy.ToolResultPart{
119					ToolCallID: "tool-call-1",
120					Output: fantasy.ToolResultOutputContentText{
121						Text: "file contents returned by the tool",
122					},
123				},
124				fantasy.ToolResultPart{
125					ToolCallID: "tool-call-2",
126					Output: fantasy.ToolResultOutputContentError{
127						Error: errors.New("permission denied"),
128					},
129				},
130				fantasy.ToolResultPart{
131					ToolCallID: "tool-call-3",
132					Output: fantasy.ToolResultOutputContentMedia{
133						MediaType: "image/png",
134						Text:      "screenshot",
135						Data:      "abc123",
136					},
137				},
138			},
139		},
140	}
141
142	usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
143	require.True(t, estimated)
144	require.Positive(t, usage.InputTokens)
145	require.Zero(t, usage.OutputTokens)
146	require.Equal(t, usage.InputTokens, usage.TotalTokens)
147}
148
149func TestFallbackStepUsageSkipsClientToolResultsAsOutput(t *testing.T) {
150	t.Parallel()
151
152	step := fantasy.StepResult{
153		Response: fantasy.Response{
154			Content: fantasy.ResponseContent{
155				fantasy.ToolResultContent{
156					ToolCallID: "tool-call-1",
157					ToolName:   "bash",
158					Result: fantasy.ToolResultOutputContentText{
159						Text: "large client-executed payload that should not count as model output tokens",
160					},
161				},
162			},
163		},
164	}
165
166	usage, estimated := fallbackStepUsage(nil, step)
167	require.False(t, estimated)
168	require.Zero(t, usage.OutputTokens)
169}
170
171func TestFallbackStepUsageCountsProviderToolResultsAsOutput(t *testing.T) {
172	t.Parallel()
173
174	step := fantasy.StepResult{
175		Response: fantasy.Response{
176			Content: fantasy.ResponseContent{
177				fantasy.ToolResultContent{
178					ToolCallID:       "tool-call-1",
179					ToolName:         "web_search",
180					ProviderExecuted: true,
181					ClientMetadata:   "provider metadata",
182					Result:           fantasy.ToolResultOutputContentText{Text: "provider-executed result"},
183				},
184			},
185		},
186	}
187
188	usage, estimated := fallbackStepUsage(nil, step)
189	require.True(t, estimated)
190	require.Positive(t, usage.OutputTokens)
191	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
192}
193
194func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
195	t.Parallel()
196
197	usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
198	require.False(t, estimated)
199	require.True(t, usageIsZero(usage))
200}
201
202func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
203	t.Parallel()
204
205	agent := &sessionAgent{}
206	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
207	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
208	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
209
210	agent.updateSessionUsage(model, currentSession, usage, nil, true)
211
212	require.Equal(t, 1.25, currentSession.Cost)
213	require.Equal(t, int64(1000), currentSession.PromptTokens)
214	require.Equal(t, int64(2000), currentSession.CompletionTokens)
215	require.True(t, currentSession.EstimatedUsage)
216}
217
218func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
219	t.Parallel()
220
221	agent := &sessionAgent{}
222	currentSession := &session.Session{
223		ID:               "session-id",
224		PromptTokens:     123,
225		CompletionTokens: 456,
226		Cost:             1.25,
227	}
228	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
229
230	agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
231
232	require.Equal(t, 1.25, currentSession.Cost)
233	require.Equal(t, int64(123), currentSession.PromptTokens)
234	require.Equal(t, int64(456), currentSession.CompletionTokens)
235}
236
237func TestUpdateSessionUsagePreservesOmittedCountersForPartialUsage(t *testing.T) {
238	t.Parallel()
239
240	agent := &sessionAgent{}
241	currentSession := &session.Session{
242		ID:               "session-id",
243		PromptTokens:     123,
244		CompletionTokens: 456,
245	}
246	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
247	usage := fantasy.Usage{InputTokens: 789}
248
249	agent.updateSessionUsage(model, currentSession, usage, nil, false)
250
251	require.Equal(t, int64(789), currentSession.PromptTokens)
252	require.Equal(t, int64(456), currentSession.CompletionTokens)
253}
254
255func TestUpdateSessionUsagePreservesCountersForTotalOnlyUsage(t *testing.T) {
256	t.Parallel()
257
258	agent := &sessionAgent{}
259	currentSession := &session.Session{
260		ID:               "session-id",
261		PromptTokens:     123,
262		CompletionTokens: 456,
263	}
264	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
265	usage := fantasy.Usage{TotalTokens: 100}
266
267	agent.updateSessionUsage(model, currentSession, usage, nil, false)
268
269	require.Equal(t, int64(123), currentSession.PromptTokens)
270	require.Equal(t, int64(456), currentSession.CompletionTokens)
271}
272
273func TestUpdateSessionUsagePreservesPromptForOutputOnlyUsage(t *testing.T) {
274	t.Parallel()
275
276	agent := &sessionAgent{}
277	currentSession := &session.Session{
278		ID:               "session-id",
279		PromptTokens:     123,
280		CompletionTokens: 456,
281	}
282	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
283	usage := fantasy.Usage{OutputTokens: 50}
284
285	agent.updateSessionUsage(model, currentSession, usage, nil, false)
286
287	require.Equal(t, int64(123), currentSession.PromptTokens)
288	require.Equal(t, int64(50), currentSession.CompletionTokens)
289}
290
291func TestUpdateSessionUsageKeepsCountersForEstimatedZeroUsage(t *testing.T) {
292	t.Parallel()
293
294	agent := &sessionAgent{}
295	currentSession := &session.Session{
296		ID:               "session-id",
297		PromptTokens:     123,
298		CompletionTokens: 456,
299		Cost:             1.25,
300	}
301	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
302
303	agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, true)
304
305	require.Equal(t, 1.25, currentSession.Cost)
306	require.Equal(t, int64(123), currentSession.PromptTokens)
307	require.Equal(t, int64(456), currentSession.CompletionTokens)
308}
309
310func TestSummaryCompletionTokens(t *testing.T) {
311	t.Parallel()
312
313	summaryMessage := message.Message{
314		Parts: []message.ContentPart{
315			message.TextContent{Text: "summary text"},
316			message.ReasoningContent{Thinking: "reasoning text"},
317		},
318	}
319
320	require.Equal(t, int64(42), summaryCompletionTokens(fantasy.Usage{OutputTokens: 42}, summaryMessage))
321	require.Equal(t, approxTokenCount("summary text")+approxTokenCount("reasoning text"), summaryCompletionTokens(fantasy.Usage{}, summaryMessage))
322	require.Zero(t, summaryCompletionTokens(fantasy.Usage{}, message.Message{}))
323}
324
325func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
326	t.Parallel()
327
328	agent := &sessionAgent{}
329	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
330	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
331	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
332
333	agent.updateSessionUsage(model, currentSession, usage, nil, false)
334
335	require.Equal(t, 1.3, currentSession.Cost)
336	require.Equal(t, int64(1000), currentSession.PromptTokens)
337	require.Equal(t, int64(2000), currentSession.CompletionTokens)
338	require.False(t, currentSession.EstimatedUsage)
339}