common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	"github.com/stretchr/testify/require"
 11	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 12)
 13
 14type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 15
 16type builderPair struct {
 17	name            string
 18	builder         builderFunc
 19	providerOptions ai.ProviderOptions
 20}
 21
 22func testCommon(t *testing.T, pairs []builderPair) {
 23	for _, pair := range pairs {
 24		testSimple(t, pair)
 25		testTool(t, pair)
 26		testMultiTool(t, pair)
 27	}
 28}
 29
 30func testSimple(t *testing.T, pair builderPair) {
 31	checkResult := func(t *testing.T, result *ai.AgentResult) {
 32		option1 := "Oi"
 33		option2 := "Olá"
 34		got := result.Response.Content.Text()
 35		require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
 36	}
 37
 38	t.Run("simple "+pair.name, func(t *testing.T) {
 39		r := newRecorder(t)
 40
 41		languageModel, err := pair.builder(r)
 42		require.NoError(t, err, "failed to build language model")
 43
 44		agent := ai.NewAgent(
 45			languageModel,
 46			ai.WithSystemPrompt("You are a helpful assistant"),
 47		)
 48		result, err := agent.Generate(t.Context(), ai.AgentCall{
 49			Prompt:          "Say hi in Portuguese",
 50			ProviderOptions: pair.providerOptions,
 51			MaxOutputTokens: ai.IntOption(4000),
 52		})
 53		require.NoError(t, err, "failed to generate")
 54		checkResult(t, result)
 55	})
 56	t.Run("simple streaming "+pair.name, func(t *testing.T) {
 57		r := newRecorder(t)
 58
 59		languageModel, err := pair.builder(r)
 60		require.NoError(t, err, "failed to build language model")
 61
 62		agent := ai.NewAgent(
 63			languageModel,
 64			ai.WithSystemPrompt("You are a helpful assistant"),
 65		)
 66		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 67			Prompt:          "Say hi in Portuguese",
 68			ProviderOptions: pair.providerOptions,
 69			MaxOutputTokens: ai.IntOption(4000),
 70		})
 71		require.NoError(t, err, "failed to generate")
 72		checkResult(t, result)
 73	})
 74}
 75
 76func testTool(t *testing.T, pair builderPair) {
 77	type WeatherInput struct {
 78		Location string `json:"location" description:"the city"`
 79	}
 80
 81	weatherTool := ai.NewAgentTool(
 82		"weather",
 83		"Get weather information for a location",
 84		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 85			return ai.NewTextResponse("40 C"), nil
 86		},
 87	)
 88	checkResult := func(t *testing.T, result *ai.AgentResult) {
 89		require.Len(t, result.Steps, 2)
 90
 91		var toolCalls []ai.ToolCallContent
 92		for _, content := range result.Steps[0].Content {
 93			if content.GetType() == ai.ContentTypeToolCall {
 94				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
 95			}
 96		}
 97		require.Len(t, toolCalls, 1)
 98		require.Equal(t, toolCalls[0].ToolName, "weather")
 99
100		want1 := "Florence"
101		want2 := "40"
102		got := result.Response.Content.Text()
103		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
104	}
105
106	t.Run("tool "+pair.name, func(t *testing.T) {
107		r := newRecorder(t)
108
109		languageModel, err := pair.builder(r)
110		require.NoError(t, err, "failed to build language model")
111
112		agent := ai.NewAgent(
113			languageModel,
114			ai.WithSystemPrompt("You are a helpful assistant"),
115			ai.WithTools(weatherTool),
116		)
117		result, err := agent.Generate(t.Context(), ai.AgentCall{
118			Prompt:          "What's the weather in Florence,Italy?",
119			ProviderOptions: pair.providerOptions,
120			MaxOutputTokens: ai.IntOption(4000),
121		})
122		require.NoError(t, err, "failed to generate")
123		checkResult(t, result)
124	})
125	t.Run("tool streaming "+pair.name, func(t *testing.T) {
126		r := newRecorder(t)
127
128		languageModel, err := pair.builder(r)
129		require.NoError(t, err, "failed to build language model")
130
131		agent := ai.NewAgent(
132			languageModel,
133			ai.WithSystemPrompt("You are a helpful assistant"),
134			ai.WithTools(weatherTool),
135		)
136		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
137			Prompt:          "What's the weather in Florence,Italy?",
138			ProviderOptions: pair.providerOptions,
139			MaxOutputTokens: ai.IntOption(4000),
140		})
141		require.NoError(t, err, "failed to generate")
142		checkResult(t, result)
143	})
144}
145
146func testMultiTool(t *testing.T, pair builderPair) {
147	type WeatherInput struct {
148		Location string `json:"location" description:"the city"`
149	}
150
151	type CalculatorInput struct {
152		A int `json:"a" description:"first number"`
153		B int `json:"b" description:"second number"`
154	}
155
156	addTool := ai.NewAgentTool(
157		"add",
158		"Add two numbers",
159		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
160			result := input.A + input.B
161			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
162		},
163	)
164	multiplyTool := ai.NewAgentTool(
165		"multiply",
166		"Multiply two numbers",
167		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
168			result := input.A * input.B
169			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
170		},
171	)
172	checkResult := func(t *testing.T, result *ai.AgentResult) {
173		require.Len(t, result.Steps, 2)
174
175		var toolCalls []ai.ToolCallContent
176		for _, content := range result.Steps[0].Content {
177			if content.GetType() == ai.ContentTypeToolCall {
178				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
179			}
180		}
181		require.Len(t, toolCalls, 2)
182
183		finalText := result.Response.Content.Text()
184		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
185		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
186	}
187
188	t.Run("multi tool "+pair.name, func(t *testing.T) {
189		r := newRecorder(t)
190
191		languageModel, err := pair.builder(r)
192		require.NoError(t, err, "failed to build language model")
193
194		agent := ai.NewAgent(
195			languageModel,
196			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
197			ai.WithTools(addTool),
198			ai.WithTools(multiplyTool),
199		)
200		result, err := agent.Generate(t.Context(), ai.AgentCall{
201			Prompt:          "Add and multiply the number 2 and 3",
202			ProviderOptions: pair.providerOptions,
203			MaxOutputTokens: ai.IntOption(4000),
204		})
205		require.NoError(t, err, "failed to generate")
206		checkResult(t, result)
207	})
208	t.Run("multi tool streaming "+pair.name, func(t *testing.T) {
209		r := newRecorder(t)
210
211		languageModel, err := pair.builder(r)
212		require.NoError(t, err, "failed to build language model")
213
214		agent := ai.NewAgent(
215			languageModel,
216			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
217			ai.WithTools(addTool),
218			ai.WithTools(multiplyTool),
219		)
220		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
221			Prompt:          "Add and multiply the number 2 and 3",
222			ProviderOptions: pair.providerOptions,
223			MaxOutputTokens: ai.IntOption(4000),
224		})
225		require.NoError(t, err, "failed to generate")
226		checkResult(t, result)
227	})
228}
229
230func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
231	for _, pair := range pairs {
232		t.Run(pair.name, func(t *testing.T) {
233			r := newRecorder(t)
234
235			languageModel, err := pair.builder(r)
236			require.NoError(t, err, "failed to build language model")
237
238			type WeatherInput struct {
239				Location string `json:"location" description:"the city"`
240			}
241
242			weatherTool := ai.NewAgentTool(
243				"weather",
244				"Get weather information for a location",
245				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
246					return ai.NewTextResponse("40 C"), nil
247				},
248			)
249
250			agent := ai.NewAgent(
251				languageModel,
252				ai.WithSystemPrompt("You are a helpful assistant"),
253				ai.WithTools(weatherTool),
254			)
255			result, err := agent.Generate(t.Context(), ai.AgentCall{
256				Prompt:          "What's the weather in Florence, Italy?",
257				ProviderOptions: pair.providerOptions,
258				// ProviderOptions: ai.ProviderOptions{
259				// 	"anthropic": &anthropic.ProviderOptions{
260				// 		Thinking: &anthropic.ThinkingProviderOption{
261				// 			BudgetTokens: 10_000,
262				// 		},
263				// 	},
264				// 	"google": &google.ProviderOptions{
265				// 		ThinkingConfig: &google.ThinkingConfig{
266				// 			ThinkingBudget:  ai.IntOption(100),
267				// 			IncludeThoughts: ai.BoolOption(true),
268				// 		},
269				// 	},
270				// 	"openai": &openai.ProviderOptions{
271				// 		ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortMedium),
272				// 	},
273				// },
274			})
275			require.NoError(t, err, "failed to generate")
276
277			want1 := "Florence"
278			want2 := "40"
279			got := result.Response.Content.Text()
280			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
281
282			thinkChecks(t, result)
283		})
284	}
285}