1package providertests
2
3import (
4 "context"
5 "os"
6 "strconv"
7 "strings"
8 "testing"
9
10 "github.com/charmbracelet/fantasy/ai"
11 "github.com/joho/godotenv"
12 "github.com/stretchr/testify/require"
13 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16func init() {
17 if _, err := os.Stat(".env"); err == nil {
18 godotenv.Load(".env")
19 } else {
20 godotenv.Load(".env.sample")
21 }
22}
23
24type testModel struct {
25 name string
26 model string
27 reasoning bool
28}
29
30type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
31
32type builderPair struct {
33 name string
34 builder builderFunc
35 providerOptions ai.ProviderOptions
36}
37
38func testCommon(t *testing.T, pairs []builderPair) {
39 for _, pair := range pairs {
40 t.Run(pair.name, func(t *testing.T) {
41 testSimple(t, pair)
42 testTool(t, pair)
43 testMultiTool(t, pair)
44 })
45 }
46}
47
48func testSimple(t *testing.T, pair builderPair) {
49 checkResult := func(t *testing.T, result *ai.AgentResult) {
50 options := []string{"Oi", "oi", "Olá", "olá"}
51 got := result.Response.Content.Text()
52 require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
53 }
54
55 t.Run("simple", func(t *testing.T) {
56 r := newRecorder(t)
57
58 languageModel, err := pair.builder(r)
59 require.NoError(t, err, "failed to build language model")
60
61 agent := ai.NewAgent(
62 languageModel,
63 ai.WithSystemPrompt("You are a helpful assistant"),
64 )
65 result, err := agent.Generate(t.Context(), ai.AgentCall{
66 Prompt: "Say hi in Portuguese",
67 ProviderOptions: pair.providerOptions,
68 MaxOutputTokens: ai.IntOption(4000),
69 })
70 require.NoError(t, err, "failed to generate")
71 checkResult(t, result)
72 })
73 t.Run("simple streaming", func(t *testing.T) {
74 r := newRecorder(t)
75
76 languageModel, err := pair.builder(r)
77 require.NoError(t, err, "failed to build language model")
78
79 agent := ai.NewAgent(
80 languageModel,
81 ai.WithSystemPrompt("You are a helpful assistant"),
82 )
83 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
84 Prompt: "Say hi in Portuguese",
85 ProviderOptions: pair.providerOptions,
86 MaxOutputTokens: ai.IntOption(4000),
87 })
88 require.NoError(t, err, "failed to generate")
89 checkResult(t, result)
90 })
91}
92
93func testTool(t *testing.T, pair builderPair) {
94 type WeatherInput struct {
95 Location string `json:"location" description:"the city"`
96 }
97
98 weatherTool := ai.NewAgentTool(
99 "weather",
100 "Get weather information for a location",
101 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
102 return ai.NewTextResponse("40 C"), nil
103 },
104 )
105 checkResult := func(t *testing.T, result *ai.AgentResult) {
106 require.GreaterOrEqual(t, len(result.Steps), 2)
107
108 var toolCalls []ai.ToolCallContent
109 for _, content := range result.Steps[0].Content {
110 if content.GetType() == ai.ContentTypeToolCall {
111 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
112 }
113 }
114 for _, tc := range toolCalls {
115 require.False(t, tc.Invalid)
116 }
117 require.Len(t, toolCalls, 1)
118 require.Equal(t, toolCalls[0].ToolName, "weather")
119
120 want1 := "Florence"
121 want2 := "40"
122 got := result.Response.Content.Text()
123 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
124 }
125
126 t.Run("tool", func(t *testing.T) {
127 r := newRecorder(t)
128
129 languageModel, err := pair.builder(r)
130 require.NoError(t, err, "failed to build language model")
131
132 agent := ai.NewAgent(
133 languageModel,
134 ai.WithSystemPrompt("You are a helpful assistant"),
135 ai.WithTools(weatherTool),
136 )
137 result, err := agent.Generate(t.Context(), ai.AgentCall{
138 Prompt: "What's the weather in Florence,Italy?",
139 ProviderOptions: pair.providerOptions,
140 MaxOutputTokens: ai.IntOption(4000),
141 })
142 require.NoError(t, err, "failed to generate")
143 checkResult(t, result)
144 })
145 t.Run("tool streaming", func(t *testing.T) {
146 r := newRecorder(t)
147
148 languageModel, err := pair.builder(r)
149 require.NoError(t, err, "failed to build language model")
150
151 agent := ai.NewAgent(
152 languageModel,
153 ai.WithSystemPrompt("You are a helpful assistant"),
154 ai.WithTools(weatherTool),
155 )
156 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
157 Prompt: "What's the weather in Florence,Italy?",
158 ProviderOptions: pair.providerOptions,
159 MaxOutputTokens: ai.IntOption(4000),
160 })
161 require.NoError(t, err, "failed to generate")
162 checkResult(t, result)
163 })
164}
165
166func testMultiTool(t *testing.T, pair builderPair) {
167 // Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
168 if strings.Contains(pair.name, "azure") {
169 t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
170 }
171 if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
172 t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
173 }
174
175 type CalculatorInput struct {
176 A int `json:"a" description:"first number"`
177 B int `json:"b" description:"second number"`
178 }
179
180 addTool := ai.NewAgentTool(
181 "add",
182 "Add two numbers",
183 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
184 result := input.A + input.B
185 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
186 },
187 )
188 multiplyTool := ai.NewAgentTool(
189 "multiply",
190 "Multiply two numbers",
191 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
192 result := input.A * input.B
193 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
194 },
195 )
196 checkResult := func(t *testing.T, result *ai.AgentResult) {
197 require.Len(t, result.Steps, 2)
198
199 var toolCalls []ai.ToolCallContent
200 for _, content := range result.Steps[0].Content {
201 if content.GetType() == ai.ContentTypeToolCall {
202 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
203 }
204 }
205 for _, tc := range toolCalls {
206 require.False(t, tc.Invalid)
207 }
208 require.Len(t, toolCalls, 2)
209
210 finalText := result.Response.Content.Text()
211 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
212 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
213 }
214
215 t.Run("multi tool", func(t *testing.T) {
216 r := newRecorder(t)
217
218 languageModel, err := pair.builder(r)
219 require.NoError(t, err, "failed to build language model")
220
221 agent := ai.NewAgent(
222 languageModel,
223 ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
224 ai.WithTools(addTool),
225 ai.WithTools(multiplyTool),
226 )
227 result, err := agent.Generate(t.Context(), ai.AgentCall{
228 Prompt: "Add and multiply the number 2 and 3",
229 ProviderOptions: pair.providerOptions,
230 MaxOutputTokens: ai.IntOption(4000),
231 })
232 require.NoError(t, err, "failed to generate")
233 checkResult(t, result)
234 })
235 t.Run("multi tool streaming", func(t *testing.T) {
236 r := newRecorder(t)
237
238 languageModel, err := pair.builder(r)
239 require.NoError(t, err, "failed to build language model")
240
241 agent := ai.NewAgent(
242 languageModel,
243 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
244 ai.WithTools(addTool),
245 ai.WithTools(multiplyTool),
246 )
247 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
248 Prompt: "Add and multiply the number 2 and 3",
249 ProviderOptions: pair.providerOptions,
250 MaxOutputTokens: ai.IntOption(4000),
251 })
252 require.NoError(t, err, "failed to generate")
253 checkResult(t, result)
254 })
255}
256
257func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
258 for _, pair := range pairs {
259 t.Run(pair.name, func(t *testing.T) {
260 t.Run("thinking", func(t *testing.T) {
261 r := newRecorder(t)
262
263 languageModel, err := pair.builder(r)
264 require.NoError(t, err, "failed to build language model")
265
266 type WeatherInput struct {
267 Location string `json:"location" description:"the city"`
268 }
269
270 weatherTool := ai.NewAgentTool(
271 "weather",
272 "Get weather information for a location",
273 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
274 return ai.NewTextResponse("40 C"), nil
275 },
276 )
277
278 agent := ai.NewAgent(
279 languageModel,
280 ai.WithSystemPrompt("You are a helpful assistant"),
281 ai.WithTools(weatherTool),
282 )
283 result, err := agent.Generate(t.Context(), ai.AgentCall{
284 Prompt: "What's the weather in Florence, Italy?",
285 ProviderOptions: pair.providerOptions,
286 })
287 require.NoError(t, err, "failed to generate")
288
289 want1 := "Florence"
290 want2 := "40"
291 got := result.Response.Content.Text()
292 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
293
294 thinkChecks(t, result)
295 })
296 t.Run("thinking-streaming", func(t *testing.T) {
297 r := newRecorder(t)
298
299 languageModel, err := pair.builder(r)
300 require.NoError(t, err, "failed to build language model")
301
302 type WeatherInput struct {
303 Location string `json:"location" description:"the city"`
304 }
305
306 weatherTool := ai.NewAgentTool(
307 "weather",
308 "Get weather information for a location",
309 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
310 return ai.NewTextResponse("40 C"), nil
311 },
312 )
313
314 agent := ai.NewAgent(
315 languageModel,
316 ai.WithSystemPrompt("You are a helpful assistant"),
317 ai.WithTools(weatherTool),
318 )
319 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
320 Prompt: "What's the weather in Florence, Italy?",
321 ProviderOptions: pair.providerOptions,
322 })
323 require.NoError(t, err, "failed to generate")
324
325 want1 := "Florence"
326 want2 := "40"
327 got := result.Response.Content.Text()
328 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
329
330 thinkChecks(t, result)
331 })
332 })
333 }
334}
335
336func containsAny(s string, subs ...string) bool {
337 for _, sub := range subs {
338 if strings.Contains(s, sub) {
339 return true
340 }
341 }
342 return false
343}