1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/stretchr/testify/require"
11 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
12)
13
14type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
15
16type builderPair struct {
17 name string
18 builder builderFunc
19 providerOptions ai.ProviderOptions
20}
21
22func testCommon(t *testing.T, pairs []builderPair) {
23 for _, pair := range pairs {
24 testSimple(t, pair)
25 testTool(t, pair)
26 testMultiTool(t, pair)
27 }
28}
29
30func testSimple(t *testing.T, pair builderPair) {
31 checkResult := func(t *testing.T, result *ai.AgentResult) {
32 option1 := "Oi"
33 option2 := "Olá"
34 got := result.Response.Content.Text()
35 require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
36 }
37
38 t.Run("simple "+pair.name, func(t *testing.T) {
39 r := newRecorder(t)
40
41 languageModel, err := pair.builder(r)
42 require.NoError(t, err, "failed to build language model")
43
44 agent := ai.NewAgent(
45 languageModel,
46 ai.WithSystemPrompt("You are a helpful assistant"),
47 )
48 result, err := agent.Generate(t.Context(), ai.AgentCall{
49 Prompt: "Say hi in Portuguese",
50 ProviderOptions: pair.providerOptions,
51 MaxOutputTokens: ai.IntOption(4000),
52 })
53 require.NoError(t, err, "failed to generate")
54 checkResult(t, result)
55 })
56 t.Run("simple streaming "+pair.name, func(t *testing.T) {
57 r := newRecorder(t)
58
59 languageModel, err := pair.builder(r)
60 require.NoError(t, err, "failed to build language model")
61
62 agent := ai.NewAgent(
63 languageModel,
64 ai.WithSystemPrompt("You are a helpful assistant"),
65 )
66 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
67 Prompt: "Say hi in Portuguese",
68 ProviderOptions: pair.providerOptions,
69 MaxOutputTokens: ai.IntOption(4000),
70 })
71 require.NoError(t, err, "failed to generate")
72 checkResult(t, result)
73 })
74}
75
76func testTool(t *testing.T, pair builderPair) {
77 type WeatherInput struct {
78 Location string `json:"location" description:"the city"`
79 }
80
81 weatherTool := ai.NewAgentTool(
82 "weather",
83 "Get weather information for a location",
84 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
85 return ai.NewTextResponse("40 C"), nil
86 },
87 )
88 checkResult := func(t *testing.T, result *ai.AgentResult) {
89 require.Len(t, result.Steps, 2)
90
91 var toolCalls []ai.ToolCallContent
92 for _, content := range result.Steps[0].Content {
93 if content.GetType() == ai.ContentTypeToolCall {
94 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
95 }
96 }
97 require.Len(t, toolCalls, 1)
98 require.Equal(t, toolCalls[0].ToolName, "weather")
99
100 want1 := "Florence"
101 want2 := "40"
102 got := result.Response.Content.Text()
103 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
104 }
105
106 t.Run("tool "+pair.name, func(t *testing.T) {
107 r := newRecorder(t)
108
109 languageModel, err := pair.builder(r)
110 require.NoError(t, err, "failed to build language model")
111
112 agent := ai.NewAgent(
113 languageModel,
114 ai.WithSystemPrompt("You are a helpful assistant"),
115 ai.WithTools(weatherTool),
116 )
117 result, err := agent.Generate(t.Context(), ai.AgentCall{
118 Prompt: "What's the weather in Florence,Italy?",
119 ProviderOptions: pair.providerOptions,
120 MaxOutputTokens: ai.IntOption(4000),
121 })
122 require.NoError(t, err, "failed to generate")
123 checkResult(t, result)
124 })
125 t.Run("tool streaming "+pair.name, func(t *testing.T) {
126 r := newRecorder(t)
127
128 languageModel, err := pair.builder(r)
129 require.NoError(t, err, "failed to build language model")
130
131 agent := ai.NewAgent(
132 languageModel,
133 ai.WithSystemPrompt("You are a helpful assistant"),
134 ai.WithTools(weatherTool),
135 )
136 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
137 Prompt: "What's the weather in Florence,Italy?",
138 ProviderOptions: pair.providerOptions,
139 MaxOutputTokens: ai.IntOption(4000),
140 })
141 require.NoError(t, err, "failed to generate")
142 checkResult(t, result)
143 })
144}
145
146func testMultiTool(t *testing.T, pair builderPair) {
147 type WeatherInput struct {
148 Location string `json:"location" description:"the city"`
149 }
150
151 type CalculatorInput struct {
152 A int `json:"a" description:"first number"`
153 B int `json:"b" description:"second number"`
154 }
155
156 addTool := ai.NewAgentTool(
157 "add",
158 "Add two numbers",
159 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
160 result := input.A + input.B
161 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
162 },
163 )
164 multiplyTool := ai.NewAgentTool(
165 "multiply",
166 "Multiply two numbers",
167 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
168 result := input.A * input.B
169 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
170 },
171 )
172 checkResult := func(t *testing.T, result *ai.AgentResult) {
173 require.Len(t, result.Steps, 2)
174
175 var toolCalls []ai.ToolCallContent
176 for _, content := range result.Steps[0].Content {
177 if content.GetType() == ai.ContentTypeToolCall {
178 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
179 }
180 }
181 require.Len(t, toolCalls, 2)
182
183 finalText := result.Response.Content.Text()
184 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
185 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
186 }
187
188 t.Run("multi tool "+pair.name, func(t *testing.T) {
189 r := newRecorder(t)
190
191 languageModel, err := pair.builder(r)
192 require.NoError(t, err, "failed to build language model")
193
194 agent := ai.NewAgent(
195 languageModel,
196 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
197 ai.WithTools(addTool),
198 ai.WithTools(multiplyTool),
199 )
200 result, err := agent.Generate(t.Context(), ai.AgentCall{
201 Prompt: "Add and multiply the number 2 and 3",
202 ProviderOptions: pair.providerOptions,
203 MaxOutputTokens: ai.IntOption(4000),
204 })
205 require.NoError(t, err, "failed to generate")
206 checkResult(t, result)
207 })
208 t.Run("multi tool streaming "+pair.name, func(t *testing.T) {
209 r := newRecorder(t)
210
211 languageModel, err := pair.builder(r)
212 require.NoError(t, err, "failed to build language model")
213
214 agent := ai.NewAgent(
215 languageModel,
216 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
217 ai.WithTools(addTool),
218 ai.WithTools(multiplyTool),
219 )
220 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
221 Prompt: "Add and multiply the number 2 and 3",
222 ProviderOptions: pair.providerOptions,
223 MaxOutputTokens: ai.IntOption(4000),
224 })
225 require.NoError(t, err, "failed to generate")
226 checkResult(t, result)
227 })
228}
229
230func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
231 for _, pair := range pairs {
232 t.Run(pair.name, func(t *testing.T) {
233 r := newRecorder(t)
234
235 languageModel, err := pair.builder(r)
236 require.NoError(t, err, "failed to build language model")
237
238 type WeatherInput struct {
239 Location string `json:"location" description:"the city"`
240 }
241
242 weatherTool := ai.NewAgentTool(
243 "weather",
244 "Get weather information for a location",
245 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
246 return ai.NewTextResponse("40 C"), nil
247 },
248 )
249
250 agent := ai.NewAgent(
251 languageModel,
252 ai.WithSystemPrompt("You are a helpful assistant"),
253 ai.WithTools(weatherTool),
254 )
255 result, err := agent.Generate(t.Context(), ai.AgentCall{
256 Prompt: "What's the weather in Florence, Italy?",
257 ProviderOptions: pair.providerOptions,
258 // ProviderOptions: ai.ProviderOptions{
259 // "anthropic": &anthropic.ProviderOptions{
260 // Thinking: &anthropic.ThinkingProviderOption{
261 // BudgetTokens: 10_000,
262 // },
263 // },
264 // "google": &google.ProviderOptions{
265 // ThinkingConfig: &google.ThinkingConfig{
266 // ThinkingBudget: ai.IntOption(100),
267 // IncludeThoughts: ai.BoolOption(true),
268 // },
269 // },
270 // "openai": &openai.ProviderOptions{
271 // ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
272 // },
273 // },
274 })
275 require.NoError(t, err, "failed to generate")
276
277 want1 := "Florence"
278 want2 := "40"
279 got := result.Response.Content.Text()
280 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
281
282 thinkChecks(t, result)
283 })
284 }
285}