1package providertests
2
3import (
4 "context"
5 "os"
6 "strconv"
7 "strings"
8 "testing"
9
10 "github.com/charmbracelet/fantasy/ai"
11 "github.com/joho/godotenv"
12 "github.com/stretchr/testify/require"
13 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16func init() {
17 if _, err := os.Stat(".env"); err == nil {
18 godotenv.Load(".env")
19 } else {
20 godotenv.Load(".env.sample")
21 }
22}
23
24type testModel struct {
25 name string
26 model string
27 reasoning bool
28}
29
30type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
31
32type builderPair struct {
33 name string
34 builder builderFunc
35 providerOptions ai.ProviderOptions
36}
37
38func testCommon(t *testing.T, pairs []builderPair) {
39 for _, pair := range pairs {
40 t.Run(pair.name, func(t *testing.T) {
41 testSimple(t, pair)
42 testTool(t, pair)
43 testMultiTool(t, pair)
44 })
45 }
46}
47
48func testSimple(t *testing.T, pair builderPair) {
49 checkResult := func(t *testing.T, result *ai.AgentResult) {
50 options := []string{"Oi", "oi", "Olá", "olá"}
51 got := result.Response.Content.Text()
52 require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
53 }
54
55 t.Run("simple", func(t *testing.T) {
56 r := newRecorder(t)
57
58 languageModel, err := pair.builder(r)
59 require.NoError(t, err, "failed to build language model")
60
61 agent := ai.NewAgent(
62 languageModel,
63 ai.WithSystemPrompt("You are a helpful assistant"),
64 )
65 result, err := agent.Generate(t.Context(), ai.AgentCall{
66 Prompt: "Say hi in Portuguese",
67 ProviderOptions: pair.providerOptions,
68 MaxOutputTokens: ai.IntOption(4000),
69 })
70 require.NoError(t, err, "failed to generate")
71 checkResult(t, result)
72 })
73 t.Run("simple streaming", func(t *testing.T) {
74 r := newRecorder(t)
75
76 languageModel, err := pair.builder(r)
77 require.NoError(t, err, "failed to build language model")
78
79 agent := ai.NewAgent(
80 languageModel,
81 ai.WithSystemPrompt("You are a helpful assistant"),
82 )
83 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
84 Prompt: "Say hi in Portuguese",
85 ProviderOptions: pair.providerOptions,
86 MaxOutputTokens: ai.IntOption(4000),
87 })
88 require.NoError(t, err, "failed to generate")
89 checkResult(t, result)
90 })
91}
92
93func testTool(t *testing.T, pair builderPair) {
94 type WeatherInput struct {
95 Location string `json:"location" description:"the city"`
96 }
97
98 weatherTool := ai.NewAgentTool(
99 "weather",
100 "Get weather information for a location",
101 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
102 return ai.NewTextResponse("40 C"), nil
103 },
104 )
105 checkResult := func(t *testing.T, result *ai.AgentResult) {
106 require.GreaterOrEqual(t, len(result.Steps), 2)
107
108 var toolCalls []ai.ToolCallContent
109 for _, content := range result.Steps[0].Content {
110 if content.GetType() == ai.ContentTypeToolCall {
111 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
112 }
113 }
114 for _, tc := range toolCalls {
115 require.False(t, tc.Invalid)
116 }
117 require.Len(t, toolCalls, 1)
118 require.Equal(t, toolCalls[0].ToolName, "weather")
119
120 want1 := "Florence"
121 want2 := "40"
122 got := result.Response.Content.Text()
123 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
124 }
125
126 t.Run("tool", func(t *testing.T) {
127 r := newRecorder(t)
128
129 languageModel, err := pair.builder(r)
130 require.NoError(t, err, "failed to build language model")
131
132 agent := ai.NewAgent(
133 languageModel,
134 ai.WithSystemPrompt("You are a helpful assistant"),
135 ai.WithTools(weatherTool),
136 )
137 result, err := agent.Generate(t.Context(), ai.AgentCall{
138 Prompt: "What's the weather in Florence,Italy?",
139 ProviderOptions: pair.providerOptions,
140 MaxOutputTokens: ai.IntOption(4000),
141 })
142 require.NoError(t, err, "failed to generate")
143 checkResult(t, result)
144 })
145 t.Run("tool streaming", func(t *testing.T) {
146 r := newRecorder(t)
147
148 languageModel, err := pair.builder(r)
149 require.NoError(t, err, "failed to build language model")
150
151 agent := ai.NewAgent(
152 languageModel,
153 ai.WithSystemPrompt("You are a helpful assistant"),
154 ai.WithTools(weatherTool),
155 )
156 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
157 Prompt: "What's the weather in Florence,Italy?",
158 ProviderOptions: pair.providerOptions,
159 MaxOutputTokens: ai.IntOption(4000),
160 })
161 require.NoError(t, err, "failed to generate")
162 checkResult(t, result)
163 })
164}
165
166func testMultiTool(t *testing.T, pair builderPair) {
167 // Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
168 if strings.Contains(pair.name, "azure") {
169 t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
170 }
171 if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
172 t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
173 }
174 if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
175 t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
176 }
177
178 type CalculatorInput struct {
179 A int `json:"a" description:"first number"`
180 B int `json:"b" description:"second number"`
181 }
182
183 addTool := ai.NewAgentTool(
184 "add",
185 "Add two numbers",
186 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
187 result := input.A + input.B
188 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
189 },
190 )
191 multiplyTool := ai.NewAgentTool(
192 "multiply",
193 "Multiply two numbers",
194 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
195 result := input.A * input.B
196 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
197 },
198 )
199 checkResult := func(t *testing.T, result *ai.AgentResult) {
200 require.Len(t, result.Steps, 2)
201
202 var toolCalls []ai.ToolCallContent
203 for _, content := range result.Steps[0].Content {
204 if content.GetType() == ai.ContentTypeToolCall {
205 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
206 }
207 }
208 for _, tc := range toolCalls {
209 require.False(t, tc.Invalid)
210 }
211 require.Len(t, toolCalls, 2)
212
213 finalText := result.Response.Content.Text()
214 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
215 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
216 }
217
218 t.Run("multi tool", func(t *testing.T) {
219 r := newRecorder(t)
220
221 languageModel, err := pair.builder(r)
222 require.NoError(t, err, "failed to build language model")
223
224 agent := ai.NewAgent(
225 languageModel,
226 ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
227 ai.WithTools(addTool),
228 ai.WithTools(multiplyTool),
229 )
230 result, err := agent.Generate(t.Context(), ai.AgentCall{
231 Prompt: "Add and multiply the number 2 and 3",
232 ProviderOptions: pair.providerOptions,
233 MaxOutputTokens: ai.IntOption(4000),
234 })
235 require.NoError(t, err, "failed to generate")
236 checkResult(t, result)
237 })
238 t.Run("multi tool streaming", func(t *testing.T) {
239 r := newRecorder(t)
240
241 languageModel, err := pair.builder(r)
242 require.NoError(t, err, "failed to build language model")
243
244 agent := ai.NewAgent(
245 languageModel,
246 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
247 ai.WithTools(addTool),
248 ai.WithTools(multiplyTool),
249 )
250 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
251 Prompt: "Add and multiply the number 2 and 3",
252 ProviderOptions: pair.providerOptions,
253 MaxOutputTokens: ai.IntOption(4000),
254 })
255 require.NoError(t, err, "failed to generate")
256 checkResult(t, result)
257 })
258}
259
260func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
261 for _, pair := range pairs {
262 t.Run(pair.name, func(t *testing.T) {
263 t.Run("thinking", func(t *testing.T) {
264 r := newRecorder(t)
265
266 languageModel, err := pair.builder(r)
267 require.NoError(t, err, "failed to build language model")
268
269 type WeatherInput struct {
270 Location string `json:"location" description:"the city"`
271 }
272
273 weatherTool := ai.NewAgentTool(
274 "weather",
275 "Get weather information for a location",
276 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
277 return ai.NewTextResponse("40 C"), nil
278 },
279 )
280
281 agent := ai.NewAgent(
282 languageModel,
283 ai.WithSystemPrompt("You are a helpful assistant"),
284 ai.WithTools(weatherTool),
285 )
286 result, err := agent.Generate(t.Context(), ai.AgentCall{
287 Prompt: "What's the weather in Florence, Italy?",
288 ProviderOptions: pair.providerOptions,
289 })
290 require.NoError(t, err, "failed to generate")
291
292 want1 := "Florence"
293 want2 := "40"
294 got := result.Response.Content.Text()
295 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
296
297 thinkChecks(t, result)
298 })
299 t.Run("thinking-streaming", func(t *testing.T) {
300 r := newRecorder(t)
301
302 languageModel, err := pair.builder(r)
303 require.NoError(t, err, "failed to build language model")
304
305 type WeatherInput struct {
306 Location string `json:"location" description:"the city"`
307 }
308
309 weatherTool := ai.NewAgentTool(
310 "weather",
311 "Get weather information for a location",
312 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
313 return ai.NewTextResponse("40 C"), nil
314 },
315 )
316
317 agent := ai.NewAgent(
318 languageModel,
319 ai.WithSystemPrompt("You are a helpful assistant"),
320 ai.WithTools(weatherTool),
321 )
322 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
323 Prompt: "What's the weather in Florence, Italy?",
324 ProviderOptions: pair.providerOptions,
325 })
326 require.NoError(t, err, "failed to generate")
327
328 want1 := "Florence"
329 want2 := "40"
330 got := result.Response.Content.Text()
331 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
332
333 thinkChecks(t, result)
334 })
335 })
336 }
337}
338
339func containsAny(s string, subs ...string) bool {
340 for _, sub := range subs {
341 if strings.Contains(s, sub) {
342 return true
343 }
344 }
345 return false
346}