1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 _ "github.com/joho/godotenv/autoload"
11 "github.com/stretchr/testify/require"
12 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15type testModel struct {
16 name string
17 model string
18 reasoning bool
19}
20
21type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
22
23type builderPair struct {
24 name string
25 builder builderFunc
26 providerOptions ai.ProviderOptions
27}
28
29func testCommon(t *testing.T, pairs []builderPair) {
30 for _, pair := range pairs {
31 t.Run(pair.name, func(t *testing.T) {
32 testSimple(t, pair)
33 testTool(t, pair)
34 testMultiTool(t, pair)
35 })
36 }
37}
38
39func testSimple(t *testing.T, pair builderPair) {
40 checkResult := func(t *testing.T, result *ai.AgentResult) {
41 options := []string{"Oi", "oi", "Olá", "olá"}
42 got := result.Response.Content.Text()
43 require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
44 }
45
46 t.Run("simple", func(t *testing.T) {
47 r := newRecorder(t)
48
49 languageModel, err := pair.builder(r)
50 require.NoError(t, err, "failed to build language model")
51
52 agent := ai.NewAgent(
53 languageModel,
54 ai.WithSystemPrompt("You are a helpful assistant"),
55 )
56 result, err := agent.Generate(t.Context(), ai.AgentCall{
57 Prompt: "Say hi in Portuguese",
58 ProviderOptions: pair.providerOptions,
59 MaxOutputTokens: ai.IntOption(4000),
60 })
61 require.NoError(t, err, "failed to generate")
62 checkResult(t, result)
63 })
64 t.Run("simple streaming", func(t *testing.T) {
65 r := newRecorder(t)
66
67 languageModel, err := pair.builder(r)
68 require.NoError(t, err, "failed to build language model")
69
70 agent := ai.NewAgent(
71 languageModel,
72 ai.WithSystemPrompt("You are a helpful assistant"),
73 )
74 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
75 Prompt: "Say hi in Portuguese",
76 ProviderOptions: pair.providerOptions,
77 MaxOutputTokens: ai.IntOption(4000),
78 })
79 require.NoError(t, err, "failed to generate")
80 checkResult(t, result)
81 })
82}
83
84func testTool(t *testing.T, pair builderPair) {
85 type WeatherInput struct {
86 Location string `json:"location" description:"the city"`
87 }
88
89 weatherTool := ai.NewAgentTool(
90 "weather",
91 "Get weather information for a location",
92 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
93 return ai.NewTextResponse("40 C"), nil
94 },
95 )
96 checkResult := func(t *testing.T, result *ai.AgentResult) {
97 require.Len(t, result.Steps, 2)
98
99 var toolCalls []ai.ToolCallContent
100 for _, content := range result.Steps[0].Content {
101 if content.GetType() == ai.ContentTypeToolCall {
102 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
103 }
104 }
105 for _, tc := range toolCalls {
106 require.False(t, tc.Invalid)
107 }
108 require.Len(t, toolCalls, 1)
109 require.Equal(t, toolCalls[0].ToolName, "weather")
110
111 want1 := "Florence"
112 want2 := "40"
113 got := result.Response.Content.Text()
114 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
115 }
116
117 t.Run("tool", func(t *testing.T) {
118 r := newRecorder(t)
119
120 languageModel, err := pair.builder(r)
121 require.NoError(t, err, "failed to build language model")
122
123 agent := ai.NewAgent(
124 languageModel,
125 ai.WithSystemPrompt("You are a helpful assistant"),
126 ai.WithTools(weatherTool),
127 )
128 result, err := agent.Generate(t.Context(), ai.AgentCall{
129 Prompt: "What's the weather in Florence,Italy?",
130 ProviderOptions: pair.providerOptions,
131 MaxOutputTokens: ai.IntOption(4000),
132 })
133 require.NoError(t, err, "failed to generate")
134 checkResult(t, result)
135 })
136 t.Run("tool streaming", func(t *testing.T) {
137 r := newRecorder(t)
138
139 languageModel, err := pair.builder(r)
140 require.NoError(t, err, "failed to build language model")
141
142 agent := ai.NewAgent(
143 languageModel,
144 ai.WithSystemPrompt("You are a helpful assistant"),
145 ai.WithTools(weatherTool),
146 )
147 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
148 Prompt: "What's the weather in Florence,Italy?",
149 ProviderOptions: pair.providerOptions,
150 MaxOutputTokens: ai.IntOption(4000),
151 })
152 require.NoError(t, err, "failed to generate")
153 checkResult(t, result)
154 })
155}
156
157func testMultiTool(t *testing.T, pair builderPair) {
158 // Apparently, Azure does not support multi-tools calls at all?
159 if strings.Contains(pair.name, "azure") {
160 t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
161 }
162
163 type CalculatorInput struct {
164 A int `json:"a" description:"first number"`
165 B int `json:"b" description:"second number"`
166 }
167
168 addTool := ai.NewAgentTool(
169 "add",
170 "Add two numbers",
171 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
172 result := input.A + input.B
173 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
174 },
175 )
176 multiplyTool := ai.NewAgentTool(
177 "multiply",
178 "Multiply two numbers",
179 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
180 result := input.A * input.B
181 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
182 },
183 )
184 checkResult := func(t *testing.T, result *ai.AgentResult) {
185 require.Len(t, result.Steps, 2)
186
187 var toolCalls []ai.ToolCallContent
188 for _, content := range result.Steps[0].Content {
189 if content.GetType() == ai.ContentTypeToolCall {
190 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
191 }
192 }
193 for _, tc := range toolCalls {
194 require.False(t, tc.Invalid)
195 }
196 require.Len(t, toolCalls, 2)
197
198 finalText := result.Response.Content.Text()
199 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
200 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
201 }
202
203 t.Run("multi tool", func(t *testing.T) {
204 r := newRecorder(t)
205
206 languageModel, err := pair.builder(r)
207 require.NoError(t, err, "failed to build language model")
208
209 agent := ai.NewAgent(
210 languageModel,
211 ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
212 ai.WithTools(addTool),
213 ai.WithTools(multiplyTool),
214 )
215 result, err := agent.Generate(t.Context(), ai.AgentCall{
216 Prompt: "Add and multiply the number 2 and 3",
217 ProviderOptions: pair.providerOptions,
218 MaxOutputTokens: ai.IntOption(4000),
219 })
220 require.NoError(t, err, "failed to generate")
221 checkResult(t, result)
222 })
223 t.Run("multi tool streaming", func(t *testing.T) {
224 r := newRecorder(t)
225
226 languageModel, err := pair.builder(r)
227 require.NoError(t, err, "failed to build language model")
228
229 agent := ai.NewAgent(
230 languageModel,
231 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
232 ai.WithTools(addTool),
233 ai.WithTools(multiplyTool),
234 )
235 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
236 Prompt: "Add and multiply the number 2 and 3",
237 ProviderOptions: pair.providerOptions,
238 MaxOutputTokens: ai.IntOption(4000),
239 })
240 require.NoError(t, err, "failed to generate")
241 checkResult(t, result)
242 })
243}
244
245func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
246 for _, pair := range pairs {
247 t.Run(pair.name, func(t *testing.T) {
248 t.Run("thinking", func(t *testing.T) {
249 r := newRecorder(t)
250
251 languageModel, err := pair.builder(r)
252 require.NoError(t, err, "failed to build language model")
253
254 type WeatherInput struct {
255 Location string `json:"location" description:"the city"`
256 }
257
258 weatherTool := ai.NewAgentTool(
259 "weather",
260 "Get weather information for a location",
261 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
262 return ai.NewTextResponse("40 C"), nil
263 },
264 )
265
266 agent := ai.NewAgent(
267 languageModel,
268 ai.WithSystemPrompt("You are a helpful assistant"),
269 ai.WithTools(weatherTool),
270 )
271 result, err := agent.Generate(t.Context(), ai.AgentCall{
272 Prompt: "What's the weather in Florence, Italy?",
273 ProviderOptions: pair.providerOptions,
274 })
275 require.NoError(t, err, "failed to generate")
276
277 want1 := "Florence"
278 want2 := "40"
279 got := result.Response.Content.Text()
280 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
281
282 thinkChecks(t, result)
283 })
284 t.Run("thinking-streaming", func(t *testing.T) {
285 r := newRecorder(t)
286
287 languageModel, err := pair.builder(r)
288 require.NoError(t, err, "failed to build language model")
289
290 type WeatherInput struct {
291 Location string `json:"location" description:"the city"`
292 }
293
294 weatherTool := ai.NewAgentTool(
295 "weather",
296 "Get weather information for a location",
297 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
298 return ai.NewTextResponse("40 C"), nil
299 },
300 )
301
302 agent := ai.NewAgent(
303 languageModel,
304 ai.WithSystemPrompt("You are a helpful assistant"),
305 ai.WithTools(weatherTool),
306 )
307 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
308 Prompt: "What's the weather in Florence, Italy?",
309 ProviderOptions: pair.providerOptions,
310 })
311 require.NoError(t, err, "failed to generate")
312
313 want1 := "Florence"
314 want2 := "40"
315 got := result.Response.Content.Text()
316 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
317
318 thinkChecks(t, result)
319 })
320 })
321 }
322}
323
324func containsAny(s string, subs ...string) bool {
325 for _, sub := range subs {
326 if strings.Contains(s, sub) {
327 return true
328 }
329 }
330 return false
331}