1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/charmbracelet/fantasy/anthropic"
11 "github.com/charmbracelet/fantasy/google"
12 "github.com/charmbracelet/fantasy/openai"
13 "github.com/charmbracelet/fantasy/openrouter"
14 _ "github.com/joho/godotenv/autoload"
15 "github.com/stretchr/testify/require"
16)
17
18func TestSimple(t *testing.T) {
19 for _, pair := range languageModelBuilders {
20 t.Run(pair.name, func(t *testing.T) {
21 r := newRecorder(t)
22
23 languageModel, err := pair.builder(r)
24 require.NoError(t, err, "failed to build language model")
25
26 agent := ai.NewAgent(
27 languageModel,
28 ai.WithSystemPrompt("You are a helpful assistant"),
29 )
30 result, err := agent.Generate(t.Context(), ai.AgentCall{
31 Prompt: "Say hi in Portuguese",
32 })
33 require.NoError(t, err, "failed to generate")
34
35 option1 := "Oi"
36 option2 := "Olá"
37 got := result.Response.Content.Text()
38 require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
39 })
40 }
41}
42
43func TestTool(t *testing.T) {
44 for _, pair := range languageModelBuilders {
45 t.Run(pair.name, func(t *testing.T) {
46 r := newRecorder(t)
47
48 languageModel, err := pair.builder(r)
49 require.NoError(t, err, "failed to build language model")
50
51 type WeatherInput struct {
52 Location string `json:"location" description:"the city"`
53 }
54
55 weatherTool := ai.NewAgentTool(
56 "weather",
57 "Get weather information for a location",
58 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
59 return ai.NewTextResponse("40 C"), nil
60 },
61 )
62
63 agent := ai.NewAgent(
64 languageModel,
65 ai.WithSystemPrompt("You are a helpful assistant"),
66 ai.WithTools(weatherTool),
67 )
68 result, err := agent.Generate(t.Context(), ai.AgentCall{
69 Prompt: "What's the weather in Florence?",
70 })
71 require.NoError(t, err, "failed to generate")
72
73 want1 := "Florence"
74 want2 := "40"
75 got := result.Response.Content.Text()
76 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
77 })
78 }
79}
80
81func TestThinking(t *testing.T) {
82 for _, pair := range thinkingLanguageModelBuilders {
83 t.Run(pair.name, func(t *testing.T) {
84 r := newRecorder(t)
85
86 languageModel, err := pair.builder(r)
87 require.NoError(t, err, "failed to build language model")
88
89 type WeatherInput struct {
90 Location string `json:"location" description:"the city"`
91 }
92
93 weatherTool := ai.NewAgentTool(
94 "weather",
95 "Get weather information for a location",
96 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
97 return ai.NewTextResponse("40 C"), nil
98 },
99 )
100
101 agent := ai.NewAgent(
102 languageModel,
103 ai.WithSystemPrompt("You are a helpful assistant"),
104 ai.WithTools(weatherTool),
105 )
106 result, err := agent.Generate(t.Context(), ai.AgentCall{
107 Prompt: "What's the weather in Florence, Italy?",
108 ProviderOptions: ai.ProviderOptions{
109 "anthropic": &anthropic.ProviderOptions{
110 Thinking: &anthropic.ThinkingProviderOption{
111 BudgetTokens: 10_000,
112 },
113 },
114 "google": &google.ProviderOptions{
115 ThinkingConfig: &google.ThinkingConfig{
116 ThinkingBudget: ai.IntOption(100),
117 IncludeThoughts: ai.BoolOption(true),
118 },
119 },
120 "openai": &openai.ProviderOptions{
121 ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
122 },
123 "openrouter": &openrouter.ProviderOptions{
124 Reasoning: &openrouter.ReasoningOptions{
125 Effort: openrouter.ReasoningEffortOption(openrouter.ReasoningEffortHigh),
126 },
127 },
128 },
129 })
130 require.NoError(t, err, "failed to generate")
131
132 want1 := "Florence"
133 want2 := "40"
134 got := result.Response.Content.Text()
135 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
136
137 testThinking(t, languageModel.Provider(), result.Steps)
138 })
139 }
140}
141
142func TestThinkingStreaming(t *testing.T) {
143 for _, pair := range thinkingLanguageModelBuilders {
144 t.Run(pair.name, func(t *testing.T) {
145 r := newRecorder(t)
146
147 languageModel, err := pair.builder(r)
148 require.NoError(t, err, "failed to build language model")
149
150 type WeatherInput struct {
151 Location string `json:"location" description:"the city"`
152 }
153
154 weatherTool := ai.NewAgentTool(
155 "weather",
156 "Get weather information for a location",
157 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
158 return ai.NewTextResponse("40 C"), nil
159 },
160 )
161
162 agent := ai.NewAgent(
163 languageModel,
164 ai.WithSystemPrompt("You are a helpful assistant"),
165 ai.WithTools(weatherTool),
166 )
167 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
168 Prompt: "What's the weather in Florence, Italy?",
169 ProviderOptions: ai.ProviderOptions{
170 "anthropic": &anthropic.ProviderOptions{
171 Thinking: &anthropic.ThinkingProviderOption{
172 BudgetTokens: 10_000,
173 },
174 },
175 "google": &google.ProviderOptions{
176 ThinkingConfig: &google.ThinkingConfig{
177 ThinkingBudget: ai.IntOption(100),
178 IncludeThoughts: ai.BoolOption(true),
179 },
180 },
181 "openai": &openai.ProviderOptions{
182 ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
183 },
184 },
185 })
186 require.NoError(t, err, "failed to generate")
187
188 want1 := "Florence"
189 want2 := "40"
190 got := result.Response.Content.Text()
191 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
192
193 testThinking(t, languageModel.Provider(), result.Steps)
194 })
195 }
196}
197
198func TestStream(t *testing.T) {
199 for _, pair := range languageModelBuilders {
200 t.Run(pair.name, func(t *testing.T) {
201 r := newRecorder(t)
202
203 languageModel, err := pair.builder(r)
204 require.NoError(t, err, "failed to build language model")
205
206 agent := ai.NewAgent(
207 languageModel,
208 ai.WithSystemPrompt("You are a helpful assistant"),
209 )
210
211 var collectedText strings.Builder
212 textDeltaCount := 0
213 stepCount := 0
214
215 streamCall := ai.AgentStreamCall{
216 Prompt: "Count from 1 to 3 in Spanish",
217 OnTextDelta: func(id, text string) error {
218 textDeltaCount++
219 collectedText.WriteString(text)
220 return nil
221 },
222 OnStepFinish: func(step ai.StepResult) error {
223 stepCount++
224 return nil
225 },
226 }
227
228 result, err := agent.Stream(t.Context(), streamCall)
229 require.NoError(t, err, "failed to stream")
230
231 finalText := result.Response.Content.Text()
232 require.NotEmpty(t, finalText, "expected non-empty response")
233
234 require.True(t, strings.Contains(strings.ToLower(finalText), "uno") &&
235 strings.Contains(strings.ToLower(finalText), "dos") &&
236 strings.Contains(strings.ToLower(finalText), "tres"), "unexpected response: %q", finalText)
237
238 require.Greater(t, textDeltaCount, 0, "expected at least one text delta callback")
239
240 require.Greater(t, stepCount, 0, "expected at least one step finish callback")
241
242 require.NotEmpty(t, collectedText.String(), "expected collected text from deltas to be non-empty")
243 })
244 }
245}
246
247func TestStreamWithTools(t *testing.T) {
248 for _, pair := range languageModelBuilders {
249 t.Run(pair.name, func(t *testing.T) {
250 r := newRecorder(t)
251
252 languageModel, err := pair.builder(r)
253 require.NoError(t, err, "failed to build language model")
254
255 type CalculatorInput struct {
256 A int `json:"a" description:"first number"`
257 B int `json:"b" description:"second number"`
258 }
259
260 calculatorTool := ai.NewAgentTool(
261 "add",
262 "Add two numbers",
263 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
264 result := input.A + input.B
265 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
266 },
267 )
268
269 agent := ai.NewAgent(
270 languageModel,
271 ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
272 ai.WithTools(calculatorTool),
273 )
274
275 toolCallCount := 0
276 toolResultCount := 0
277 var collectedText strings.Builder
278
279 streamCall := ai.AgentStreamCall{
280 Prompt: "What is 15 + 27?",
281 OnTextDelta: func(id, text string) error {
282 collectedText.WriteString(text)
283 return nil
284 },
285 OnToolCall: func(toolCall ai.ToolCallContent) error {
286 toolCallCount++
287 require.Equal(t, "add", toolCall.ToolName, "unexpected tool name")
288 return nil
289 },
290 OnToolResult: func(result ai.ToolResultContent) error {
291 toolResultCount++
292 return nil
293 },
294 }
295
296 result, err := agent.Stream(t.Context(), streamCall)
297 require.NoError(t, err, "failed to stream")
298
299 finalText := result.Response.Content.Text()
300 require.Contains(t, finalText, "42", "expected response to contain '42', got: %q", finalText)
301
302 require.Greater(t, toolCallCount, 0, "expected at least one tool call")
303
304 require.Greater(t, toolResultCount, 0, "expected at least one tool result")
305 })
306 }
307}
308
309func TestStreamWithMultipleTools(t *testing.T) {
310 for _, pair := range languageModelBuilders {
311 t.Run(pair.name, func(t *testing.T) {
312 r := newRecorder(t)
313
314 languageModel, err := pair.builder(r)
315 require.NoError(t, err, "failed to build language model")
316
317 type CalculatorInput struct {
318 A int `json:"a" description:"first number"`
319 B int `json:"b" description:"second number"`
320 }
321
322 addTool := ai.NewAgentTool(
323 "add",
324 "Add two numbers",
325 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
326 result := input.A + input.B
327 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
328 },
329 )
330 multiplyTool := ai.NewAgentTool(
331 "multiply",
332 "Multiply two numbers",
333 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
334 result := input.A * input.B
335 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
336 },
337 )
338
339 agent := ai.NewAgent(
340 languageModel,
341 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
342 ai.WithTools(addTool),
343 ai.WithTools(multiplyTool),
344 )
345
346 toolCallCount := 0
347 toolResultCount := 0
348 var collectedText strings.Builder
349
350 streamCall := ai.AgentStreamCall{
351 Prompt: "Add and multiply the number 2 and 3",
352 OnTextDelta: func(id, text string) error {
353 collectedText.WriteString(text)
354 return nil
355 },
356 OnToolCall: func(toolCall ai.ToolCallContent) error {
357 toolCallCount++
358 return nil
359 },
360 OnToolResult: func(result ai.ToolResultContent) error {
361 toolResultCount++
362 return nil
363 },
364 }
365
366 result, err := agent.Stream(t.Context(), streamCall)
367 require.NoError(t, err, "failed to stream")
368 require.Equal(t, len(result.Steps), 2, "expected all tool calls in step 1")
369 finalText := result.Response.Content.Text()
370 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
371 require.Contains(t, finalText, "6", "expected response to contain '5', got: %q", finalText)
372
373 require.Greater(t, toolCallCount, 0, "expected at least one tool call")
374
375 require.Greater(t, toolResultCount, 0, "expected at least one tool result")
376 })
377 }
378}