1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/charmbracelet/fantasy/anthropic"
11 "github.com/charmbracelet/fantasy/google"
12 "github.com/charmbracelet/fantasy/openai"
13 _ "github.com/joho/godotenv/autoload"
14 "github.com/stretchr/testify/require"
15)
16
17func TestSimple(t *testing.T) {
18 for _, pair := range languageModelBuilders {
19 t.Run(pair.name, func(t *testing.T) {
20 r := newRecorder(t)
21
22 languageModel, err := pair.builder(r)
23 require.NoError(t, err, "failed to build language model")
24
25 agent := ai.NewAgent(
26 languageModel,
27 ai.WithSystemPrompt("You are a helpful assistant"),
28 )
29 result, err := agent.Generate(t.Context(), ai.AgentCall{
30 Prompt: "Say hi in Portuguese",
31 })
32 require.NoError(t, err, "failed to generate")
33
34 option1 := "Oi"
35 option2 := "Olá"
36 got := result.Response.Content.Text()
37 require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
38 })
39 }
40}
41
42func TestTool(t *testing.T) {
43 for _, pair := range languageModelBuilders {
44 t.Run(pair.name, func(t *testing.T) {
45 r := newRecorder(t)
46
47 languageModel, err := pair.builder(r)
48 require.NoError(t, err, "failed to build language model")
49
50 type WeatherInput struct {
51 Location string `json:"location" description:"the city"`
52 }
53
54 weatherTool := ai.NewAgentTool(
55 "weather",
56 "Get weather information for a location",
57 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
58 return ai.NewTextResponse("40 C"), nil
59 },
60 )
61
62 agent := ai.NewAgent(
63 languageModel,
64 ai.WithSystemPrompt("You are a helpful assistant"),
65 ai.WithTools(weatherTool),
66 )
67 result, err := agent.Generate(t.Context(), ai.AgentCall{
68 Prompt: "What's the weather in Florence?",
69 })
70 require.NoError(t, err, "failed to generate")
71
72 want1 := "Florence"
73 want2 := "40"
74 got := result.Response.Content.Text()
75 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
76 })
77 }
78}
79
80func TestThinking(t *testing.T) {
81 for _, pair := range thinkingLanguageModelBuilders {
82 t.Run(pair.name, func(t *testing.T) {
83 r := newRecorder(t)
84
85 languageModel, err := pair.builder(r)
86 require.NoError(t, err, "failed to build language model")
87
88 type WeatherInput struct {
89 Location string `json:"location" description:"the city"`
90 }
91
92 weatherTool := ai.NewAgentTool(
93 "weather",
94 "Get weather information for a location",
95 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
96 return ai.NewTextResponse("40 C"), nil
97 },
98 )
99
100 agent := ai.NewAgent(
101 languageModel,
102 ai.WithSystemPrompt("You are a helpful assistant"),
103 ai.WithTools(weatherTool),
104 )
105 result, err := agent.Generate(t.Context(), ai.AgentCall{
106 Prompt: "What's the weather in Florence, Italy?",
107 ProviderOptions: ai.ProviderOptions{
108 "anthropic": &anthropic.ProviderOptions{
109 Thinking: &anthropic.ThinkingProviderOption{
110 BudgetTokens: 10_000,
111 },
112 },
113 "google": &google.ProviderOptions{
114 ThinkingConfig: &google.ThinkingConfig{
115 ThinkingBudget: ai.IntOption(100),
116 IncludeThoughts: ai.BoolOption(true),
117 },
118 },
119 "openai": &openai.ProviderOptions{
120 ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
121 },
122 },
123 })
124 require.NoError(t, err, "failed to generate")
125
126 want1 := "Florence"
127 want2 := "40"
128 got := result.Response.Content.Text()
129 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
130
131 testThinkingSteps(t, languageModel.Provider(), result.Steps)
132 })
133 }
134}
135
136func TestThinkingStreaming(t *testing.T) {
137 for _, pair := range thinkingLanguageModelBuilders {
138 t.Run(pair.name, func(t *testing.T) {
139 r := newRecorder(t)
140
141 languageModel, err := pair.builder(r)
142 require.NoError(t, err, "failed to build language model")
143
144 type WeatherInput struct {
145 Location string `json:"location" description:"the city"`
146 }
147
148 weatherTool := ai.NewAgentTool(
149 "weather",
150 "Get weather information for a location",
151 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
152 return ai.NewTextResponse("40 C"), nil
153 },
154 )
155
156 agent := ai.NewAgent(
157 languageModel,
158 ai.WithSystemPrompt("You are a helpful assistant"),
159 ai.WithTools(weatherTool),
160 )
161 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
162 Prompt: "What's the weather in Florence, Italy?",
163 ProviderOptions: ai.ProviderOptions{
164 "anthropic": &anthropic.ProviderOptions{
165 Thinking: &anthropic.ThinkingProviderOption{
166 BudgetTokens: 10_000,
167 },
168 },
169 "google": &google.ProviderOptions{
170 ThinkingConfig: &google.ThinkingConfig{
171 ThinkingBudget: ai.IntOption(100),
172 IncludeThoughts: ai.BoolOption(true),
173 },
174 },
175 "openai": &openai.ProviderOptions{
176 ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
177 },
178 },
179 })
180 require.NoError(t, err, "failed to generate")
181
182 want1 := "Florence"
183 want2 := "40"
184 got := result.Response.Content.Text()
185 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
186
187 testThinkingSteps(t, languageModel.Provider(), result.Steps)
188 })
189 }
190}
191
192func TestStream(t *testing.T) {
193 for _, pair := range languageModelBuilders {
194 t.Run(pair.name, func(t *testing.T) {
195 r := newRecorder(t)
196
197 languageModel, err := pair.builder(r)
198 require.NoError(t, err, "failed to build language model")
199
200 agent := ai.NewAgent(
201 languageModel,
202 ai.WithSystemPrompt("You are a helpful assistant"),
203 )
204
205 var collectedText strings.Builder
206 textDeltaCount := 0
207 stepCount := 0
208
209 streamCall := ai.AgentStreamCall{
210 Prompt: "Count from 1 to 3 in Spanish",
211 OnTextDelta: func(id, text string) error {
212 textDeltaCount++
213 collectedText.WriteString(text)
214 return nil
215 },
216 OnStepFinish: func(step ai.StepResult) error {
217 stepCount++
218 return nil
219 },
220 }
221
222 result, err := agent.Stream(t.Context(), streamCall)
223 require.NoError(t, err, "failed to stream")
224
225 finalText := result.Response.Content.Text()
226 require.NotEmpty(t, finalText, "expected non-empty response")
227
228 require.True(t, strings.Contains(strings.ToLower(finalText), "uno") &&
229 strings.Contains(strings.ToLower(finalText), "dos") &&
230 strings.Contains(strings.ToLower(finalText), "tres"), "unexpected response: %q", finalText)
231
232 require.Greater(t, textDeltaCount, 0, "expected at least one text delta callback")
233
234 require.Greater(t, stepCount, 0, "expected at least one step finish callback")
235
236 require.NotEmpty(t, collectedText.String(), "expected collected text from deltas to be non-empty")
237 })
238 }
239}
240
241func TestStreamWithTools(t *testing.T) {
242 for _, pair := range languageModelBuilders {
243 t.Run(pair.name, func(t *testing.T) {
244 r := newRecorder(t)
245
246 languageModel, err := pair.builder(r)
247 require.NoError(t, err, "failed to build language model")
248
249 type CalculatorInput struct {
250 A int `json:"a" description:"first number"`
251 B int `json:"b" description:"second number"`
252 }
253
254 calculatorTool := ai.NewAgentTool(
255 "add",
256 "Add two numbers",
257 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
258 result := input.A + input.B
259 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
260 },
261 )
262
263 agent := ai.NewAgent(
264 languageModel,
265 ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
266 ai.WithTools(calculatorTool),
267 )
268
269 toolCallCount := 0
270 toolResultCount := 0
271 var collectedText strings.Builder
272
273 streamCall := ai.AgentStreamCall{
274 Prompt: "What is 15 + 27?",
275 OnTextDelta: func(id, text string) error {
276 collectedText.WriteString(text)
277 return nil
278 },
279 OnToolCall: func(toolCall ai.ToolCallContent) error {
280 toolCallCount++
281 require.Equal(t, "add", toolCall.ToolName, "unexpected tool name")
282 return nil
283 },
284 OnToolResult: func(result ai.ToolResultContent) error {
285 toolResultCount++
286 return nil
287 },
288 }
289
290 result, err := agent.Stream(t.Context(), streamCall)
291 require.NoError(t, err, "failed to stream")
292
293 finalText := result.Response.Content.Text()
294 require.Contains(t, finalText, "42", "expected response to contain '42', got: %q", finalText)
295
296 require.Greater(t, toolCallCount, 0, "expected at least one tool call")
297
298 require.Greater(t, toolResultCount, 0, "expected at least one tool result")
299 })
300 }
301}