1package providertests
2
3import (
4 "context"
5 "os"
6 "strconv"
7 "strings"
8 "testing"
9
10 "charm.land/fantasy"
11 "charm.land/x/vcr"
12 "github.com/joho/godotenv"
13 "github.com/stretchr/testify/require"
14)
15
16func init() {
17 if _, err := os.Stat(".env"); err == nil {
18 godotenv.Load(".env")
19 } else {
20 godotenv.Load(".env.sample")
21 }
22}
23
24type testModel struct {
25 name string
26 model string
27 reasoning bool
28}
29
30type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
31
32type builderPair struct {
33 name string
34 builder builderFunc
35 providerOptions fantasy.ProviderOptions
36 prepareStep fantasy.PrepareStepFunction
37}
38
39func testCommon(t *testing.T, pairs []builderPair) {
40 for _, pair := range pairs {
41 t.Run(pair.name, func(t *testing.T) {
42 testSimple(t, pair)
43 testTool(t, pair)
44 testMultiTool(t, pair)
45 })
46 }
47}
48
49func testSimple(t *testing.T, pair builderPair) {
50 checkResult := func(t *testing.T, result *fantasy.AgentResult) {
51 options := []string{"Oi", "oi", "Olá", "olá"}
52 got := result.Response.Content.Text()
53 require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
54 }
55
56 t.Run("simple", func(t *testing.T) {
57 if strings.HasPrefix(pair.name, "avian-") {
58 t.Skip("Avian only support streaming")
59 }
60
61 r := vcr.NewRecorder(t)
62
63 languageModel, err := pair.builder(t, r)
64 require.NoError(t, err, "failed to build language model")
65
66 agent := fantasy.NewAgent(
67 languageModel,
68 fantasy.WithSystemPrompt("You are a helpful assistant"),
69 )
70 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
71 Prompt: "Say hi in Portuguese",
72 ProviderOptions: pair.providerOptions,
73 MaxOutputTokens: new(int64(4000)),
74 PrepareStep: pair.prepareStep,
75 })
76 require.NoError(t, err, "failed to generate")
77 checkResult(t, result)
78 })
79
80 t.Run("simple streaming", func(t *testing.T) {
81 r := vcr.NewRecorder(t)
82
83 languageModel, err := pair.builder(t, r)
84 require.NoError(t, err, "failed to build language model")
85
86 agent := fantasy.NewAgent(
87 languageModel,
88 fantasy.WithSystemPrompt("You are a helpful assistant"),
89 )
90 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
91 Prompt: "Say hi in Portuguese",
92 ProviderOptions: pair.providerOptions,
93 MaxOutputTokens: new(int64(4000)),
94 PrepareStep: pair.prepareStep,
95 })
96 require.NoError(t, err, "failed to generate")
97 checkResult(t, result)
98 })
99}
100
101func testTool(t *testing.T, pair builderPair) {
102 type WeatherInput struct {
103 Location string `json:"location" description:"the city"`
104 }
105
106 weatherTool := fantasy.NewAgentTool(
107 "weather",
108 "Get weather information for a location",
109 func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
110 return fantasy.NewTextResponse("40 C"), nil
111 },
112 )
113 checkResult := func(t *testing.T, result *fantasy.AgentResult) {
114 require.GreaterOrEqual(t, len(result.Steps), 2)
115
116 var toolCalls []fantasy.ToolCallContent
117 for _, content := range result.Steps[0].Content {
118 if content.GetType() == fantasy.ContentTypeToolCall {
119 toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
120 }
121 }
122 for _, tc := range toolCalls {
123 require.False(t, tc.Invalid)
124 }
125 require.Len(t, toolCalls, 1)
126 require.Equal(t, toolCalls[0].ToolName, "weather")
127
128 want1 := "Florence"
129 want2 := "40"
130 got := result.Response.Content.Text()
131 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
132 }
133
134 t.Run("tool", func(t *testing.T) {
135 if strings.HasPrefix(pair.name, "avian-") {
136 t.Skip("Avian only support streaming")
137 }
138
139 r := vcr.NewRecorder(t)
140
141 languageModel, err := pair.builder(t, r)
142 require.NoError(t, err, "failed to build language model")
143
144 agent := fantasy.NewAgent(
145 languageModel,
146 fantasy.WithSystemPrompt("You are a helpful assistant"),
147 fantasy.WithTools(weatherTool),
148 )
149 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
150 Prompt: "What's the weather in Florence,Italy?",
151 ProviderOptions: pair.providerOptions,
152 MaxOutputTokens: new(int64(4000)),
153 PrepareStep: pair.prepareStep,
154 })
155 require.NoError(t, err, "failed to generate")
156 checkResult(t, result)
157 })
158
159 t.Run("tool streaming", func(t *testing.T) {
160 r := vcr.NewRecorder(t)
161
162 languageModel, err := pair.builder(t, r)
163 require.NoError(t, err, "failed to build language model")
164
165 agent := fantasy.NewAgent(
166 languageModel,
167 fantasy.WithSystemPrompt("You are a helpful assistant"),
168 fantasy.WithTools(weatherTool),
169 )
170 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
171 Prompt: "What's the weather in Florence,Italy?",
172 ProviderOptions: pair.providerOptions,
173 MaxOutputTokens: new(int64(4000)),
174 PrepareStep: pair.prepareStep,
175 })
176 require.NoError(t, err, "failed to generate")
177 checkResult(t, result)
178 })
179}
180
181func testMultiTool(t *testing.T, pair builderPair) {
182 // Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
183 if strings.Contains(pair.name, "azure") {
184 t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
185 }
186 if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
187 t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
188 }
189 if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
190 t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
191 }
192 if strings.Contains(pair.name, "openai") && strings.Contains(pair.name, "o4-mini") {
193 t.Skip("skipping multi-tool tests for openai o4-mini it for some reason is not doing parallel tool calls even if asked")
194 }
195 if strings.Contains(pair.name, "llama-cpp") && strings.Contains(pair.name, "gpt-oss") {
196 t.Skip("skipping multi-tool tests for llama-cpp gpt-oss as it does not support parallel multi-tool calls")
197 }
198
199 type CalculatorInput struct {
200 A int `json:"a" description:"first number"`
201 B int `json:"b" description:"second number"`
202 }
203
204 addTool := fantasy.NewAgentTool(
205 "add",
206 "Add two numbers",
207 func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
208 result := input.A + input.B
209 return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
210 },
211 )
212 multiplyTool := fantasy.NewAgentTool(
213 "multiply",
214 "Multiply two numbers",
215 func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
216 result := input.A * input.B
217 return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
218 },
219 )
220 checkResult := func(t *testing.T, result *fantasy.AgentResult) {
221 require.Len(t, result.Steps, 2)
222
223 var toolCalls []fantasy.ToolCallContent
224 for _, content := range result.Steps[0].Content {
225 if content.GetType() == fantasy.ContentTypeToolCall {
226 toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
227 }
228 }
229 for _, tc := range toolCalls {
230 require.False(t, tc.Invalid)
231 }
232 require.Len(t, toolCalls, 2)
233
234 finalText := result.Response.Content.Text()
235 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
236 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
237 }
238
239 t.Run("multi tool", func(t *testing.T) {
240 if strings.HasPrefix(pair.name, "avian-") {
241 t.Skip("Avian only support streaming")
242 }
243
244 r := vcr.NewRecorder(t)
245
246 languageModel, err := pair.builder(t, r)
247 require.NoError(t, err, "failed to build language model")
248
249 agent := fantasy.NewAgent(
250 languageModel,
251 fantasy.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
252 fantasy.WithTools(addTool),
253 fantasy.WithTools(multiplyTool),
254 )
255 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
256 Prompt: "Add and multiply the number 2 and 3",
257 ProviderOptions: pair.providerOptions,
258 MaxOutputTokens: new(int64(4000)),
259 PrepareStep: pair.prepareStep,
260 })
261 require.NoError(t, err, "failed to generate")
262 checkResult(t, result)
263 })
264
265 t.Run("multi tool streaming", func(t *testing.T) {
266 r := vcr.NewRecorder(t)
267
268 languageModel, err := pair.builder(t, r)
269 require.NoError(t, err, "failed to build language model")
270
271 agent := fantasy.NewAgent(
272 languageModel,
273 fantasy.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
274 fantasy.WithTools(addTool),
275 fantasy.WithTools(multiplyTool),
276 )
277 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
278 Prompt: "Add and multiply the number 2 and 3",
279 ProviderOptions: pair.providerOptions,
280 MaxOutputTokens: new(int64(4000)),
281 PrepareStep: pair.prepareStep,
282 })
283 require.NoError(t, err, "failed to generate")
284 checkResult(t, result)
285 })
286}
287
288func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *fantasy.AgentResult)) {
289 for _, pair := range pairs {
290 t.Run(pair.name, func(t *testing.T) {
291 t.Run("thinking", func(t *testing.T) {
292 if strings.HasPrefix(pair.name, "avian-") {
293 t.Skip("Avian only support streaming")
294 }
295
296 r := vcr.NewRecorder(t)
297
298 languageModel, err := pair.builder(t, r)
299 require.NoError(t, err, "failed to build language model")
300
301 type WeatherInput struct {
302 Location string `json:"location" description:"the city"`
303 }
304
305 weatherTool := fantasy.NewAgentTool(
306 "weather",
307 "Get weather information for a location",
308 func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
309 return fantasy.NewTextResponse("40 C"), nil
310 },
311 )
312
313 agent := fantasy.NewAgent(
314 languageModel,
315 fantasy.WithSystemPrompt("You are a helpful assistant"),
316 fantasy.WithTools(weatherTool),
317 )
318 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
319 Prompt: "What's the weather in Florence, Italy?",
320 ProviderOptions: pair.providerOptions,
321 PrepareStep: pair.prepareStep,
322 })
323 require.NoError(t, err, "failed to generate")
324
325 want1 := "Florence"
326 want2 := "40"
327 got := result.Response.Content.Text()
328 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
329
330 thinkChecks(t, result)
331 })
332
333 t.Run("thinking-streaming", func(t *testing.T) {
334 r := vcr.NewRecorder(t)
335
336 languageModel, err := pair.builder(t, r)
337 require.NoError(t, err, "failed to build language model")
338
339 type WeatherInput struct {
340 Location string `json:"location" description:"the city"`
341 }
342
343 weatherTool := fantasy.NewAgentTool(
344 "weather",
345 "Get weather information for a location",
346 func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
347 return fantasy.NewTextResponse("40 C"), nil
348 },
349 )
350
351 agent := fantasy.NewAgent(
352 languageModel,
353 fantasy.WithSystemPrompt("You are a helpful assistant"),
354 fantasy.WithTools(weatherTool),
355 )
356 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
357 Prompt: "What's the weather in Florence, Italy?",
358 ProviderOptions: pair.providerOptions,
359 PrepareStep: pair.prepareStep,
360 })
361 require.NoError(t, err, "failed to generate")
362
363 want1 := "Florence"
364 want2 := "40"
365 got := result.Response.Content.Text()
366 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
367
368 thinkChecks(t, result)
369 })
370 })
371 }
372}
373
374func containsAny(s string, subs ...string) bool {
375 for _, sub := range subs {
376 if strings.Contains(s, sub) {
377 return true
378 }
379 }
380 return false
381}