1package providertests
2
3import (
4 "context"
5 "os"
6 "strconv"
7 "strings"
8 "testing"
9
10 "charm.land/fantasy"
11 "github.com/joho/godotenv"
12 "github.com/stretchr/testify/require"
13 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16func init() {
17 if _, err := os.Stat(".env"); err == nil {
18 godotenv.Load(".env")
19 } else {
20 godotenv.Load(".env.sample")
21 }
22}
23
24type testModel struct {
25 name string
26 model string
27 reasoning bool
28}
29
30type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
31
32type builderPair struct {
33 name string
34 builder builderFunc
35 providerOptions fantasy.ProviderOptions
36 prepareStep fantasy.PrepareStepFunction
37}
38
39func testCommon(t *testing.T, pairs []builderPair) {
40 for _, pair := range pairs {
41 t.Run(pair.name, func(t *testing.T) {
42 testSimple(t, pair)
43 testTool(t, pair)
44 testMultiTool(t, pair)
45 })
46 }
47}
48
49func testSimple(t *testing.T, pair builderPair) {
50 checkResult := func(t *testing.T, result *fantasy.AgentResult) {
51 options := []string{"Oi", "oi", "Olá", "olá"}
52 got := result.Response.Content.Text()
53 require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
54 }
55
56 t.Run("simple", func(t *testing.T) {
57 r := newRecorder(t)
58
59 languageModel, err := pair.builder(t, r)
60 require.NoError(t, err, "failed to build language model")
61
62 agent := fantasy.NewAgent(
63 languageModel,
64 fantasy.WithSystemPrompt("You are a helpful assistant"),
65 )
66 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
67 Prompt: "Say hi in Portuguese",
68 ProviderOptions: pair.providerOptions,
69 MaxOutputTokens: fantasy.Opt(int64(4000)),
70 PrepareStep: pair.prepareStep,
71 })
72 require.NoError(t, err, "failed to generate")
73 checkResult(t, result)
74 })
75 t.Run("simple streaming", func(t *testing.T) {
76 r := newRecorder(t)
77
78 languageModel, err := pair.builder(t, r)
79 require.NoError(t, err, "failed to build language model")
80
81 agent := fantasy.NewAgent(
82 languageModel,
83 fantasy.WithSystemPrompt("You are a helpful assistant"),
84 )
85 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
86 Prompt: "Say hi in Portuguese",
87 ProviderOptions: pair.providerOptions,
88 MaxOutputTokens: fantasy.Opt(int64(4000)),
89 PrepareStep: pair.prepareStep,
90 })
91 require.NoError(t, err, "failed to generate")
92 checkResult(t, result)
93 })
94}
95
96func testTool(t *testing.T, pair builderPair) {
97 type WeatherInput struct {
98 Location string `json:"location" description:"the city"`
99 }
100
101 weatherTool := fantasy.NewAgentTool(
102 "weather",
103 "Get weather information for a location",
104 func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
105 return fantasy.NewTextResponse("40 C"), nil
106 },
107 )
108 checkResult := func(t *testing.T, result *fantasy.AgentResult) {
109 require.GreaterOrEqual(t, len(result.Steps), 2)
110
111 var toolCalls []fantasy.ToolCallContent
112 for _, content := range result.Steps[0].Content {
113 if content.GetType() == fantasy.ContentTypeToolCall {
114 toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
115 }
116 }
117 for _, tc := range toolCalls {
118 require.False(t, tc.Invalid)
119 }
120 require.Len(t, toolCalls, 1)
121 require.Equal(t, toolCalls[0].ToolName, "weather")
122
123 want1 := "Florence"
124 want2 := "40"
125 got := result.Response.Content.Text()
126 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
127 }
128
129 t.Run("tool", func(t *testing.T) {
130 r := newRecorder(t)
131
132 languageModel, err := pair.builder(t, r)
133 require.NoError(t, err, "failed to build language model")
134
135 agent := fantasy.NewAgent(
136 languageModel,
137 fantasy.WithSystemPrompt("You are a helpful assistant"),
138 fantasy.WithTools(weatherTool),
139 )
140 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
141 Prompt: "What's the weather in Florence,Italy?",
142 ProviderOptions: pair.providerOptions,
143 MaxOutputTokens: fantasy.Opt(int64(4000)),
144 PrepareStep: pair.prepareStep,
145 })
146 require.NoError(t, err, "failed to generate")
147 checkResult(t, result)
148 })
149 t.Run("tool streaming", func(t *testing.T) {
150 r := newRecorder(t)
151
152 languageModel, err := pair.builder(t, r)
153 require.NoError(t, err, "failed to build language model")
154
155 agent := fantasy.NewAgent(
156 languageModel,
157 fantasy.WithSystemPrompt("You are a helpful assistant"),
158 fantasy.WithTools(weatherTool),
159 )
160 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
161 Prompt: "What's the weather in Florence,Italy?",
162 ProviderOptions: pair.providerOptions,
163 MaxOutputTokens: fantasy.Opt(int64(4000)),
164 PrepareStep: pair.prepareStep,
165 })
166 require.NoError(t, err, "failed to generate")
167 checkResult(t, result)
168 })
169}
170
171func testMultiTool(t *testing.T, pair builderPair) {
172 // Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
173 if strings.Contains(pair.name, "azure") {
174 t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
175 }
176 if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
177 t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
178 }
179 if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
180 t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
181 }
182 if strings.Contains(pair.name, "openai") && strings.Contains(pair.name, "o4-mini") {
183 t.Skip("skipping multi-tool tests for openai o4-mini it for some reason is not doing parallel tool calls even if asked")
184 }
185
186 type CalculatorInput struct {
187 A int `json:"a" description:"first number"`
188 B int `json:"b" description:"second number"`
189 }
190
191 addTool := fantasy.NewAgentTool(
192 "add",
193 "Add two numbers",
194 func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
195 result := input.A + input.B
196 return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
197 },
198 )
199 multiplyTool := fantasy.NewAgentTool(
200 "multiply",
201 "Multiply two numbers",
202 func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
203 result := input.A * input.B
204 return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
205 },
206 )
207 checkResult := func(t *testing.T, result *fantasy.AgentResult) {
208 require.Len(t, result.Steps, 2)
209
210 var toolCalls []fantasy.ToolCallContent
211 for _, content := range result.Steps[0].Content {
212 if content.GetType() == fantasy.ContentTypeToolCall {
213 toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
214 }
215 }
216 for _, tc := range toolCalls {
217 require.False(t, tc.Invalid)
218 }
219 require.Len(t, toolCalls, 2)
220
221 finalText := result.Response.Content.Text()
222 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
223 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
224 }
225
226 t.Run("multi tool", func(t *testing.T) {
227 r := newRecorder(t)
228
229 languageModel, err := pair.builder(t, r)
230 require.NoError(t, err, "failed to build language model")
231
232 agent := fantasy.NewAgent(
233 languageModel,
234 fantasy.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
235 fantasy.WithTools(addTool),
236 fantasy.WithTools(multiplyTool),
237 )
238 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
239 Prompt: "Add and multiply the number 2 and 3",
240 ProviderOptions: pair.providerOptions,
241 MaxOutputTokens: fantasy.Opt(int64(4000)),
242 PrepareStep: pair.prepareStep,
243 })
244 require.NoError(t, err, "failed to generate")
245 checkResult(t, result)
246 })
247 t.Run("multi tool streaming", func(t *testing.T) {
248 r := newRecorder(t)
249
250 languageModel, err := pair.builder(t, r)
251 require.NoError(t, err, "failed to build language model")
252
253 agent := fantasy.NewAgent(
254 languageModel,
255 fantasy.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
256 fantasy.WithTools(addTool),
257 fantasy.WithTools(multiplyTool),
258 )
259 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
260 Prompt: "Add and multiply the number 2 and 3",
261 ProviderOptions: pair.providerOptions,
262 MaxOutputTokens: fantasy.Opt(int64(4000)),
263 PrepareStep: pair.prepareStep,
264 })
265 require.NoError(t, err, "failed to generate")
266 checkResult(t, result)
267 })
268}
269
270func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *fantasy.AgentResult)) {
271 for _, pair := range pairs {
272 t.Run(pair.name, func(t *testing.T) {
273 t.Run("thinking", func(t *testing.T) {
274 r := newRecorder(t)
275
276 languageModel, err := pair.builder(t, r)
277 require.NoError(t, err, "failed to build language model")
278
279 type WeatherInput struct {
280 Location string `json:"location" description:"the city"`
281 }
282
283 weatherTool := fantasy.NewAgentTool(
284 "weather",
285 "Get weather information for a location",
286 func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
287 return fantasy.NewTextResponse("40 C"), nil
288 },
289 )
290
291 agent := fantasy.NewAgent(
292 languageModel,
293 fantasy.WithSystemPrompt("You are a helpful assistant"),
294 fantasy.WithTools(weatherTool),
295 )
296 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
297 Prompt: "What's the weather in Florence, Italy?",
298 ProviderOptions: pair.providerOptions,
299 PrepareStep: pair.prepareStep,
300 })
301 require.NoError(t, err, "failed to generate")
302
303 want1 := "Florence"
304 want2 := "40"
305 got := result.Response.Content.Text()
306 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
307
308 thinkChecks(t, result)
309 })
310 t.Run("thinking-streaming", func(t *testing.T) {
311 r := newRecorder(t)
312
313 languageModel, err := pair.builder(t, r)
314 require.NoError(t, err, "failed to build language model")
315
316 type WeatherInput struct {
317 Location string `json:"location" description:"the city"`
318 }
319
320 weatherTool := fantasy.NewAgentTool(
321 "weather",
322 "Get weather information for a location",
323 func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
324 return fantasy.NewTextResponse("40 C"), nil
325 },
326 )
327
328 agent := fantasy.NewAgent(
329 languageModel,
330 fantasy.WithSystemPrompt("You are a helpful assistant"),
331 fantasy.WithTools(weatherTool),
332 )
333 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
334 Prompt: "What's the weather in Florence, Italy?",
335 ProviderOptions: pair.providerOptions,
336 PrepareStep: pair.prepareStep,
337 })
338 require.NoError(t, err, "failed to generate")
339
340 want1 := "Florence"
341 want2 := "40"
342 got := result.Response.Content.Text()
343 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
344
345 thinkChecks(t, result)
346 })
347 })
348 }
349}
350
351func containsAny(s string, subs ...string) bool {
352 for _, sub := range subs {
353 if strings.Contains(s, sub) {
354 return true
355 }
356 }
357 return false
358}