1package providertests
2
3import (
4 "context"
5 "os"
6 "strconv"
7 "strings"
8 "testing"
9
10 "github.com/charmbracelet/fantasy/ai"
11 "github.com/joho/godotenv"
12 "github.com/stretchr/testify/require"
13 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16func init() {
17 if _, err := os.Stat(".env"); err == nil {
18 godotenv.Load(".env")
19 } else {
20 godotenv.Load(".env.sample")
21 }
22}
23
24type testModel struct {
25 name string
26 model string
27 reasoning bool
28}
29
30type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
31
32type builderPair struct {
33 name string
34 builder builderFunc
35 providerOptions ai.ProviderOptions
36}
37
38func testCommon(t *testing.T, pairs []builderPair) {
39 for _, pair := range pairs {
40 t.Run(pair.name, func(t *testing.T) {
41 testSimple(t, pair)
42 testTool(t, pair)
43 testMultiTool(t, pair)
44 })
45 }
46}
47
48func testSimple(t *testing.T, pair builderPair) {
49 checkResult := func(t *testing.T, result *ai.AgentResult) {
50 options := []string{"Oi", "oi", "Olá", "olá"}
51 got := result.Response.Content.Text()
52 require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
53 }
54
55 t.Run("simple", func(t *testing.T) {
56 r := newRecorder(t)
57
58 languageModel, err := pair.builder(r)
59 require.NoError(t, err, "failed to build language model")
60
61 agent := ai.NewAgent(
62 languageModel,
63 ai.WithSystemPrompt("You are a helpful assistant"),
64 )
65 result, err := agent.Generate(t.Context(), ai.AgentCall{
66 Prompt: "Say hi in Portuguese",
67 ProviderOptions: pair.providerOptions,
68 MaxOutputTokens: ai.Opt(int64(4000)),
69 })
70 require.NoError(t, err, "failed to generate")
71 checkResult(t, result)
72 })
73 t.Run("simple streaming", func(t *testing.T) {
74 r := newRecorder(t)
75
76 languageModel, err := pair.builder(r)
77 require.NoError(t, err, "failed to build language model")
78
79 agent := ai.NewAgent(
80 languageModel,
81 ai.WithSystemPrompt("You are a helpful assistant"),
82 )
83 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
84 Prompt: "Say hi in Portuguese",
85 ProviderOptions: pair.providerOptions,
86 MaxOutputTokens: ai.Opt(int64(4000)),
87 })
88 require.NoError(t, err, "failed to generate")
89 checkResult(t, result)
90 })
91}
92
93func testTool(t *testing.T, pair builderPair) {
94 type WeatherInput struct {
95 Location string `json:"location" description:"the city"`
96 }
97
98 weatherTool := ai.NewAgentTool(
99 "weather",
100 "Get weather information for a location",
101 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
102 return ai.NewTextResponse("40 C"), nil
103 },
104 )
105 checkResult := func(t *testing.T, result *ai.AgentResult) {
106 require.GreaterOrEqual(t, len(result.Steps), 2)
107
108 var toolCalls []ai.ToolCallContent
109 for _, content := range result.Steps[0].Content {
110 if content.GetType() == ai.ContentTypeToolCall {
111 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
112 }
113 }
114 for _, tc := range toolCalls {
115 require.False(t, tc.Invalid)
116 }
117 require.Len(t, toolCalls, 1)
118 require.Equal(t, toolCalls[0].ToolName, "weather")
119
120 want1 := "Florence"
121 want2 := "40"
122 got := result.Response.Content.Text()
123 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
124 }
125
126 t.Run("tool", func(t *testing.T) {
127 r := newRecorder(t)
128
129 languageModel, err := pair.builder(r)
130 require.NoError(t, err, "failed to build language model")
131
132 agent := ai.NewAgent(
133 languageModel,
134 ai.WithSystemPrompt("You are a helpful assistant"),
135 ai.WithTools(weatherTool),
136 )
137 result, err := agent.Generate(t.Context(), ai.AgentCall{
138 Prompt: "What's the weather in Florence,Italy?",
139 ProviderOptions: pair.providerOptions,
140 MaxOutputTokens: ai.Opt(int64(4000)),
141 })
142 require.NoError(t, err, "failed to generate")
143 checkResult(t, result)
144 })
145 t.Run("tool streaming", func(t *testing.T) {
146 r := newRecorder(t)
147
148 languageModel, err := pair.builder(r)
149 require.NoError(t, err, "failed to build language model")
150
151 agent := ai.NewAgent(
152 languageModel,
153 ai.WithSystemPrompt("You are a helpful assistant"),
154 ai.WithTools(weatherTool),
155 )
156 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
157 Prompt: "What's the weather in Florence,Italy?",
158 ProviderOptions: pair.providerOptions,
159 MaxOutputTokens: ai.Opt(int64(4000)),
160 })
161 require.NoError(t, err, "failed to generate")
162 checkResult(t, result)
163 })
164}
165
166func testMultiTool(t *testing.T, pair builderPair) {
167 // Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
168 if strings.Contains(pair.name, "azure") {
169 t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
170 }
171 if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
172 t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
173 }
174 if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
175 t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
176 }
177 if strings.Contains(pair.name, "openai") && strings.Contains(pair.name, "o4-mini") {
178 t.Skip("skipping multi-tool tests for openai o4-mini it for some reason is not doing parallel tool calls even if asked")
179 }
180
181 type CalculatorInput struct {
182 A int `json:"a" description:"first number"`
183 B int `json:"b" description:"second number"`
184 }
185
186 addTool := ai.NewAgentTool(
187 "add",
188 "Add two numbers",
189 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
190 result := input.A + input.B
191 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
192 },
193 )
194 multiplyTool := ai.NewAgentTool(
195 "multiply",
196 "Multiply two numbers",
197 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
198 result := input.A * input.B
199 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
200 },
201 )
202 checkResult := func(t *testing.T, result *ai.AgentResult) {
203 require.Len(t, result.Steps, 2)
204
205 var toolCalls []ai.ToolCallContent
206 for _, content := range result.Steps[0].Content {
207 if content.GetType() == ai.ContentTypeToolCall {
208 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
209 }
210 }
211 for _, tc := range toolCalls {
212 require.False(t, tc.Invalid)
213 }
214 require.Len(t, toolCalls, 2)
215
216 finalText := result.Response.Content.Text()
217 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
218 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
219 }
220
221 t.Run("multi tool", func(t *testing.T) {
222 r := newRecorder(t)
223
224 languageModel, err := pair.builder(r)
225 require.NoError(t, err, "failed to build language model")
226
227 agent := ai.NewAgent(
228 languageModel,
229 ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
230 ai.WithTools(addTool),
231 ai.WithTools(multiplyTool),
232 )
233 result, err := agent.Generate(t.Context(), ai.AgentCall{
234 Prompt: "Add and multiply the number 2 and 3",
235 ProviderOptions: pair.providerOptions,
236 MaxOutputTokens: ai.Opt(int64(4000)),
237 })
238 require.NoError(t, err, "failed to generate")
239 checkResult(t, result)
240 })
241 t.Run("multi tool streaming", func(t *testing.T) {
242 r := newRecorder(t)
243
244 languageModel, err := pair.builder(r)
245 require.NoError(t, err, "failed to build language model")
246
247 agent := ai.NewAgent(
248 languageModel,
249 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
250 ai.WithTools(addTool),
251 ai.WithTools(multiplyTool),
252 )
253 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
254 Prompt: "Add and multiply the number 2 and 3",
255 ProviderOptions: pair.providerOptions,
256 MaxOutputTokens: ai.Opt(int64(4000)),
257 })
258 require.NoError(t, err, "failed to generate")
259 checkResult(t, result)
260 })
261}
262
263func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
264 for _, pair := range pairs {
265 t.Run(pair.name, func(t *testing.T) {
266 t.Run("thinking", func(t *testing.T) {
267 r := newRecorder(t)
268
269 languageModel, err := pair.builder(r)
270 require.NoError(t, err, "failed to build language model")
271
272 type WeatherInput struct {
273 Location string `json:"location" description:"the city"`
274 }
275
276 weatherTool := ai.NewAgentTool(
277 "weather",
278 "Get weather information for a location",
279 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
280 return ai.NewTextResponse("40 C"), nil
281 },
282 )
283
284 agent := ai.NewAgent(
285 languageModel,
286 ai.WithSystemPrompt("You are a helpful assistant"),
287 ai.WithTools(weatherTool),
288 )
289 result, err := agent.Generate(t.Context(), ai.AgentCall{
290 Prompt: "What's the weather in Florence, Italy?",
291 ProviderOptions: pair.providerOptions,
292 })
293 require.NoError(t, err, "failed to generate")
294
295 want1 := "Florence"
296 want2 := "40"
297 got := result.Response.Content.Text()
298 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
299
300 thinkChecks(t, result)
301 })
302 t.Run("thinking-streaming", func(t *testing.T) {
303 r := newRecorder(t)
304
305 languageModel, err := pair.builder(r)
306 require.NoError(t, err, "failed to build language model")
307
308 type WeatherInput struct {
309 Location string `json:"location" description:"the city"`
310 }
311
312 weatherTool := ai.NewAgentTool(
313 "weather",
314 "Get weather information for a location",
315 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
316 return ai.NewTextResponse("40 C"), nil
317 },
318 )
319
320 agent := ai.NewAgent(
321 languageModel,
322 ai.WithSystemPrompt("You are a helpful assistant"),
323 ai.WithTools(weatherTool),
324 )
325 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
326 Prompt: "What's the weather in Florence, Italy?",
327 ProviderOptions: pair.providerOptions,
328 })
329 require.NoError(t, err, "failed to generate")
330
331 want1 := "Florence"
332 want2 := "40"
333 got := result.Response.Content.Text()
334 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
335
336 thinkChecks(t, result)
337 })
338 })
339 }
340}
341
342func containsAny(s string, subs ...string) bool {
343 for _, sub := range subs {
344 if strings.Contains(s, sub) {
345 return true
346 }
347 }
348 return false
349}