1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/stretchr/testify/require"
11 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
12)
13
14type testModel struct {
15 name string
16 model string
17 reasoning bool
18}
19
20type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
21
22type builderPair struct {
23 name string
24 builder builderFunc
25 providerOptions ai.ProviderOptions
26}
27
28func testCommon(t *testing.T, pairs []builderPair) {
29 for _, pair := range pairs {
30 t.Run(pair.name, func(t *testing.T) {
31 testSimple(t, pair)
32 testTool(t, pair)
33 testMultiTool(t, pair)
34 })
35 }
36}
37
38func testSimple(t *testing.T, pair builderPair) {
39 checkResult := func(t *testing.T, result *ai.AgentResult) {
40 option1 := "Oi"
41 option2 := "Olá"
42 got := result.Response.Content.Text()
43 require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
44 }
45
46 t.Run("simple", func(t *testing.T) {
47 r := newRecorder(t)
48
49 languageModel, err := pair.builder(r)
50 require.NoError(t, err, "failed to build language model")
51
52 agent := ai.NewAgent(
53 languageModel,
54 ai.WithSystemPrompt("You are a helpful assistant"),
55 )
56 result, err := agent.Generate(t.Context(), ai.AgentCall{
57 Prompt: "Say hi in Portuguese",
58 ProviderOptions: pair.providerOptions,
59 MaxOutputTokens: ai.IntOption(4000),
60 })
61 require.NoError(t, err, "failed to generate")
62 checkResult(t, result)
63 })
64 t.Run("simple streaming", func(t *testing.T) {
65 r := newRecorder(t)
66
67 languageModel, err := pair.builder(r)
68 require.NoError(t, err, "failed to build language model")
69
70 agent := ai.NewAgent(
71 languageModel,
72 ai.WithSystemPrompt("You are a helpful assistant"),
73 )
74 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
75 Prompt: "Say hi in Portuguese",
76 ProviderOptions: pair.providerOptions,
77 MaxOutputTokens: ai.IntOption(4000),
78 })
79 require.NoError(t, err, "failed to generate")
80 checkResult(t, result)
81 })
82}
83
84func testTool(t *testing.T, pair builderPair) {
85 type WeatherInput struct {
86 Location string `json:"location" description:"the city"`
87 }
88
89 weatherTool := ai.NewAgentTool(
90 "weather",
91 "Get weather information for a location",
92 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
93 return ai.NewTextResponse("40 C"), nil
94 },
95 )
96 checkResult := func(t *testing.T, result *ai.AgentResult) {
97 require.Len(t, result.Steps, 2)
98
99 var toolCalls []ai.ToolCallContent
100 for _, content := range result.Steps[0].Content {
101 if content.GetType() == ai.ContentTypeToolCall {
102 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
103 }
104 }
105 for _, tc := range toolCalls {
106 require.False(t, tc.Invalid)
107 }
108 require.Len(t, toolCalls, 1)
109 require.Equal(t, toolCalls[0].ToolName, "weather")
110
111 want1 := "Florence"
112 want2 := "40"
113 got := result.Response.Content.Text()
114 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
115 }
116
117 t.Run("tool", func(t *testing.T) {
118 r := newRecorder(t)
119
120 languageModel, err := pair.builder(r)
121 require.NoError(t, err, "failed to build language model")
122
123 agent := ai.NewAgent(
124 languageModel,
125 ai.WithSystemPrompt("You are a helpful assistant"),
126 ai.WithTools(weatherTool),
127 )
128 result, err := agent.Generate(t.Context(), ai.AgentCall{
129 Prompt: "What's the weather in Florence,Italy?",
130 ProviderOptions: pair.providerOptions,
131 MaxOutputTokens: ai.IntOption(4000),
132 })
133 require.NoError(t, err, "failed to generate")
134 checkResult(t, result)
135 })
136 t.Run("tool streaming", func(t *testing.T) {
137 r := newRecorder(t)
138
139 languageModel, err := pair.builder(r)
140 require.NoError(t, err, "failed to build language model")
141
142 agent := ai.NewAgent(
143 languageModel,
144 ai.WithSystemPrompt("You are a helpful assistant"),
145 ai.WithTools(weatherTool),
146 )
147 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
148 Prompt: "What's the weather in Florence,Italy?",
149 ProviderOptions: pair.providerOptions,
150 MaxOutputTokens: ai.IntOption(4000),
151 })
152 require.NoError(t, err, "failed to generate")
153 checkResult(t, result)
154 })
155}
156
157func testMultiTool(t *testing.T, pair builderPair) {
158 type CalculatorInput struct {
159 A int `json:"a" description:"first number"`
160 B int `json:"b" description:"second number"`
161 }
162
163 addTool := ai.NewAgentTool(
164 "add",
165 "Add two numbers",
166 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
167 result := input.A + input.B
168 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
169 },
170 )
171 multiplyTool := ai.NewAgentTool(
172 "multiply",
173 "Multiply two numbers",
174 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
175 result := input.A * input.B
176 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
177 },
178 )
179 checkResult := func(t *testing.T, result *ai.AgentResult) {
180 require.Len(t, result.Steps, 2)
181
182 var toolCalls []ai.ToolCallContent
183 for _, content := range result.Steps[0].Content {
184 if content.GetType() == ai.ContentTypeToolCall {
185 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
186 }
187 }
188 for _, tc := range toolCalls {
189 require.False(t, tc.Invalid)
190 }
191 require.Len(t, toolCalls, 2)
192
193 finalText := result.Response.Content.Text()
194 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
195 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
196 }
197
198 t.Run("multi tool", func(t *testing.T) {
199 r := newRecorder(t)
200
201 languageModel, err := pair.builder(r)
202 require.NoError(t, err, "failed to build language model")
203
204 agent := ai.NewAgent(
205 languageModel,
206 ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
207 ai.WithTools(addTool),
208 ai.WithTools(multiplyTool),
209 )
210 result, err := agent.Generate(t.Context(), ai.AgentCall{
211 Prompt: "Add and multiply the number 2 and 3",
212 ProviderOptions: pair.providerOptions,
213 MaxOutputTokens: ai.IntOption(4000),
214 })
215 require.NoError(t, err, "failed to generate")
216 checkResult(t, result)
217 })
218 t.Run("multi tool streaming", func(t *testing.T) {
219 r := newRecorder(t)
220
221 languageModel, err := pair.builder(r)
222 require.NoError(t, err, "failed to build language model")
223
224 agent := ai.NewAgent(
225 languageModel,
226 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
227 ai.WithTools(addTool),
228 ai.WithTools(multiplyTool),
229 )
230 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
231 Prompt: "Add and multiply the number 2 and 3",
232 ProviderOptions: pair.providerOptions,
233 MaxOutputTokens: ai.IntOption(4000),
234 })
235 require.NoError(t, err, "failed to generate")
236 checkResult(t, result)
237 })
238}
239
240func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
241 for _, pair := range pairs {
242 t.Run(pair.name, func(t *testing.T) {
243 t.Run("thinking", func(t *testing.T) {
244 r := newRecorder(t)
245
246 languageModel, err := pair.builder(r)
247 require.NoError(t, err, "failed to build language model")
248
249 type WeatherInput struct {
250 Location string `json:"location" description:"the city"`
251 }
252
253 weatherTool := ai.NewAgentTool(
254 "weather",
255 "Get weather information for a location",
256 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
257 return ai.NewTextResponse("40 C"), nil
258 },
259 )
260
261 agent := ai.NewAgent(
262 languageModel,
263 ai.WithSystemPrompt("You are a helpful assistant"),
264 ai.WithTools(weatherTool),
265 )
266 result, err := agent.Generate(t.Context(), ai.AgentCall{
267 Prompt: "What's the weather in Florence, Italy?",
268 ProviderOptions: pair.providerOptions,
269 })
270 require.NoError(t, err, "failed to generate")
271
272 want1 := "Florence"
273 want2 := "40"
274 got := result.Response.Content.Text()
275 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
276
277 thinkChecks(t, result)
278 })
279 t.Run("thinking-streaming", func(t *testing.T) {
280 r := newRecorder(t)
281
282 languageModel, err := pair.builder(r)
283 require.NoError(t, err, "failed to build language model")
284
285 type WeatherInput struct {
286 Location string `json:"location" description:"the city"`
287 }
288
289 weatherTool := ai.NewAgentTool(
290 "weather",
291 "Get weather information for a location",
292 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
293 return ai.NewTextResponse("40 C"), nil
294 },
295 )
296
297 agent := ai.NewAgent(
298 languageModel,
299 ai.WithSystemPrompt("You are a helpful assistant"),
300 ai.WithTools(weatherTool),
301 )
302 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
303 Prompt: "What's the weather in Florence, Italy?",
304 ProviderOptions: pair.providerOptions,
305 })
306 require.NoError(t, err, "failed to generate")
307
308 want1 := "Florence"
309 want2 := "40"
310 got := result.Response.Content.Text()
311 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
312
313 thinkChecks(t, result)
314 })
315 })
316 }
317}