common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	"github.com/stretchr/testify/require"
 11	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 12)
 13
 14type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 15
 16type builderPair struct {
 17	name            string
 18	builder         builderFunc
 19	providerOptions ai.ProviderOptions
 20}
 21
 22func testCommon(t *testing.T, pairs []builderPair) {
 23	for _, pair := range pairs {
 24		testSimple(t, pair)
 25		testTool(t, pair)
 26		testMultiTool(t, pair)
 27	}
 28}
 29
 30func testSimple(t *testing.T, pair builderPair) {
 31	checkResult := func(t *testing.T, result *ai.AgentResult) {
 32		option1 := "Oi"
 33		option2 := "Olá"
 34		got := result.Response.Content.Text()
 35		require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
 36	}
 37
 38	t.Run("simple "+pair.name, func(t *testing.T) {
 39		r := newRecorder(t)
 40
 41		languageModel, err := pair.builder(r)
 42		require.NoError(t, err, "failed to build language model")
 43
 44		agent := ai.NewAgent(
 45			languageModel,
 46			ai.WithSystemPrompt("You are a helpful assistant"),
 47		)
 48		result, err := agent.Generate(t.Context(), ai.AgentCall{
 49			Prompt:          "Say hi in Portuguese",
 50			ProviderOptions: pair.providerOptions,
 51			MaxOutputTokens: ai.IntOption(4000),
 52		})
 53		require.NoError(t, err, "failed to generate")
 54		checkResult(t, result)
 55	})
 56	t.Run("simple streaming "+pair.name, func(t *testing.T) {
 57		r := newRecorder(t)
 58
 59		languageModel, err := pair.builder(r)
 60		require.NoError(t, err, "failed to build language model")
 61
 62		agent := ai.NewAgent(
 63			languageModel,
 64			ai.WithSystemPrompt("You are a helpful assistant"),
 65		)
 66		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 67			Prompt:          "Say hi in Portuguese",
 68			ProviderOptions: pair.providerOptions,
 69			MaxOutputTokens: ai.IntOption(4000),
 70		})
 71		require.NoError(t, err, "failed to generate")
 72		checkResult(t, result)
 73	})
 74}
 75
 76func testTool(t *testing.T, pair builderPair) {
 77	type WeatherInput struct {
 78		Location string `json:"location" description:"the city"`
 79	}
 80
 81	weatherTool := ai.NewAgentTool(
 82		"weather",
 83		"Get weather information for a location",
 84		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 85			return ai.NewTextResponse("40 C"), nil
 86		},
 87	)
 88	checkResult := func(t *testing.T, result *ai.AgentResult) {
 89		require.Len(t, result.Steps, 2)
 90
 91		var toolCalls []ai.ToolCallContent
 92		for _, content := range result.Steps[0].Content {
 93			if content.GetType() == ai.ContentTypeToolCall {
 94				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
 95			}
 96		}
 97		for _, tc := range toolCalls {
 98			require.False(t, tc.Invalid)
 99		}
100		require.Len(t, toolCalls, 1)
101		require.Equal(t, toolCalls[0].ToolName, "weather")
102
103		want1 := "Florence"
104		want2 := "40"
105		got := result.Response.Content.Text()
106		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
107	}
108
109	t.Run("tool "+pair.name, func(t *testing.T) {
110		r := newRecorder(t)
111
112		languageModel, err := pair.builder(r)
113		require.NoError(t, err, "failed to build language model")
114
115		agent := ai.NewAgent(
116			languageModel,
117			ai.WithSystemPrompt("You are a helpful assistant"),
118			ai.WithTools(weatherTool),
119		)
120		result, err := agent.Generate(t.Context(), ai.AgentCall{
121			Prompt:          "What's the weather in Florence,Italy?",
122			ProviderOptions: pair.providerOptions,
123			MaxOutputTokens: ai.IntOption(4000),
124		})
125		require.NoError(t, err, "failed to generate")
126		checkResult(t, result)
127	})
128	t.Run("tool streaming "+pair.name, func(t *testing.T) {
129		r := newRecorder(t)
130
131		languageModel, err := pair.builder(r)
132		require.NoError(t, err, "failed to build language model")
133
134		agent := ai.NewAgent(
135			languageModel,
136			ai.WithSystemPrompt("You are a helpful assistant"),
137			ai.WithTools(weatherTool),
138		)
139		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
140			Prompt:          "What's the weather in Florence,Italy?",
141			ProviderOptions: pair.providerOptions,
142			MaxOutputTokens: ai.IntOption(4000),
143		})
144		require.NoError(t, err, "failed to generate")
145		checkResult(t, result)
146	})
147}
148
149func testMultiTool(t *testing.T, pair builderPair) {
150	type CalculatorInput struct {
151		A int `json:"a" description:"first number"`
152		B int `json:"b" description:"second number"`
153	}
154
155	addTool := ai.NewAgentTool(
156		"add",
157		"Add two numbers",
158		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
159			result := input.A + input.B
160			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
161		},
162	)
163	multiplyTool := ai.NewAgentTool(
164		"multiply",
165		"Multiply two numbers",
166		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
167			result := input.A * input.B
168			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
169		},
170	)
171	checkResult := func(t *testing.T, result *ai.AgentResult) {
172		require.Len(t, result.Steps, 2)
173
174		var toolCalls []ai.ToolCallContent
175		for _, content := range result.Steps[0].Content {
176			if content.GetType() == ai.ContentTypeToolCall {
177				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
178			}
179		}
180		for _, tc := range toolCalls {
181			require.False(t, tc.Invalid)
182		}
183		require.Len(t, toolCalls, 2)
184
185		finalText := result.Response.Content.Text()
186		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
187		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
188	}
189
190	t.Run("multi tool "+pair.name, func(t *testing.T) {
191		r := newRecorder(t)
192
193		languageModel, err := pair.builder(r)
194		require.NoError(t, err, "failed to build language model")
195
196		agent := ai.NewAgent(
197			languageModel,
198			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
199			ai.WithTools(addTool),
200			ai.WithTools(multiplyTool),
201		)
202		result, err := agent.Generate(t.Context(), ai.AgentCall{
203			Prompt:          "Add and multiply the number 2 and 3",
204			ProviderOptions: pair.providerOptions,
205			MaxOutputTokens: ai.IntOption(4000),
206		})
207		require.NoError(t, err, "failed to generate")
208		checkResult(t, result)
209	})
210	t.Run("multi tool streaming "+pair.name, func(t *testing.T) {
211		r := newRecorder(t)
212
213		languageModel, err := pair.builder(r)
214		require.NoError(t, err, "failed to build language model")
215
216		agent := ai.NewAgent(
217			languageModel,
218			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
219			ai.WithTools(addTool),
220			ai.WithTools(multiplyTool),
221		)
222		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
223			Prompt:          "Add and multiply the number 2 and 3",
224			ProviderOptions: pair.providerOptions,
225			MaxOutputTokens: ai.IntOption(4000),
226		})
227		require.NoError(t, err, "failed to generate")
228		checkResult(t, result)
229	})
230}
231
232func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
233	for _, pair := range pairs {
234		t.Run("thinking-"+pair.name, func(t *testing.T) {
235			r := newRecorder(t)
236
237			languageModel, err := pair.builder(r)
238			require.NoError(t, err, "failed to build language model")
239
240			type WeatherInput struct {
241				Location string `json:"location" description:"the city"`
242			}
243
244			weatherTool := ai.NewAgentTool(
245				"weather",
246				"Get weather information for a location",
247				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
248					return ai.NewTextResponse("40 C"), nil
249				},
250			)
251
252			agent := ai.NewAgent(
253				languageModel,
254				ai.WithSystemPrompt("You are a helpful assistant"),
255				ai.WithTools(weatherTool),
256			)
257			result, err := agent.Generate(t.Context(), ai.AgentCall{
258				Prompt:          "What's the weather in Florence, Italy?",
259				ProviderOptions: pair.providerOptions,
260			})
261			require.NoError(t, err, "failed to generate")
262
263			want1 := "Florence"
264			want2 := "40"
265			got := result.Response.Content.Text()
266			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
267
268			thinkChecks(t, result)
269		})
270		t.Run("thinking-streaming-"+pair.name, func(t *testing.T) {
271			r := newRecorder(t)
272
273			languageModel, err := pair.builder(r)
274			require.NoError(t, err, "failed to build language model")
275
276			type WeatherInput struct {
277				Location string `json:"location" description:"the city"`
278			}
279
280			weatherTool := ai.NewAgentTool(
281				"weather",
282				"Get weather information for a location",
283				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
284					return ai.NewTextResponse("40 C"), nil
285				},
286			)
287
288			agent := ai.NewAgent(
289				languageModel,
290				ai.WithSystemPrompt("You are a helpful assistant"),
291				ai.WithTools(weatherTool),
292			)
293			result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
294				Prompt:          "What's the weather in Florence, Italy?",
295				ProviderOptions: pair.providerOptions,
296			})
297			require.NoError(t, err, "failed to generate")
298
299			want1 := "Florence"
300			want2 := "40"
301			got := result.Response.Content.Text()
302			require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
303
304			thinkChecks(t, result)
305		})
306	}
307}