1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	"github.com/charmbracelet/fantasy/anthropic"
 11	"github.com/charmbracelet/fantasy/google"
 12	"github.com/charmbracelet/fantasy/openai"
 13	"github.com/charmbracelet/fantasy/openrouter"
 14	_ "github.com/joho/godotenv/autoload"
 15	"github.com/stretchr/testify/require"
 16)
 17
 18func TestSimple(t *testing.T) {
 19	for _, pair := range languageModelBuilders {
 20		t.Run(pair.name, func(t *testing.T) {
 21			r := newRecorder(t)
 22
 23			languageModel, err := pair.builder(r)
 24			require.NoError(t, err, "failed to build language model")
 25
 26			agent := ai.NewAgent(
 27				languageModel,
 28				ai.WithSystemPrompt("You are a helpful assistant"),
 29			)
 30			result, err := agent.Generate(t.Context(), ai.AgentCall{
 31				Prompt: "Say hi in Portuguese",
 32			})
 33			require.NoError(t, err, "failed to generate")
 34
 35			option1 := "Oi"
 36			option2 := "Olá"
 37			got := result.Response.Content.Text()
 38			require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
 39		})
 40	}
 41}
 42
 43func TestTool(t *testing.T) {
 44	for _, pair := range languageModelBuilders {
 45		t.Run(pair.name, func(t *testing.T) {
 46			r := newRecorder(t)
 47
 48			languageModel, err := pair.builder(r)
 49			require.NoError(t, err, "failed to build language model")
 50
 51			type WeatherInput struct {
 52				Location string `json:"location" description:"the city"`
 53			}
 54
 55			weatherTool := ai.NewAgentTool(
 56				"weather",
 57				"Get weather information for a location",
 58				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 59					return ai.NewTextResponse("40 C"), nil
 60				},
 61			)
 62
 63			agent := ai.NewAgent(
 64				languageModel,
 65				ai.WithSystemPrompt("You are a helpful assistant"),
 66				ai.WithTools(weatherTool),
 67			)
 68			result, err := agent.Generate(t.Context(), ai.AgentCall{
 69				Prompt: "What's the weather in Florence?",
 70			})
 71			require.NoError(t, err, "failed to generate")
 72
 73			want1 := "Florence"
 74			want2 := "40"
 75			got := result.Response.Content.Text()
 76			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
 77		})
 78	}
 79}
 80
 81func TestThinking(t *testing.T) {
 82	for _, pair := range thinkingLanguageModelBuilders {
 83		t.Run(pair.name, func(t *testing.T) {
 84			r := newRecorder(t)
 85
 86			languageModel, err := pair.builder(r)
 87			require.NoError(t, err, "failed to build language model")
 88
 89			type WeatherInput struct {
 90				Location string `json:"location" description:"the city"`
 91			}
 92
 93			weatherTool := ai.NewAgentTool(
 94				"weather",
 95				"Get weather information for a location",
 96				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 97					return ai.NewTextResponse("40 C"), nil
 98				},
 99			)
100
101			agent := ai.NewAgent(
102				languageModel,
103				ai.WithSystemPrompt("You are a helpful assistant"),
104				ai.WithTools(weatherTool),
105			)
106			result, err := agent.Generate(t.Context(), ai.AgentCall{
107				Prompt: "What's the weather in Florence, Italy?",
108				ProviderOptions: ai.ProviderOptions{
109					"anthropic": &anthropic.ProviderOptions{
110						Thinking: &anthropic.ThinkingProviderOption{
111							BudgetTokens: 10_000,
112						},
113					},
114					"google": &google.ProviderOptions{
115						ThinkingConfig: &google.ThinkingConfig{
116							ThinkingBudget:  ai.IntOption(100),
117							IncludeThoughts: ai.BoolOption(true),
118						},
119					},
120					"openai": &openai.ProviderOptions{
121						ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
122					},
123					"openrouter": &openrouter.ProviderOptions{
124						Reasoning: &openrouter.ReasoningOptions{
125							Effort: openrouter.ReasoningEffortOption(openrouter.ReasoningEffortHigh),
126						},
127					},
128				},
129			})
130			require.NoError(t, err, "failed to generate")
131
132			want1 := "Florence"
133			want2 := "40"
134			got := result.Response.Content.Text()
135			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
136
137			testThinking(t, languageModel.Provider(), result.Steps)
138		})
139	}
140}
141
142func TestThinkingStreaming(t *testing.T) {
143	for _, pair := range thinkingLanguageModelBuilders {
144		t.Run(pair.name, func(t *testing.T) {
145			r := newRecorder(t)
146
147			languageModel, err := pair.builder(r)
148			require.NoError(t, err, "failed to build language model")
149
150			type WeatherInput struct {
151				Location string `json:"location" description:"the city"`
152			}
153
154			weatherTool := ai.NewAgentTool(
155				"weather",
156				"Get weather information for a location",
157				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
158					return ai.NewTextResponse("40 C"), nil
159				},
160			)
161
162			agent := ai.NewAgent(
163				languageModel,
164				ai.WithSystemPrompt("You are a helpful assistant"),
165				ai.WithTools(weatherTool),
166			)
167			result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
168				Prompt: "What's the weather in Florence, Italy?",
169				ProviderOptions: ai.ProviderOptions{
170					"anthropic": &anthropic.ProviderOptions{
171						Thinking: &anthropic.ThinkingProviderOption{
172							BudgetTokens: 10_000,
173						},
174					},
175					"google": &google.ProviderOptions{
176						ThinkingConfig: &google.ThinkingConfig{
177							ThinkingBudget:  ai.IntOption(100),
178							IncludeThoughts: ai.BoolOption(true),
179						},
180					},
181					"openai": &openai.ProviderOptions{
182						ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
183					},
184				},
185			})
186			require.NoError(t, err, "failed to generate")
187
188			want1 := "Florence"
189			want2 := "40"
190			got := result.Response.Content.Text()
191			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
192
193			testThinking(t, languageModel.Provider(), result.Steps)
194		})
195	}
196}
197
198func TestStream(t *testing.T) {
199	for _, pair := range languageModelBuilders {
200		t.Run(pair.name, func(t *testing.T) {
201			r := newRecorder(t)
202
203			languageModel, err := pair.builder(r)
204			require.NoError(t, err, "failed to build language model")
205
206			agent := ai.NewAgent(
207				languageModel,
208				ai.WithSystemPrompt("You are a helpful assistant"),
209			)
210
211			var collectedText strings.Builder
212			textDeltaCount := 0
213			stepCount := 0
214
215			streamCall := ai.AgentStreamCall{
216				Prompt: "Count from 1 to 3 in Spanish",
217				OnTextDelta: func(id, text string) error {
218					textDeltaCount++
219					collectedText.WriteString(text)
220					return nil
221				},
222				OnStepFinish: func(step ai.StepResult) error {
223					stepCount++
224					return nil
225				},
226			}
227
228			result, err := agent.Stream(t.Context(), streamCall)
229			require.NoError(t, err, "failed to stream")
230
231			finalText := result.Response.Content.Text()
232			require.NotEmpty(t, finalText, "expected non-empty response")
233
234			require.True(t, strings.Contains(strings.ToLower(finalText), "uno") &&
235				strings.Contains(strings.ToLower(finalText), "dos") &&
236				strings.Contains(strings.ToLower(finalText), "tres"), "unexpected response: %q", finalText)
237
238			require.Greater(t, textDeltaCount, 0, "expected at least one text delta callback")
239
240			require.Greater(t, stepCount, 0, "expected at least one step finish callback")
241
242			require.NotEmpty(t, collectedText.String(), "expected collected text from deltas to be non-empty")
243		})
244	}
245}
246
247func TestStreamWithTools(t *testing.T) {
248	for _, pair := range languageModelBuilders {
249		t.Run(pair.name, func(t *testing.T) {
250			r := newRecorder(t)
251
252			languageModel, err := pair.builder(r)
253			require.NoError(t, err, "failed to build language model")
254
255			type CalculatorInput struct {
256				A int `json:"a" description:"first number"`
257				B int `json:"b" description:"second number"`
258			}
259
260			calculatorTool := ai.NewAgentTool(
261				"add",
262				"Add two numbers",
263				func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
264					result := input.A + input.B
265					return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
266				},
267			)
268
269			agent := ai.NewAgent(
270				languageModel,
271				ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
272				ai.WithTools(calculatorTool),
273			)
274
275			toolCallCount := 0
276			toolResultCount := 0
277			var collectedText strings.Builder
278
279			streamCall := ai.AgentStreamCall{
280				Prompt: "What is 15 + 27?",
281				OnTextDelta: func(id, text string) error {
282					collectedText.WriteString(text)
283					return nil
284				},
285				OnToolCall: func(toolCall ai.ToolCallContent) error {
286					toolCallCount++
287					require.Equal(t, "add", toolCall.ToolName, "unexpected tool name")
288					return nil
289				},
290				OnToolResult: func(result ai.ToolResultContent) error {
291					toolResultCount++
292					return nil
293				},
294			}
295
296			result, err := agent.Stream(t.Context(), streamCall)
297			require.NoError(t, err, "failed to stream")
298
299			finalText := result.Response.Content.Text()
300			require.Contains(t, finalText, "42", "expected response to contain '42', got: %q", finalText)
301
302			require.Greater(t, toolCallCount, 0, "expected at least one tool call")
303
304			require.Greater(t, toolResultCount, 0, "expected at least one tool result")
305		})
306	}
307}
308
309func TestStreamWithMultipleTools(t *testing.T) {
310	for _, pair := range languageModelBuilders {
311		t.Run(pair.name, func(t *testing.T) {
312			r := newRecorder(t)
313
314			languageModel, err := pair.builder(r)
315			require.NoError(t, err, "failed to build language model")
316
317			type CalculatorInput struct {
318				A int `json:"a" description:"first number"`
319				B int `json:"b" description:"second number"`
320			}
321
322			addTool := ai.NewAgentTool(
323				"add",
324				"Add two numbers",
325				func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
326					result := input.A + input.B
327					return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
328				},
329			)
330			multiplyTool := ai.NewAgentTool(
331				"multiply",
332				"Multiply two numbers",
333				func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
334					result := input.A * input.B
335					return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
336				},
337			)
338
339			agent := ai.NewAgent(
340				languageModel,
341				ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
342				ai.WithTools(addTool),
343				ai.WithTools(multiplyTool),
344			)
345
346			toolCallCount := 0
347			toolResultCount := 0
348			var collectedText strings.Builder
349
350			streamCall := ai.AgentStreamCall{
351				Prompt: "Add and multiply the number 2 and 3",
352				OnTextDelta: func(id, text string) error {
353					collectedText.WriteString(text)
354					return nil
355				},
356				OnToolCall: func(toolCall ai.ToolCallContent) error {
357					toolCallCount++
358					return nil
359				},
360				OnToolResult: func(result ai.ToolResultContent) error {
361					toolResultCount++
362					return nil
363				},
364			}
365
366			result, err := agent.Stream(t.Context(), streamCall)
367			require.NoError(t, err, "failed to stream")
368			require.Equal(t, len(result.Steps), 2, "expected all tool calls in step 1")
369			finalText := result.Response.Content.Text()
370			require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
371			require.Contains(t, finalText, "6", "expected response to contain '5', got: %q", finalText)
372
373			require.Greater(t, toolCallCount, 0, "expected at least one tool call")
374
375			require.Greater(t, toolResultCount, 0, "expected at least one tool result")
376		})
377	}
378}