1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/stretchr/testify/require"
11 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
12)
13
14type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
15
16type builderPair struct {
17 name string
18 builder builderFunc
19 providerOptions ai.ProviderOptions
20}
21
22func testCommon(t *testing.T, pairs []builderPair) {
23 for _, pair := range pairs {
24 testSimple(t, pair)
25 testTool(t, pair)
26 testMultiTool(t, pair)
27 }
28}
29
30func testSimple(t *testing.T, pair builderPair) {
31 checkResult := func(t *testing.T, result *ai.AgentResult) {
32 option1 := "Oi"
33 option2 := "Olá"
34 got := result.Response.Content.Text()
35 require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
36 }
37
38 t.Run("simple "+pair.name, func(t *testing.T) {
39 r := newRecorder(t)
40
41 languageModel, err := pair.builder(r)
42 require.NoError(t, err, "failed to build language model")
43
44 agent := ai.NewAgent(
45 languageModel,
46 ai.WithSystemPrompt("You are a helpful assistant"),
47 )
48 result, err := agent.Generate(t.Context(), ai.AgentCall{
49 Prompt: "Say hi in Portuguese",
50 ProviderOptions: pair.providerOptions,
51 MaxOutputTokens: ai.IntOption(4000),
52 })
53 require.NoError(t, err, "failed to generate")
54 checkResult(t, result)
55 })
56 t.Run("simple streaming "+pair.name, func(t *testing.T) {
57 r := newRecorder(t)
58
59 languageModel, err := pair.builder(r)
60 require.NoError(t, err, "failed to build language model")
61
62 agent := ai.NewAgent(
63 languageModel,
64 ai.WithSystemPrompt("You are a helpful assistant"),
65 )
66 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
67 Prompt: "Say hi in Portuguese",
68 ProviderOptions: pair.providerOptions,
69 MaxOutputTokens: ai.IntOption(4000),
70 })
71 require.NoError(t, err, "failed to generate")
72 checkResult(t, result)
73 })
74}
75
76func testTool(t *testing.T, pair builderPair) {
77 type WeatherInput struct {
78 Location string `json:"location" description:"the city"`
79 }
80
81 weatherTool := ai.NewAgentTool(
82 "weather",
83 "Get weather information for a location",
84 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
85 return ai.NewTextResponse("40 C"), nil
86 },
87 )
88 checkResult := func(t *testing.T, result *ai.AgentResult) {
89 require.Len(t, result.Steps, 2)
90
91 var toolCalls []ai.ToolCallContent
92 for _, content := range result.Steps[0].Content {
93 if content.GetType() == ai.ContentTypeToolCall {
94 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
95 }
96 }
97 for _, tc := range toolCalls {
98 require.False(t, tc.Invalid)
99 }
100 require.Len(t, toolCalls, 1)
101 require.Equal(t, toolCalls[0].ToolName, "weather")
102
103 want1 := "Florence"
104 want2 := "40"
105 got := result.Response.Content.Text()
106 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
107 }
108
109 t.Run("tool "+pair.name, func(t *testing.T) {
110 r := newRecorder(t)
111
112 languageModel, err := pair.builder(r)
113 require.NoError(t, err, "failed to build language model")
114
115 agent := ai.NewAgent(
116 languageModel,
117 ai.WithSystemPrompt("You are a helpful assistant"),
118 ai.WithTools(weatherTool),
119 )
120 result, err := agent.Generate(t.Context(), ai.AgentCall{
121 Prompt: "What's the weather in Florence,Italy?",
122 ProviderOptions: pair.providerOptions,
123 MaxOutputTokens: ai.IntOption(4000),
124 })
125 require.NoError(t, err, "failed to generate")
126 checkResult(t, result)
127 })
128 t.Run("tool streaming "+pair.name, func(t *testing.T) {
129 r := newRecorder(t)
130
131 languageModel, err := pair.builder(r)
132 require.NoError(t, err, "failed to build language model")
133
134 agent := ai.NewAgent(
135 languageModel,
136 ai.WithSystemPrompt("You are a helpful assistant"),
137 ai.WithTools(weatherTool),
138 )
139 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
140 Prompt: "What's the weather in Florence,Italy?",
141 ProviderOptions: pair.providerOptions,
142 MaxOutputTokens: ai.IntOption(4000),
143 })
144 require.NoError(t, err, "failed to generate")
145 checkResult(t, result)
146 })
147}
148
149func testMultiTool(t *testing.T, pair builderPair) {
150 type CalculatorInput struct {
151 A int `json:"a" description:"first number"`
152 B int `json:"b" description:"second number"`
153 }
154
155 addTool := ai.NewAgentTool(
156 "add",
157 "Add two numbers",
158 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
159 result := input.A + input.B
160 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
161 },
162 )
163 multiplyTool := ai.NewAgentTool(
164 "multiply",
165 "Multiply two numbers",
166 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
167 result := input.A * input.B
168 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
169 },
170 )
171 checkResult := func(t *testing.T, result *ai.AgentResult) {
172 require.Len(t, result.Steps, 2)
173
174 var toolCalls []ai.ToolCallContent
175 for _, content := range result.Steps[0].Content {
176 if content.GetType() == ai.ContentTypeToolCall {
177 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
178 }
179 }
180 for _, tc := range toolCalls {
181 require.False(t, tc.Invalid)
182 }
183 require.Len(t, toolCalls, 2)
184
185 finalText := result.Response.Content.Text()
186 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
187 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
188 }
189
190 t.Run("multi tool "+pair.name, func(t *testing.T) {
191 r := newRecorder(t)
192
193 languageModel, err := pair.builder(r)
194 require.NoError(t, err, "failed to build language model")
195
196 agent := ai.NewAgent(
197 languageModel,
198 ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
199 ai.WithTools(addTool),
200 ai.WithTools(multiplyTool),
201 )
202 result, err := agent.Generate(t.Context(), ai.AgentCall{
203 Prompt: "Add and multiply the number 2 and 3",
204 ProviderOptions: pair.providerOptions,
205 MaxOutputTokens: ai.IntOption(4000),
206 })
207 require.NoError(t, err, "failed to generate")
208 checkResult(t, result)
209 })
210 t.Run("multi tool streaming "+pair.name, func(t *testing.T) {
211 r := newRecorder(t)
212
213 languageModel, err := pair.builder(r)
214 require.NoError(t, err, "failed to build language model")
215
216 agent := ai.NewAgent(
217 languageModel,
218 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
219 ai.WithTools(addTool),
220 ai.WithTools(multiplyTool),
221 )
222 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
223 Prompt: "Add and multiply the number 2 and 3",
224 ProviderOptions: pair.providerOptions,
225 MaxOutputTokens: ai.IntOption(4000),
226 })
227 require.NoError(t, err, "failed to generate")
228 checkResult(t, result)
229 })
230}
231
232func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
233 for _, pair := range pairs {
234 t.Run("thinking-"+pair.name, func(t *testing.T) {
235 r := newRecorder(t)
236
237 languageModel, err := pair.builder(r)
238 require.NoError(t, err, "failed to build language model")
239
240 type WeatherInput struct {
241 Location string `json:"location" description:"the city"`
242 }
243
244 weatherTool := ai.NewAgentTool(
245 "weather",
246 "Get weather information for a location",
247 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
248 return ai.NewTextResponse("40 C"), nil
249 },
250 )
251
252 agent := ai.NewAgent(
253 languageModel,
254 ai.WithSystemPrompt("You are a helpful assistant"),
255 ai.WithTools(weatherTool),
256 )
257 result, err := agent.Generate(t.Context(), ai.AgentCall{
258 Prompt: "What's the weather in Florence, Italy?",
259 ProviderOptions: pair.providerOptions,
260 })
261 require.NoError(t, err, "failed to generate")
262
263 want1 := "Florence"
264 want2 := "40"
265 got := result.Response.Content.Text()
266 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
267
268 thinkChecks(t, result)
269 })
270 t.Run("thinking-streaming-"+pair.name, func(t *testing.T) {
271 r := newRecorder(t)
272
273 languageModel, err := pair.builder(r)
274 require.NoError(t, err, "failed to build language model")
275
276 type WeatherInput struct {
277 Location string `json:"location" description:"the city"`
278 }
279
280 weatherTool := ai.NewAgentTool(
281 "weather",
282 "Get weather information for a location",
283 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
284 return ai.NewTextResponse("40 C"), nil
285 },
286 )
287
288 agent := ai.NewAgent(
289 languageModel,
290 ai.WithSystemPrompt("You are a helpful assistant"),
291 ai.WithTools(weatherTool),
292 )
293 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
294 Prompt: "What's the weather in Florence, Italy?",
295 ProviderOptions: pair.providerOptions,
296 })
297 require.NoError(t, err, "failed to generate")
298
299 want1 := "Florence"
300 want2 := "40"
301 got := result.Response.Content.Text()
302 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
303
304 thinkChecks(t, result)
305 })
306 }
307}