1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	"github.com/charmbracelet/fantasy/anthropic"
 11	"github.com/charmbracelet/fantasy/google"
 12	"github.com/charmbracelet/fantasy/openai"
 13	_ "github.com/joho/godotenv/autoload"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17func TestSimple(t *testing.T) {
 18	for _, pair := range languageModelBuilders {
 19		t.Run(pair.name, func(t *testing.T) {
 20			r := newRecorder(t)
 21
 22			languageModel, err := pair.builder(r)
 23			require.NoError(t, err, "failed to build language model")
 24
 25			agent := ai.NewAgent(
 26				languageModel,
 27				ai.WithSystemPrompt("You are a helpful assistant"),
 28			)
 29			result, err := agent.Generate(t.Context(), ai.AgentCall{
 30				Prompt: "Say hi in Portuguese",
 31			})
 32			require.NoError(t, err, "failed to generate")
 33
 34			option1 := "Oi"
 35			option2 := "Olá"
 36			got := result.Response.Content.Text()
 37			require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
 38		})
 39	}
 40}
 41
 42func TestTool(t *testing.T) {
 43	for _, pair := range languageModelBuilders {
 44		t.Run(pair.name, func(t *testing.T) {
 45			r := newRecorder(t)
 46
 47			languageModel, err := pair.builder(r)
 48			require.NoError(t, err, "failed to build language model")
 49
 50			type WeatherInput struct {
 51				Location string `json:"location" description:"the city"`
 52			}
 53
 54			weatherTool := ai.NewAgentTool(
 55				"weather",
 56				"Get weather information for a location",
 57				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 58					return ai.NewTextResponse("40 C"), nil
 59				},
 60			)
 61
 62			agent := ai.NewAgent(
 63				languageModel,
 64				ai.WithSystemPrompt("You are a helpful assistant"),
 65				ai.WithTools(weatherTool),
 66			)
 67			result, err := agent.Generate(t.Context(), ai.AgentCall{
 68				Prompt: "What's the weather in Florence?",
 69			})
 70			require.NoError(t, err, "failed to generate")
 71
 72			want1 := "Florence"
 73			want2 := "40"
 74			got := result.Response.Content.Text()
 75			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
 76		})
 77	}
 78}
 79
 80func TestThinking(t *testing.T) {
 81	for _, pair := range thinkingLanguageModelBuilders {
 82		t.Run(pair.name, func(t *testing.T) {
 83			r := newRecorder(t)
 84
 85			languageModel, err := pair.builder(r)
 86			require.NoError(t, err, "failed to build language model")
 87
 88			type WeatherInput struct {
 89				Location string `json:"location" description:"the city"`
 90			}
 91
 92			weatherTool := ai.NewAgentTool(
 93				"weather",
 94				"Get weather information for a location",
 95				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 96					return ai.NewTextResponse("40 C"), nil
 97				},
 98			)
 99
100			agent := ai.NewAgent(
101				languageModel,
102				ai.WithSystemPrompt("You are a helpful assistant"),
103				ai.WithTools(weatherTool),
104			)
105			result, err := agent.Generate(t.Context(), ai.AgentCall{
106				Prompt: "What's the weather in Florence, Italy?",
107				ProviderOptions: ai.ProviderOptions{
108					"anthropic": &anthropic.ProviderOptions{
109						Thinking: &anthropic.ThinkingProviderOption{
110							BudgetTokens: 10_000,
111						},
112					},
113					"google": &google.ProviderOptions{
114						ThinkingConfig: &google.ThinkingConfig{
115							ThinkingBudget:  ai.IntOption(100),
116							IncludeThoughts: ai.BoolOption(true),
117						},
118					},
119					"openai": &openai.ProviderOptions{
120						ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
121					},
122				},
123			})
124			require.NoError(t, err, "failed to generate")
125
126			want1 := "Florence"
127			want2 := "40"
128			got := result.Response.Content.Text()
129			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
130
131			testThinkingSteps(t, languageModel.Provider(), result.Steps)
132		})
133	}
134}
135
136func TestThinkingStreaming(t *testing.T) {
137	for _, pair := range thinkingLanguageModelBuilders {
138		t.Run(pair.name, func(t *testing.T) {
139			r := newRecorder(t)
140
141			languageModel, err := pair.builder(r)
142			require.NoError(t, err, "failed to build language model")
143
144			type WeatherInput struct {
145				Location string `json:"location" description:"the city"`
146			}
147
148			weatherTool := ai.NewAgentTool(
149				"weather",
150				"Get weather information for a location",
151				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
152					return ai.NewTextResponse("40 C"), nil
153				},
154			)
155
156			agent := ai.NewAgent(
157				languageModel,
158				ai.WithSystemPrompt("You are a helpful assistant"),
159				ai.WithTools(weatherTool),
160			)
161			result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
162				Prompt: "What's the weather in Florence, Italy?",
163				ProviderOptions: ai.ProviderOptions{
164					"anthropic": &anthropic.ProviderOptions{
165						Thinking: &anthropic.ThinkingProviderOption{
166							BudgetTokens: 10_000,
167						},
168					},
169					"google": &google.ProviderOptions{
170						ThinkingConfig: &google.ThinkingConfig{
171							ThinkingBudget:  ai.IntOption(100),
172							IncludeThoughts: ai.BoolOption(true),
173						},
174					},
175					"openai": &openai.ProviderOptions{
176						ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
177					},
178				},
179			})
180			require.NoError(t, err, "failed to generate")
181
182			want1 := "Florence"
183			want2 := "40"
184			got := result.Response.Content.Text()
185			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
186
187			testThinkingSteps(t, languageModel.Provider(), result.Steps)
188		})
189	}
190}
191
192func TestStream(t *testing.T) {
193	for _, pair := range languageModelBuilders {
194		t.Run(pair.name, func(t *testing.T) {
195			r := newRecorder(t)
196
197			languageModel, err := pair.builder(r)
198			require.NoError(t, err, "failed to build language model")
199
200			agent := ai.NewAgent(
201				languageModel,
202				ai.WithSystemPrompt("You are a helpful assistant"),
203			)
204
205			var collectedText strings.Builder
206			textDeltaCount := 0
207			stepCount := 0
208
209			streamCall := ai.AgentStreamCall{
210				Prompt: "Count from 1 to 3 in Spanish",
211				OnTextDelta: func(id, text string) error {
212					textDeltaCount++
213					collectedText.WriteString(text)
214					return nil
215				},
216				OnStepFinish: func(step ai.StepResult) error {
217					stepCount++
218					return nil
219				},
220			}
221
222			result, err := agent.Stream(t.Context(), streamCall)
223			require.NoError(t, err, "failed to stream")
224
225			finalText := result.Response.Content.Text()
226			require.NotEmpty(t, finalText, "expected non-empty response")
227
228			require.True(t, strings.Contains(strings.ToLower(finalText), "uno") &&
229				strings.Contains(strings.ToLower(finalText), "dos") &&
230				strings.Contains(strings.ToLower(finalText), "tres"), "unexpected response: %q", finalText)
231
232			require.Greater(t, textDeltaCount, 0, "expected at least one text delta callback")
233
234			require.Greater(t, stepCount, 0, "expected at least one step finish callback")
235
236			require.NotEmpty(t, collectedText.String(), "expected collected text from deltas to be non-empty")
237		})
238	}
239}
240
241func TestStreamWithTools(t *testing.T) {
242	for _, pair := range languageModelBuilders {
243		t.Run(pair.name, func(t *testing.T) {
244			r := newRecorder(t)
245
246			languageModel, err := pair.builder(r)
247			require.NoError(t, err, "failed to build language model")
248
249			type CalculatorInput struct {
250				A int `json:"a" description:"first number"`
251				B int `json:"b" description:"second number"`
252			}
253
254			calculatorTool := ai.NewAgentTool(
255				"add",
256				"Add two numbers",
257				func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
258					result := input.A + input.B
259					return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
260				},
261			)
262
263			agent := ai.NewAgent(
264				languageModel,
265				ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
266				ai.WithTools(calculatorTool),
267			)
268
269			toolCallCount := 0
270			toolResultCount := 0
271			var collectedText strings.Builder
272
273			streamCall := ai.AgentStreamCall{
274				Prompt: "What is 15 + 27?",
275				OnTextDelta: func(id, text string) error {
276					collectedText.WriteString(text)
277					return nil
278				},
279				OnToolCall: func(toolCall ai.ToolCallContent) error {
280					toolCallCount++
281					require.Equal(t, "add", toolCall.ToolName, "unexpected tool name")
282					return nil
283				},
284				OnToolResult: func(result ai.ToolResultContent) error {
285					toolResultCount++
286					return nil
287				},
288			}
289
290			result, err := agent.Stream(t.Context(), streamCall)
291			require.NoError(t, err, "failed to stream")
292
293			finalText := result.Response.Content.Text()
294			require.Contains(t, finalText, "42", "expected response to contain '42', got: %q", finalText)
295
296			require.Greater(t, toolCallCount, 0, "expected at least one tool call")
297
298			require.Greater(t, toolResultCount, 0, "expected at least one tool result")
299		})
300	}
301}