usage_fallback_test.go

  1package agent
  2
  3import (
  4	"errors"
  5	"testing"
  6
  7	"charm.land/catwalk/pkg/catwalk"
  8	"charm.land/fantasy"
  9	"github.com/charmbracelet/crush/internal/session"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13func TestUsageIsZero(t *testing.T) {
 14	t.Parallel()
 15
 16	require.True(t, usageIsZero(fantasy.Usage{}))
 17	require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
 18	require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
 19	require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
 20	require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
 21	require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
 22	require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
 23}
 24
 25func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
 26	t.Parallel()
 27
 28	usage := fantasy.Usage{
 29		InputTokens:  10,
 30		OutputTokens: 5,
 31		TotalTokens:  15,
 32	}
 33	step := fantasy.StepResult{
 34		Response: fantasy.Response{Usage: usage},
 35	}
 36
 37	fallbackUsage, estimated := fallbackStepUsage(nil, step)
 38	require.False(t, estimated)
 39	require.Equal(t, usage, fallbackUsage)
 40}
 41
 42func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
 43	t.Parallel()
 44
 45	messages := []fantasy.Message{
 46		fantasy.NewUserMessage("please explain the implementation details"),
 47	}
 48	step := fantasy.StepResult{
 49		Response: fantasy.Response{
 50			Content: fantasy.ResponseContent{
 51				fantasy.TextContent{Text: "the implementation stores state safely"},
 52			},
 53		},
 54	}
 55
 56	usage, estimated := fallbackStepUsage(messages, step)
 57	require.True(t, estimated)
 58	require.Positive(t, usage.InputTokens)
 59	require.Positive(t, usage.OutputTokens)
 60	require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
 61}
 62
 63func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
 64	t.Parallel()
 65
 66	messages := []fantasy.Message{
 67		{
 68			Role: fantasy.MessageRoleAssistant,
 69			Content: []fantasy.MessagePart{
 70				fantasy.ReasoningPart{Text: "first reason about the request"},
 71			},
 72		},
 73	}
 74	step := fantasy.StepResult{
 75		Response: fantasy.Response{
 76			Content: fantasy.ResponseContent{
 77				fantasy.ReasoningContent{Text: "second reason about the answer"},
 78			},
 79		},
 80	}
 81
 82	usage, estimated := fallbackStepUsage(messages, step)
 83	require.True(t, estimated)
 84	require.Positive(t, usage.InputTokens)
 85	require.Positive(t, usage.OutputTokens)
 86}
 87
 88func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
 89	t.Parallel()
 90
 91	step := fantasy.StepResult{
 92		Response: fantasy.Response{
 93			Content: fantasy.ResponseContent{
 94				fantasy.ToolCallContent{
 95					ToolCallID: "tool-call-1",
 96					ToolName:   "view",
 97					Input:      `{"file_path":"/tmp/example.go"}`,
 98				},
 99			},
100		},
101	}
102
103	usage, estimated := fallbackStepUsage(nil, step)
104	require.True(t, estimated)
105	require.Zero(t, usage.InputTokens)
106	require.Positive(t, usage.OutputTokens)
107	require.Equal(t, usage.OutputTokens, usage.TotalTokens)
108}
109
110func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
111	t.Parallel()
112
113	messages := []fantasy.Message{
114		{
115			Role: fantasy.MessageRoleTool,
116			Content: []fantasy.MessagePart{
117				fantasy.ToolResultPart{
118					ToolCallID: "tool-call-1",
119					Output: fantasy.ToolResultOutputContentText{
120						Text: "file contents returned by the tool",
121					},
122				},
123				fantasy.ToolResultPart{
124					ToolCallID: "tool-call-2",
125					Output: fantasy.ToolResultOutputContentError{
126						Error: errors.New("permission denied"),
127					},
128				},
129				fantasy.ToolResultPart{
130					ToolCallID: "tool-call-3",
131					Output: fantasy.ToolResultOutputContentMedia{
132						MediaType: "image/png",
133						Text:      "screenshot",
134						Data:      "abc123",
135					},
136				},
137			},
138		},
139	}
140
141	usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
142	require.True(t, estimated)
143	require.Positive(t, usage.InputTokens)
144	require.Zero(t, usage.OutputTokens)
145	require.Equal(t, usage.InputTokens, usage.TotalTokens)
146}
147
148func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
149	t.Parallel()
150
151	usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
152	require.False(t, estimated)
153	require.True(t, usageIsZero(usage))
154}
155
156func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
157	t.Parallel()
158
159	agent := &sessionAgent{}
160	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
161	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
162	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
163
164	agent.updateSessionUsage(model, currentSession, usage, nil, true)
165
166	require.Equal(t, 1.25, currentSession.Cost)
167	require.Equal(t, int64(1000), currentSession.PromptTokens)
168	require.Equal(t, int64(2000), currentSession.CompletionTokens)
169}
170
171func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
172	t.Parallel()
173
174	agent := &sessionAgent{}
175	currentSession := &session.Session{
176		ID:               "session-id",
177		PromptTokens:     123,
178		CompletionTokens: 456,
179		Cost:             1.25,
180	}
181	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
182
183	agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
184
185	require.Equal(t, 1.25, currentSession.Cost)
186	require.Equal(t, int64(123), currentSession.PromptTokens)
187	require.Equal(t, int64(456), currentSession.CompletionTokens)
188}
189
190func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
191	t.Parallel()
192
193	agent := &sessionAgent{}
194	currentSession := &session.Session{ID: "session-id", Cost: 1.25}
195	model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
196	usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
197
198	agent.updateSessionUsage(model, currentSession, usage, nil, false)
199
200	require.Equal(t, 1.3, currentSession.Cost)
201	require.Equal(t, int64(1000), currentSession.PromptTokens)
202	require.Equal(t, int64(2000), currentSession.CompletionTokens)
203}