1package agent
2
3import (
4 "errors"
5 "testing"
6
7 "charm.land/catwalk/pkg/catwalk"
8 "charm.land/fantasy"
9 "github.com/charmbracelet/crush/internal/session"
10 "github.com/stretchr/testify/require"
11)
12
13func TestUsageIsZero(t *testing.T) {
14 t.Parallel()
15
16 require.True(t, usageIsZero(fantasy.Usage{}))
17 require.False(t, usageIsZero(fantasy.Usage{InputTokens: 1}))
18 require.False(t, usageIsZero(fantasy.Usage{OutputTokens: 1}))
19 require.False(t, usageIsZero(fantasy.Usage{TotalTokens: 1}))
20 require.False(t, usageIsZero(fantasy.Usage{ReasoningTokens: 1}))
21 require.False(t, usageIsZero(fantasy.Usage{CacheCreationTokens: 1}))
22 require.False(t, usageIsZero(fantasy.Usage{CacheReadTokens: 1}))
23}
24
25func TestFallbackStepUsageKeepsProviderUsage(t *testing.T) {
26 t.Parallel()
27
28 usage := fantasy.Usage{
29 InputTokens: 10,
30 OutputTokens: 5,
31 TotalTokens: 15,
32 }
33 step := fantasy.StepResult{
34 Response: fantasy.Response{Usage: usage},
35 }
36
37 fallbackUsage, estimated := fallbackStepUsage(nil, step)
38 require.False(t, estimated)
39 require.Equal(t, usage, fallbackUsage)
40}
41
42func TestFallbackStepUsageEstimatesPromptAndAssistantText(t *testing.T) {
43 t.Parallel()
44
45 messages := []fantasy.Message{
46 fantasy.NewUserMessage("please explain the implementation details"),
47 }
48 step := fantasy.StepResult{
49 Response: fantasy.Response{
50 Content: fantasy.ResponseContent{
51 fantasy.TextContent{Text: "the implementation stores state safely"},
52 },
53 },
54 }
55
56 usage, estimated := fallbackStepUsage(messages, step)
57 require.True(t, estimated)
58 require.Positive(t, usage.InputTokens)
59 require.Positive(t, usage.OutputTokens)
60 require.Equal(t, usage.InputTokens+usage.OutputTokens, usage.TotalTokens)
61}
62
63func TestFallbackStepUsageEstimatesReasoning(t *testing.T) {
64 t.Parallel()
65
66 messages := []fantasy.Message{
67 {
68 Role: fantasy.MessageRoleAssistant,
69 Content: []fantasy.MessagePart{
70 fantasy.ReasoningPart{Text: "first reason about the request"},
71 },
72 },
73 }
74 step := fantasy.StepResult{
75 Response: fantasy.Response{
76 Content: fantasy.ResponseContent{
77 fantasy.ReasoningContent{Text: "second reason about the answer"},
78 },
79 },
80 }
81
82 usage, estimated := fallbackStepUsage(messages, step)
83 require.True(t, estimated)
84 require.Positive(t, usage.InputTokens)
85 require.Positive(t, usage.OutputTokens)
86}
87
88func TestFallbackStepUsageEstimatesToolCalls(t *testing.T) {
89 t.Parallel()
90
91 step := fantasy.StepResult{
92 Response: fantasy.Response{
93 Content: fantasy.ResponseContent{
94 fantasy.ToolCallContent{
95 ToolCallID: "tool-call-1",
96 ToolName: "view",
97 Input: `{"file_path":"/tmp/example.go"}`,
98 },
99 },
100 },
101 }
102
103 usage, estimated := fallbackStepUsage(nil, step)
104 require.True(t, estimated)
105 require.Zero(t, usage.InputTokens)
106 require.Positive(t, usage.OutputTokens)
107 require.Equal(t, usage.OutputTokens, usage.TotalTokens)
108}
109
110func TestFallbackStepUsageEstimatesToolResults(t *testing.T) {
111 t.Parallel()
112
113 messages := []fantasy.Message{
114 {
115 Role: fantasy.MessageRoleTool,
116 Content: []fantasy.MessagePart{
117 fantasy.ToolResultPart{
118 ToolCallID: "tool-call-1",
119 Output: fantasy.ToolResultOutputContentText{
120 Text: "file contents returned by the tool",
121 },
122 },
123 fantasy.ToolResultPart{
124 ToolCallID: "tool-call-2",
125 Output: fantasy.ToolResultOutputContentError{
126 Error: errors.New("permission denied"),
127 },
128 },
129 fantasy.ToolResultPart{
130 ToolCallID: "tool-call-3",
131 Output: fantasy.ToolResultOutputContentMedia{
132 MediaType: "image/png",
133 Text: "screenshot",
134 Data: "abc123",
135 },
136 },
137 },
138 },
139 }
140
141 usage, estimated := fallbackStepUsage(messages, fantasy.StepResult{})
142 require.True(t, estimated)
143 require.Positive(t, usage.InputTokens)
144 require.Zero(t, usage.OutputTokens)
145 require.Equal(t, usage.InputTokens, usage.TotalTokens)
146}
147
148func TestFallbackStepUsageReturnsZeroWithoutContent(t *testing.T) {
149 t.Parallel()
150
151 usage, estimated := fallbackStepUsage(nil, fantasy.StepResult{})
152 require.False(t, estimated)
153 require.True(t, usageIsZero(usage))
154}
155
156func TestUpdateSessionUsageSkipsEstimatedCost(t *testing.T) {
157 t.Parallel()
158
159 agent := &sessionAgent{}
160 currentSession := &session.Session{ID: "session-id", Cost: 1.25}
161 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
162 usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
163
164 agent.updateSessionUsage(model, currentSession, usage, nil, true)
165
166 require.Equal(t, 1.25, currentSession.Cost)
167 require.Equal(t, int64(1000), currentSession.PromptTokens)
168 require.Equal(t, int64(2000), currentSession.CompletionTokens)
169}
170
171func TestUpdateSessionUsageKeepsCountersForZeroUsage(t *testing.T) {
172 t.Parallel()
173
174 agent := &sessionAgent{}
175 currentSession := &session.Session{
176 ID: "session-id",
177 PromptTokens: 123,
178 CompletionTokens: 456,
179 Cost: 1.25,
180 }
181 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
182
183 agent.updateSessionUsage(model, currentSession, fantasy.Usage{}, nil, false)
184
185 require.Equal(t, 1.25, currentSession.Cost)
186 require.Equal(t, int64(123), currentSession.PromptTokens)
187 require.Equal(t, int64(456), currentSession.CompletionTokens)
188}
189
190func TestUpdateSessionUsageAddsProviderCost(t *testing.T) {
191 t.Parallel()
192
193 agent := &sessionAgent{}
194 currentSession := &session.Session{ID: "session-id", Cost: 1.25}
195 model := Model{CatwalkCfg: catwalk.Model{CostPer1MIn: 10, CostPer1MOut: 20}}
196 usage := fantasy.Usage{InputTokens: 1000, OutputTokens: 2000}
197
198 agent.updateSessionUsage(model, currentSession, usage, nil, false)
199
200 require.Equal(t, 1.3, currentSession.Cost)
201 require.Equal(t, int64(1000), currentSession.PromptTokens)
202 require.Equal(t, int64(2000), currentSession.CompletionTokens)
203}