usage_fallback_test.go

  1package agent
  2
  3import (
  4	"errors"
  5	"testing"
  6
  7	"charm.land/catwalk/pkg/catwalk"
  8	"charm.land/fantasy"
  9	"github.com/charmbracelet/crush/internal/message"
 10	"github.com/charmbracelet/crush/internal/session"
 11	"github.com/stretchr/testify/require"
 12)
 13
 14func TestUsageIsZero(t *testing.T) {
 15	t.Parallel()
 16
 17	require.True(t, usageIsZero(fantasy.Usage{}))
 18	require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
 19	require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
 20	require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
 21	require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
 22	require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
 23	require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
 24}
 25
 26func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
 27	t.Parallel()
 28
 29	usage := fantasy.Usage{
 30		InputTokens:  10,
 31		OutputTokens: 5,
 32		TotalTokens:  15,
 33	}
 34	step := fantasy.StepResult{
 35		Response: fantasy.Response{Usage: usage},
 36	}
 37
 38	fallbackUsage, estimated := fallbackStepUsage(nil, step)
 39	require.False(t, estimated)
 40	require.Equal(t, usage, fallbackUsage)
 41}
 42
 43func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
 44	t.Parallel()
 45
 46	messages := []fantasy.Message{
 47		fantasy.NewUserMessage("please explain the implementation details"),
 48	}
 49	step := fantasy.StepResult{
 50		Response: fantasy.Response{
 51			Content: fantasy.ResponseContent{
 52				fantasy.TextContent{Text: "the implementation stores state safely"},
 53			},
 54		},
 55	}
 56
 57	usage, estimated := fallbackStepUsage(messages, step)
 58	require.True(t, estimated)
 59	require.Positive(t, usage.InputTokens)
 60	require.Positive(t, usage.OutputTokens)
 61	require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
 62}
 63
 64func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
 65	t.Parallel()
 66
 67	messages := []fantasy.Message{
 68		{
 69			Role: fantasy.MessageRoleAssistant,
 70			Content: []fantasy.MessagePart{
 71				fantasy.ReasoningPart{Text: "first reason about the request"},
 72			},
 73		},
 74	}
 75	step := fantasy.StepResult{
 76		Response: fantasy.Response{
 77			Content: fantasy.ResponseContent{
 78				fantasy.ReasoningContent{Text: "second reason about the answer"},
 79			},
 80		},
 81	}
 82
 83	usage, estimated := fallbackStepUsage(messages, step)
 84	require.True(t, estimated)
 85	require.Positive(t, usage.InputTokens)
 86	require.Positive(t, usage.OutputTokens)
 87}
 88
 89func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
 90	t.Parallel()
 91
 92	step := fantasy.StepResult{
 93		Response: fantasy.Response{
 94			Content: fantasy.ResponseContent{
 95				fantasy.ToolCallContent{
 96					ToolCallID: "tool-call-1",
 97					ToolName:   "view",
 98					Input:      `{"file_path":"/tmp/example.go"}`,
 99				},
100			},
101		},
102	}
103
104	usage, estimated := fallbackStepUsage(nil, step)
105	require.True(t, estimated)
106	require.Zero(t, usage.InputTokens)
107	require.Positive(t, usage.OutputTokens)
108	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
109}
110
111func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
112	t.Parallel()
113
114	messages := []fantasy.Message{
115		{
116			Role: fantasy.MessageRoleTool,
117			Content: []fantasy.MessagePart{
118				fantasy.ToolResultPart{
119					ToolCallID: "tool-call-1",
120					Output: fantasy.ToolResultOutputContentText{
121						Text: "file contents returned by the tool",
122					},
123				},
124				fantasy.ToolResultPart{
125					ToolCallID: "tool-call-2",
126					Output: fantasy.ToolResultOutputContentError{
127						Error: errors.New("permission denied"),
128					},
129				},
130				fantasy.ToolResultPart{
131					ToolCallID: "tool-call-3",
132					Output: fantasy.ToolResultOutputContentMedia{
133						MediaType: "image/png",
134						Text:      "screenshot",
135						Data:      "abc123",
136					},
137				},
138			},
139		},
140	}
141
142	usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
143	require.True(t, estimated)
144	require.Positive(t, usage.InputTokens)
145	require.Zero(t, usage.OutputTokens)
146	require.Equal(t, usage.InputTokens, usage.TotalTokens)
147}
148
149func TestFallbackStepUsageSkipsClientToolResultsAsOutput(t *testing.T) {
150	t.Parallel()
151
152	step := fantasy.StepResult{
153		Response: fantasy.Response{
154			Content: fantasy.ResponseContent{
155				fantasy.ToolResultContent{
156					ToolCallID: "tool-call-1",
157					ToolName:   "bash",
158					Result: fantasy.ToolResultOutputContentText{
159						Text: "large client-executed payload that should not count as model output tokens",
160					},
161				},
162			},
163		},
164	}
165
166	usage, estimated := fallbackStepUsage(nil, step)
167	require.False(t, estimated)
168	require.Zero(t, usage.OutputTokens)
169}
170
171func TestFallbackStepUsageCountsProviderToolResultsAsOutput(t *testing.T) {
172	t.Parallel()
173
174	step := fantasy.StepResult{
175		Response: fantasy.Response{
176			Content: fantasy.ResponseContent{
177				fantasy.ToolResultContent{
178					ToolCallID:       "tool-call-1",
179					ToolName:         "web_search",
180					ProviderExecuted: true,
181					ClientMetadata:   "provider metadata",
182					Result:           fantasy.ToolResultOutputContentText{Text: "provider-executed result"},
183				},
184			},
185		},
186	}
187
188	usage, estimated := fallbackStepUsage(nil, step)
189	require.True(t, estimated)
190	require.Positive(t, usage.OutputTokens)
191	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
192}
193
194func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
195	t.Parallel()
196
197	usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
198	require.False(t, estimated)
199	require.True(t, usageIsZero(usage))
200}
201
202func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
203	t.Parallel()
204
205	agent := &sessionAgent{}
206	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
207	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
208	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
209
210	agent.updateSessionUsage(model, currentSession, usage, nil, true)
211
212	require.Equal(t, 1.25, currentSession.Cost)
213	require.Equal(t, int64(1000), currentSession.PromptTokens)
214	require.Equal(t, int64(2000), currentSession.CompletionTokens)
215}
216
217func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
218	t.Parallel()
219
220	agent := &sessionAgent{}
221	currentSession := &session.Session{
222		ID:               "session-id",
223		PromptTokens:     123,
224		CompletionTokens: 456,
225		Cost:             1.25,
226	}
227	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
228
229	agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
230
231	require.Equal(t, 1.25, currentSession.Cost)
232	require.Equal(t, int64(123), currentSession.PromptTokens)
233	require.Equal(t, int64(456), currentSession.CompletionTokens)
234}
235
236func TestUpdateSessionUsagePreservesOmittedCountersForPartialUsage(t *testing.T) {
237	t.Parallel()
238
239	agent := &sessionAgent{}
240	currentSession := &session.Session{
241		ID:               "session-id",
242		PromptTokens:     123,
243		CompletionTokens: 456,
244	}
245	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
246	usage := fantasy.Usage{InputTokens: 789}
247
248	agent.updateSessionUsage(model, currentSession, usage, nil, false)
249
250	require.Equal(t, int64(789), currentSession.PromptTokens)
251	require.Equal(t, int64(456), currentSession.CompletionTokens)
252}
253
254func TestUpdateSessionUsagePreservesCountersForTotalOnlyUsage(t *testing.T) {
255	t.Parallel()
256
257	agent := &sessionAgent{}
258	currentSession := &session.Session{
259		ID:               "session-id",
260		PromptTokens:     123,
261		CompletionTokens: 456,
262	}
263	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
264	usage := fantasy.Usage{TotalTokens: 100}
265
266	agent.updateSessionUsage(model, currentSession, usage, nil, false)
267
268	require.Equal(t, int64(123), currentSession.PromptTokens)
269	require.Equal(t, int64(456), currentSession.CompletionTokens)
270}
271
272func TestUpdateSessionUsagePreservesPromptForOutputOnlyUsage(t *testing.T) {
273	t.Parallel()
274
275	agent := &sessionAgent{}
276	currentSession := &session.Session{
277		ID:               "session-id",
278		PromptTokens:     123,
279		CompletionTokens: 456,
280	}
281	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
282	usage := fantasy.Usage{OutputTokens: 50}
283
284	agent.updateSessionUsage(model, currentSession, usage, nil, false)
285
286	require.Equal(t, int64(123), currentSession.PromptTokens)
287	require.Equal(t, int64(50), currentSession.CompletionTokens)
288}
289
290func TestUpdateSessionUsageKeepsCountersForEstimatedZeroUsage(t *testing.T) {
291	t.Parallel()
292
293	agent := &sessionAgent{}
294	currentSession := &session.Session{
295		ID:               "session-id",
296		PromptTokens:     123,
297		CompletionTokens: 456,
298		Cost:             1.25,
299	}
300	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
301
302	agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, true)
303
304	require.Equal(t, 1.25, currentSession.Cost)
305	require.Equal(t, int64(123), currentSession.PromptTokens)
306	require.Equal(t, int64(456), currentSession.CompletionTokens)
307}
308
309func TestSummaryCompletionTokens(t *testing.T) {
310	t.Parallel()
311
312	summaryMessage := message.Message{
313		Parts: []message.ContentPart{
314			message.TextContent{Text: "summary text"},
315			message.ReasoningContent{Thinking: "reasoning text"},
316		},
317	}
318
319	require.Equal(t, int64(42), summaryCompletionTokens(fantasy.Usage{OutputTokens: 42}, summaryMessage))
320	require.Equal(t, approxTokenCount("summary text")+approxTokenCount("reasoning text"), summaryCompletionTokens(fantasy.Usage{}, summaryMessage))
321	require.Zero(t, summaryCompletionTokens(fantasy.Usage{}, message.Message{}))
322}
323
324func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
325	t.Parallel()
326
327	agent := &sessionAgent{}
328	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
329	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
330	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
331
332	agent.updateSessionUsage(model, currentSession, usage, nil, false)
333
334	require.Equal(t, 1.3, currentSession.Cost)
335	require.Equal(t, int64(1000), currentSession.PromptTokens)
336	require.Equal(t, int64(2000), currentSession.CompletionTokens)
337}