1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/stretchr/testify/require"
11 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
12)
13
14type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
15
16type builderPair struct {
17 name string
18 builder builderFunc
19 providerOptions ai.ProviderOptions
20}
21
22func testCommon(t *testing.T, pairs []builderPair) {
23 for _, pair := range pairs {
24 testSimple(t, pair)
25 testTool(t, pair)
26 testMultiTool(t, pair)
27 }
28}
29
30func testSimple(t *testing.T, pair builderPair) {
31 checkResult := func(t *testing.T, result *ai.AgentResult) {
32 option1 := "Oi"
33 option2 := "Olá"
34 got := result.Response.Content.Text()
35 require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
36 }
37
38 t.Run("simple "+pair.name, func(t *testing.T) {
39 r := newRecorder(t)
40
41 languageModel, err := pair.builder(r)
42 require.NoError(t, err, "failed to build language model")
43
44 agent := ai.NewAgent(
45 languageModel,
46 ai.WithSystemPrompt("You are a helpful assistant"),
47 )
48 result, err := agent.Generate(t.Context(), ai.AgentCall{
49 Prompt: "Say hi in Portuguese",
50 ProviderOptions: pair.providerOptions,
51 MaxOutputTokens: ai.IntOption(4000),
52 })
53 require.NoError(t, err, "failed to generate")
54 checkResult(t, result)
55 })
56 t.Run("simple streaming "+pair.name, func(t *testing.T) {
57 r := newRecorder(t)
58
59 languageModel, err := pair.builder(r)
60 require.NoError(t, err, "failed to build language model")
61
62 agent := ai.NewAgent(
63 languageModel,
64 ai.WithSystemPrompt("You are a helpful assistant"),
65 )
66 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
67 Prompt: "Say hi in Portuguese",
68 ProviderOptions: pair.providerOptions,
69 MaxOutputTokens: ai.IntOption(4000),
70 })
71 require.NoError(t, err, "failed to generate")
72 checkResult(t, result)
73 })
74}
75
76func testTool(t *testing.T, pair builderPair) {
77 type WeatherInput struct {
78 Location string `json:"location" description:"the city"`
79 }
80
81 weatherTool := ai.NewAgentTool(
82 "weather",
83 "Get weather information for a location",
84 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
85 return ai.NewTextResponse("40 C"), nil
86 },
87 )
88 checkResult := func(t *testing.T, result *ai.AgentResult) {
89 require.Len(t, result.Steps, 2)
90
91 var toolCalls []ai.ToolCallContent
92 for _, content := range result.Steps[0].Content {
93 if content.GetType() == ai.ContentTypeToolCall {
94 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
95 }
96 }
97 for _, tc := range toolCalls {
98 require.False(t, tc.Invalid)
99 }
100 require.Len(t, toolCalls, 1)
101 require.Equal(t, toolCalls[0].ToolName, "weather")
102
103 want1 := "Florence"
104 want2 := "40"
105 got := result.Response.Content.Text()
106 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
107 }
108
109 t.Run("tool "+pair.name, func(t *testing.T) {
110 r := newRecorder(t)
111
112 languageModel, err := pair.builder(r)
113 require.NoError(t, err, "failed to build language model")
114
115 agent := ai.NewAgent(
116 languageModel,
117 ai.WithSystemPrompt("You are a helpful assistant"),
118 ai.WithTools(weatherTool),
119 )
120 result, err := agent.Generate(t.Context(), ai.AgentCall{
121 Prompt: "What's the weather in Florence,Italy?",
122 ProviderOptions: pair.providerOptions,
123 MaxOutputTokens: ai.IntOption(4000),
124 })
125 require.NoError(t, err, "failed to generate")
126 checkResult(t, result)
127 })
128 t.Run("tool streaming "+pair.name, func(t *testing.T) {
129 r := newRecorder(t)
130
131 languageModel, err := pair.builder(r)
132 require.NoError(t, err, "failed to build language model")
133
134 agent := ai.NewAgent(
135 languageModel,
136 ai.WithSystemPrompt("You are a helpful assistant"),
137 ai.WithTools(weatherTool),
138 )
139 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
140 Prompt: "What's the weather in Florence,Italy?",
141 ProviderOptions: pair.providerOptions,
142 MaxOutputTokens: ai.IntOption(4000),
143 })
144 require.NoError(t, err, "failed to generate")
145 checkResult(t, result)
146 })
147}
148
149func testMultiTool(t *testing.T, pair builderPair) {
150 type WeatherInput struct {
151 Location string `json:"location" description:"the city"`
152 }
153
154 type CalculatorInput struct {
155 A int `json:"a" description:"first number"`
156 B int `json:"b" description:"second number"`
157 }
158
159 addTool := ai.NewAgentTool(
160 "add",
161 "Add two numbers",
162 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
163 result := input.A + input.B
164 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
165 },
166 )
167 multiplyTool := ai.NewAgentTool(
168 "multiply",
169 "Multiply two numbers",
170 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
171 result := input.A * input.B
172 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
173 },
174 )
175 checkResult := func(t *testing.T, result *ai.AgentResult) {
176 require.Len(t, result.Steps, 2)
177
178 var toolCalls []ai.ToolCallContent
179 for _, content := range result.Steps[0].Content {
180 if content.GetType() == ai.ContentTypeToolCall {
181 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
182 }
183 }
184 for _, tc := range toolCalls {
185 require.False(t, tc.Invalid)
186 }
187 require.Len(t, toolCalls, 2)
188
189 finalText := result.Response.Content.Text()
190 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
191 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
192 }
193
194 t.Run("multi tool "+pair.name, func(t *testing.T) {
195 r := newRecorder(t)
196
197 languageModel, err := pair.builder(r)
198 require.NoError(t, err, "failed to build language model")
199
200 agent := ai.NewAgent(
201 languageModel,
202 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
203 ai.WithTools(addTool),
204 ai.WithTools(multiplyTool),
205 )
206 result, err := agent.Generate(t.Context(), ai.AgentCall{
207 Prompt: "Add and multiply the number 2 and 3",
208 ProviderOptions: pair.providerOptions,
209 MaxOutputTokens: ai.IntOption(4000),
210 })
211 require.NoError(t, err, "failed to generate")
212 checkResult(t, result)
213 })
214 t.Run("multi tool streaming "+pair.name, func(t *testing.T) {
215 r := newRecorder(t)
216
217 languageModel, err := pair.builder(r)
218 require.NoError(t, err, "failed to build language model")
219
220 agent := ai.NewAgent(
221 languageModel,
222 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
223 ai.WithTools(addTool),
224 ai.WithTools(multiplyTool),
225 )
226 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
227 Prompt: "Add and multiply the number 2 and 3",
228 ProviderOptions: pair.providerOptions,
229 MaxOutputTokens: ai.IntOption(4000),
230 })
231 require.NoError(t, err, "failed to generate")
232 checkResult(t, result)
233 })
234}
235
236func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
237 for _, pair := range pairs {
238 t.Run("thinking-"+pair.name, func(t *testing.T) {
239 r := newRecorder(t)
240
241 languageModel, err := pair.builder(r)
242 require.NoError(t, err, "failed to build language model")
243
244 type WeatherInput struct {
245 Location string `json:"location" description:"the city"`
246 }
247
248 weatherTool := ai.NewAgentTool(
249 "weather",
250 "Get weather information for a location",
251 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
252 return ai.NewTextResponse("40 C"), nil
253 },
254 )
255
256 agent := ai.NewAgent(
257 languageModel,
258 ai.WithSystemPrompt("You are a helpful assistant"),
259 ai.WithTools(weatherTool),
260 )
261 result, err := agent.Generate(t.Context(), ai.AgentCall{
262 Prompt: "What's the weather in Florence, Italy?",
263 ProviderOptions: pair.providerOptions,
264 // ProviderOptions: ai.ProviderOptions{
265 // "anthropic": &anthropic.ProviderOptions{
266 // Thinking: &anthropic.ThinkingProviderOption{
267 // BudgetTokens: 10_000,
268 // },
269 // },
270 // "google": &google.ProviderOptions{
271 // ThinkingConfig: &google.ThinkingConfig{
272 // ThinkingBudget: ai.IntOption(100),
273 // IncludeThoughts: ai.BoolOption(true),
274 // },
275 // },
276 // "openai": &openai.ProviderOptions{
277 // ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
278 // },
279 // },
280 })
281 require.NoError(t, err, "failed to generate")
282
283 want1 := "Florence"
284 want2 := "40"
285 got := result.Response.Content.Text()
286 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
287
288 thinkChecks(t, result)
289 })
290 }
291}