1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	"github.com/stretchr/testify/require"
 11	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 12)
 13
 14type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 15
 16type builderPair struct {
 17	name            string
 18	builder         builderFunc
 19	providerOptions ai.ProviderOptions
 20}
 21
 22func testCommon(t *testing.T, pairs []builderPair) {
 23	for _, pair := range pairs {
 24		testSimple(t, pair)
 25		testTool(t, pair)
 26		testMultiTool(t, pair)
 27	}
 28}
 29
 30func testSimple(t *testing.T, pair builderPair) {
 31	checkResult := func(t *testing.T, result *ai.AgentResult) {
 32		option1 := "Oi"
 33		option2 := "Olá"
 34		got := result.Response.Content.Text()
 35		require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
 36	}
 37
 38	t.Run("simple "+pair.name, func(t *testing.T) {
 39		r := newRecorder(t)
 40
 41		languageModel, err := pair.builder(r)
 42		require.NoError(t, err, "failed to build language model")
 43
 44		agent := ai.NewAgent(
 45			languageModel,
 46			ai.WithSystemPrompt("You are a helpful assistant"),
 47		)
 48		result, err := agent.Generate(t.Context(), ai.AgentCall{
 49			Prompt:          "Say hi in Portuguese",
 50			ProviderOptions: pair.providerOptions,
 51			MaxOutputTokens: ai.IntOption(4000),
 52		})
 53		require.NoError(t, err, "failed to generate")
 54		checkResult(t, result)
 55	})
 56	t.Run("simple streaming "+pair.name, func(t *testing.T) {
 57		r := newRecorder(t)
 58
 59		languageModel, err := pair.builder(r)
 60		require.NoError(t, err, "failed to build language model")
 61
 62		agent := ai.NewAgent(
 63			languageModel,
 64			ai.WithSystemPrompt("You are a helpful assistant"),
 65		)
 66		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 67			Prompt:          "Say hi in Portuguese",
 68			ProviderOptions: pair.providerOptions,
 69			MaxOutputTokens: ai.IntOption(4000),
 70		})
 71		require.NoError(t, err, "failed to generate")
 72		checkResult(t, result)
 73	})
 74}
 75
 76func testTool(t *testing.T, pair builderPair) {
 77	type WeatherInput struct {
 78		Location string `json:"location" description:"the city"`
 79	}
 80
 81	weatherTool := ai.NewAgentTool(
 82		"weather",
 83		"Get weather information for a location",
 84		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 85			return ai.NewTextResponse("40 C"), nil
 86		},
 87	)
 88	checkResult := func(t *testing.T, result *ai.AgentResult) {
 89		require.Len(t, result.Steps, 2)
 90
 91		var toolCalls []ai.ToolCallContent
 92		for _, content := range result.Steps[0].Content {
 93			if content.GetType() == ai.ContentTypeToolCall {
 94				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
 95			}
 96		}
 97		for _, tc := range toolCalls {
 98			require.False(t, tc.Invalid)
 99		}
100		require.Len(t, toolCalls, 1)
101		require.Equal(t, toolCalls[0].ToolName, "weather")
102
103		want1 := "Florence"
104		want2 := "40"
105		got := result.Response.Content.Text()
106		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
107	}
108
109	t.Run("tool "+pair.name, func(t *testing.T) {
110		r := newRecorder(t)
111
112		languageModel, err := pair.builder(r)
113		require.NoError(t, err, "failed to build language model")
114
115		agent := ai.NewAgent(
116			languageModel,
117			ai.WithSystemPrompt("You are a helpful assistant"),
118			ai.WithTools(weatherTool),
119		)
120		result, err := agent.Generate(t.Context(), ai.AgentCall{
121			Prompt:          "What's the weather in Florence,Italy?",
122			ProviderOptions: pair.providerOptions,
123			MaxOutputTokens: ai.IntOption(4000),
124		})
125		require.NoError(t, err, "failed to generate")
126		checkResult(t, result)
127	})
128	t.Run("tool streaming "+pair.name, func(t *testing.T) {
129		r := newRecorder(t)
130
131		languageModel, err := pair.builder(r)
132		require.NoError(t, err, "failed to build language model")
133
134		agent := ai.NewAgent(
135			languageModel,
136			ai.WithSystemPrompt("You are a helpful assistant"),
137			ai.WithTools(weatherTool),
138		)
139		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
140			Prompt:          "What's the weather in Florence,Italy?",
141			ProviderOptions: pair.providerOptions,
142			MaxOutputTokens: ai.IntOption(4000),
143		})
144		require.NoError(t, err, "failed to generate")
145		checkResult(t, result)
146	})
147}
148
149func testMultiTool(t *testing.T, pair builderPair) {
150	type WeatherInput struct {
151		Location string `json:"location" description:"the city"`
152	}
153
154	type CalculatorInput struct {
155		A int `json:"a" description:"first number"`
156		B int `json:"b" description:"second number"`
157	}
158
159	addTool := ai.NewAgentTool(
160		"add",
161		"Add two numbers",
162		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
163			result := input.A + input.B
164			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
165		},
166	)
167	multiplyTool := ai.NewAgentTool(
168		"multiply",
169		"Multiply two numbers",
170		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
171			result := input.A * input.B
172			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
173		},
174	)
175	checkResult := func(t *testing.T, result *ai.AgentResult) {
176		require.Len(t, result.Steps, 2)
177
178		var toolCalls []ai.ToolCallContent
179		for _, content := range result.Steps[0].Content {
180			if content.GetType() == ai.ContentTypeToolCall {
181				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
182			}
183		}
184		for _, tc := range toolCalls {
185			require.False(t, tc.Invalid)
186		}
187		require.Len(t, toolCalls, 2)
188
189		finalText := result.Response.Content.Text()
190		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
191		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
192	}
193
194	t.Run("multi tool "+pair.name, func(t *testing.T) {
195		r := newRecorder(t)
196
197		languageModel, err := pair.builder(r)
198		require.NoError(t, err, "failed to build language model")
199
200		agent := ai.NewAgent(
201			languageModel,
202			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
203			ai.WithTools(addTool),
204			ai.WithTools(multiplyTool),
205		)
206		result, err := agent.Generate(t.Context(), ai.AgentCall{
207			Prompt:          "Add and multiply the number 2 and 3",
208			ProviderOptions: pair.providerOptions,
209			MaxOutputTokens: ai.IntOption(4000),
210		})
211		require.NoError(t, err, "failed to generate")
212		checkResult(t, result)
213	})
214	t.Run("multi tool streaming "+pair.name, func(t *testing.T) {
215		r := newRecorder(t)
216
217		languageModel, err := pair.builder(r)
218		require.NoError(t, err, "failed to build language model")
219
220		agent := ai.NewAgent(
221			languageModel,
222			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
223			ai.WithTools(addTool),
224			ai.WithTools(multiplyTool),
225		)
226		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
227			Prompt:          "Add and multiply the number 2 and 3",
228			ProviderOptions: pair.providerOptions,
229			MaxOutputTokens: ai.IntOption(4000),
230		})
231		require.NoError(t, err, "failed to generate")
232		checkResult(t, result)
233	})
234}
235
236func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
237	for _, pair := range pairs {
238		t.Run("thinking-"+pair.name, func(t *testing.T) {
239			r := newRecorder(t)
240
241			languageModel, err := pair.builder(r)
242			require.NoError(t, err, "failed to build language model")
243
244			type WeatherInput struct {
245				Location string `json:"location" description:"the city"`
246			}
247
248			weatherTool := ai.NewAgentTool(
249				"weather",
250				"Get weather information for a location",
251				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
252					return ai.NewTextResponse("40 C"), nil
253				},
254			)
255
256			agent := ai.NewAgent(
257				languageModel,
258				ai.WithSystemPrompt("You are a helpful assistant"),
259				ai.WithTools(weatherTool),
260			)
261			result, err := agent.Generate(t.Context(), ai.AgentCall{
262				Prompt:          "What's the weather in Florence, Italy?",
263				ProviderOptions: pair.providerOptions,
264				// ProviderOptions: ai.ProviderOptions{
265				// 	"anthropic": &anthropic.ProviderOptions{
266				// 		Thinking: &anthropic.ThinkingProviderOption{
267				// 			BudgetTokens: 10_000,
268				// 		},
269				// 	},
270				// 	"google": &google.ProviderOptions{
271				// 		ThinkingConfig: &google.ThinkingConfig{
272				// 			ThinkingBudget:  ai.IntOption(100),
273				// 			IncludeThoughts: ai.BoolOption(true),
274				// 		},
275				// 	},
276				// 	"openai": &openai.ProviderOptions{
277				// 		ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
278				// 	},
279				// },
280			})
281			require.NoError(t, err, "failed to generate")
282
283			want1 := "Florence"
284			want2 := "40"
285			got := result.Response.Content.Text()
286			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
287
288			thinkChecks(t, result)
289		})
290	}
291}