1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 _ "github.com/joho/godotenv/autoload"
11 "github.com/stretchr/testify/require"
12 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15type testModel struct {
16 name string
17 model string
18 reasoning bool
19}
20
21type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
22
23type builderPair struct {
24 name string
25 builder builderFunc
26 providerOptions ai.ProviderOptions
27}
28
29func testCommon(t *testing.T, pairs []builderPair) {
30 for _, pair := range pairs {
31 t.Run(pair.name, func(t *testing.T) {
32 testSimple(t, pair)
33 testTool(t, pair)
34 testMultiTool(t, pair)
35 })
36 }
37}
38
39func testSimple(t *testing.T, pair builderPair) {
40 checkResult := func(t *testing.T, result *ai.AgentResult) {
41 option1 := "Oi"
42 option2 := "Olá"
43 got := result.Response.Content.Text()
44 require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
45 }
46
47 t.Run("simple", func(t *testing.T) {
48 r := newRecorder(t)
49
50 languageModel, err := pair.builder(r)
51 require.NoError(t, err, "failed to build language model")
52
53 agent := ai.NewAgent(
54 languageModel,
55 ai.WithSystemPrompt("You are a helpful assistant"),
56 )
57 result, err := agent.Generate(t.Context(), ai.AgentCall{
58 Prompt: "Say hi in Portuguese",
59 ProviderOptions: pair.providerOptions,
60 MaxOutputTokens: ai.IntOption(4000),
61 })
62 require.NoError(t, err, "failed to generate")
63 checkResult(t, result)
64 })
65 t.Run("simple streaming", func(t *testing.T) {
66 r := newRecorder(t)
67
68 languageModel, err := pair.builder(r)
69 require.NoError(t, err, "failed to build language model")
70
71 agent := ai.NewAgent(
72 languageModel,
73 ai.WithSystemPrompt("You are a helpful assistant"),
74 )
75 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
76 Prompt: "Say hi in Portuguese",
77 ProviderOptions: pair.providerOptions,
78 MaxOutputTokens: ai.IntOption(4000),
79 })
80 require.NoError(t, err, "failed to generate")
81 checkResult(t, result)
82 })
83}
84
85func testTool(t *testing.T, pair builderPair) {
86 type WeatherInput struct {
87 Location string `json:"location" description:"the city"`
88 }
89
90 weatherTool := ai.NewAgentTool(
91 "weather",
92 "Get weather information for a location",
93 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
94 return ai.NewTextResponse("40 C"), nil
95 },
96 )
97 checkResult := func(t *testing.T, result *ai.AgentResult) {
98 require.Len(t, result.Steps, 2)
99
100 var toolCalls []ai.ToolCallContent
101 for _, content := range result.Steps[0].Content {
102 if content.GetType() == ai.ContentTypeToolCall {
103 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
104 }
105 }
106 for _, tc := range toolCalls {
107 require.False(t, tc.Invalid)
108 }
109 require.Len(t, toolCalls, 1)
110 require.Equal(t, toolCalls[0].ToolName, "weather")
111
112 want1 := "Florence"
113 want2 := "40"
114 got := result.Response.Content.Text()
115 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
116 }
117
118 t.Run("tool", func(t *testing.T) {
119 r := newRecorder(t)
120
121 languageModel, err := pair.builder(r)
122 require.NoError(t, err, "failed to build language model")
123
124 agent := ai.NewAgent(
125 languageModel,
126 ai.WithSystemPrompt("You are a helpful assistant"),
127 ai.WithTools(weatherTool),
128 )
129 result, err := agent.Generate(t.Context(), ai.AgentCall{
130 Prompt: "What's the weather in Florence,Italy?",
131 ProviderOptions: pair.providerOptions,
132 MaxOutputTokens: ai.IntOption(4000),
133 })
134 require.NoError(t, err, "failed to generate")
135 checkResult(t, result)
136 })
137 t.Run("tool streaming", func(t *testing.T) {
138 r := newRecorder(t)
139
140 languageModel, err := pair.builder(r)
141 require.NoError(t, err, "failed to build language model")
142
143 agent := ai.NewAgent(
144 languageModel,
145 ai.WithSystemPrompt("You are a helpful assistant"),
146 ai.WithTools(weatherTool),
147 )
148 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
149 Prompt: "What's the weather in Florence,Italy?",
150 ProviderOptions: pair.providerOptions,
151 MaxOutputTokens: ai.IntOption(4000),
152 })
153 require.NoError(t, err, "failed to generate")
154 checkResult(t, result)
155 })
156}
157
158func testMultiTool(t *testing.T, pair builderPair) {
159 type CalculatorInput struct {
160 A int `json:"a" description:"first number"`
161 B int `json:"b" description:"second number"`
162 }
163
164 addTool := ai.NewAgentTool(
165 "add",
166 "Add two numbers",
167 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
168 result := input.A + input.B
169 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
170 },
171 )
172 multiplyTool := ai.NewAgentTool(
173 "multiply",
174 "Multiply two numbers",
175 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
176 result := input.A * input.B
177 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
178 },
179 )
180 checkResult := func(t *testing.T, result *ai.AgentResult) {
181 require.Len(t, result.Steps, 2)
182
183 var toolCalls []ai.ToolCallContent
184 for _, content := range result.Steps[0].Content {
185 if content.GetType() == ai.ContentTypeToolCall {
186 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
187 }
188 }
189 for _, tc := range toolCalls {
190 require.False(t, tc.Invalid)
191 }
192 require.Len(t, toolCalls, 2)
193
194 finalText := result.Response.Content.Text()
195 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
196 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
197 }
198
199 t.Run("multi tool", func(t *testing.T) {
200 r := newRecorder(t)
201
202 languageModel, err := pair.builder(r)
203 require.NoError(t, err, "failed to build language model")
204
205 agent := ai.NewAgent(
206 languageModel,
207 ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
208 ai.WithTools(addTool),
209 ai.WithTools(multiplyTool),
210 )
211 result, err := agent.Generate(t.Context(), ai.AgentCall{
212 Prompt: "Add and multiply the number 2 and 3",
213 ProviderOptions: pair.providerOptions,
214 MaxOutputTokens: ai.IntOption(4000),
215 })
216 require.NoError(t, err, "failed to generate")
217 checkResult(t, result)
218 })
219 t.Run("multi tool streaming", func(t *testing.T) {
220 r := newRecorder(t)
221
222 languageModel, err := pair.builder(r)
223 require.NoError(t, err, "failed to build language model")
224
225 agent := ai.NewAgent(
226 languageModel,
227 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
228 ai.WithTools(addTool),
229 ai.WithTools(multiplyTool),
230 )
231 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
232 Prompt: "Add and multiply the number 2 and 3",
233 ProviderOptions: pair.providerOptions,
234 MaxOutputTokens: ai.IntOption(4000),
235 })
236 require.NoError(t, err, "failed to generate")
237 checkResult(t, result)
238 })
239}
240
241func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
242 for _, pair := range pairs {
243 t.Run(pair.name, func(t *testing.T) {
244 t.Run("thinking", func(t *testing.T) {
245 r := newRecorder(t)
246
247 languageModel, err := pair.builder(r)
248 require.NoError(t, err, "failed to build language model")
249
250 type WeatherInput struct {
251 Location string `json:"location" description:"the city"`
252 }
253
254 weatherTool := ai.NewAgentTool(
255 "weather",
256 "Get weather information for a location",
257 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
258 return ai.NewTextResponse("40 C"), nil
259 },
260 )
261
262 agent := ai.NewAgent(
263 languageModel,
264 ai.WithSystemPrompt("You are a helpful assistant"),
265 ai.WithTools(weatherTool),
266 )
267 result, err := agent.Generate(t.Context(), ai.AgentCall{
268 Prompt: "What's the weather in Florence, Italy?",
269 ProviderOptions: pair.providerOptions,
270 })
271 require.NoError(t, err, "failed to generate")
272
273 want1 := "Florence"
274 want2 := "40"
275 got := result.Response.Content.Text()
276 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
277
278 thinkChecks(t, result)
279 })
280 t.Run("thinking-streaming", func(t *testing.T) {
281 r := newRecorder(t)
282
283 languageModel, err := pair.builder(r)
284 require.NoError(t, err, "failed to build language model")
285
286 type WeatherInput struct {
287 Location string `json:"location" description:"the city"`
288 }
289
290 weatherTool := ai.NewAgentTool(
291 "weather",
292 "Get weather information for a location",
293 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
294 return ai.NewTextResponse("40 C"), nil
295 },
296 )
297
298 agent := ai.NewAgent(
299 languageModel,
300 ai.WithSystemPrompt("You are a helpful assistant"),
301 ai.WithTools(weatherTool),
302 )
303 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
304 Prompt: "What's the weather in Florence, Italy?",
305 ProviderOptions: pair.providerOptions,
306 })
307 require.NoError(t, err, "failed to generate")
308
309 want1 := "Florence"
310 want2 := "40"
311 got := result.Response.Content.Text()
312 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
313
314 thinkChecks(t, result)
315 })
316 })
317 }
318}