common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	"github.com/stretchr/testify/require"
 11	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 12)
 13
 14type testModel struct {
 15	name      string
 16	model     string
 17	reasoning bool
 18}
 19
 20type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 21
 22type builderPair struct {
 23	name            string
 24	builder         builderFunc
 25	providerOptions ai.ProviderOptions
 26}
 27
 28func testCommon(t *testing.T, pairs []builderPair) {
 29	for _, pair := range pairs {
 30		t.Run(pair.name, func(t *testing.T) {
 31			testSimple(t, pair)
 32			testTool(t, pair)
 33			testMultiTool(t, pair)
 34		})
 35	}
 36}
 37
 38func testSimple(t *testing.T, pair builderPair) {
 39	checkResult := func(t *testing.T, result *ai.AgentResult) {
 40		option1 := "Oi"
 41		option2 := "Olá"
 42		got := result.Response.Content.Text()
 43		require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
 44	}
 45
 46	t.Run("simple", func(t *testing.T) {
 47		r := newRecorder(t)
 48
 49		languageModel, err := pair.builder(r)
 50		require.NoError(t, err, "failed to build language model")
 51
 52		agent := ai.NewAgent(
 53			languageModel,
 54			ai.WithSystemPrompt("You are a helpful assistant"),
 55		)
 56		result, err := agent.Generate(t.Context(), ai.AgentCall{
 57			Prompt:          "Say hi in Portuguese",
 58			ProviderOptions: pair.providerOptions,
 59			MaxOutputTokens: ai.IntOption(4000),
 60		})
 61		require.NoError(t, err, "failed to generate")
 62		checkResult(t, result)
 63	})
 64	t.Run("simple streaming", func(t *testing.T) {
 65		r := newRecorder(t)
 66
 67		languageModel, err := pair.builder(r)
 68		require.NoError(t, err, "failed to build language model")
 69
 70		agent := ai.NewAgent(
 71			languageModel,
 72			ai.WithSystemPrompt("You are a helpful assistant"),
 73		)
 74		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 75			Prompt:          "Say hi in Portuguese",
 76			ProviderOptions: pair.providerOptions,
 77			MaxOutputTokens: ai.IntOption(4000),
 78		})
 79		require.NoError(t, err, "failed to generate")
 80		checkResult(t, result)
 81	})
 82}
 83
 84func testTool(t *testing.T, pair builderPair) {
 85	type WeatherInput struct {
 86		Location string `json:"location" description:"the city"`
 87	}
 88
 89	weatherTool := ai.NewAgentTool(
 90		"weather",
 91		"Get weather information for a location",
 92		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 93			return ai.NewTextResponse("40 C"), nil
 94		},
 95	)
 96	checkResult := func(t *testing.T, result *ai.AgentResult) {
 97		require.Len(t, result.Steps, 2)
 98
 99		var toolCalls []ai.ToolCallContent
100		for _, content := range result.Steps[0].Content {
101			if content.GetType() == ai.ContentTypeToolCall {
102				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
103			}
104		}
105		for _, tc := range toolCalls {
106			require.False(t, tc.Invalid)
107		}
108		require.Len(t, toolCalls, 1)
109		require.Equal(t, toolCalls[0].ToolName, "weather")
110
111		want1 := "Florence"
112		want2 := "40"
113		got := result.Response.Content.Text()
114		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
115	}
116
117	t.Run("tool", func(t *testing.T) {
118		r := newRecorder(t)
119
120		languageModel, err := pair.builder(r)
121		require.NoError(t, err, "failed to build language model")
122
123		agent := ai.NewAgent(
124			languageModel,
125			ai.WithSystemPrompt("You are a helpful assistant"),
126			ai.WithTools(weatherTool),
127		)
128		result, err := agent.Generate(t.Context(), ai.AgentCall{
129			Prompt:          "What's the weather in Florence,Italy?",
130			ProviderOptions: pair.providerOptions,
131			MaxOutputTokens: ai.IntOption(4000),
132		})
133		require.NoError(t, err, "failed to generate")
134		checkResult(t, result)
135	})
136	t.Run("tool streaming", func(t *testing.T) {
137		r := newRecorder(t)
138
139		languageModel, err := pair.builder(r)
140		require.NoError(t, err, "failed to build language model")
141
142		agent := ai.NewAgent(
143			languageModel,
144			ai.WithSystemPrompt("You are a helpful assistant"),
145			ai.WithTools(weatherTool),
146		)
147		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
148			Prompt:          "What's the weather in Florence,Italy?",
149			ProviderOptions: pair.providerOptions,
150			MaxOutputTokens: ai.IntOption(4000),
151		})
152		require.NoError(t, err, "failed to generate")
153		checkResult(t, result)
154	})
155}
156
157func testMultiTool(t *testing.T, pair builderPair) {
158	type CalculatorInput struct {
159		A int `json:"a" description:"first number"`
160		B int `json:"b" description:"second number"`
161	}
162
163	addTool := ai.NewAgentTool(
164		"add",
165		"Add two numbers",
166		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
167			result := input.A + input.B
168			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
169		},
170	)
171	multiplyTool := ai.NewAgentTool(
172		"multiply",
173		"Multiply two numbers",
174		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
175			result := input.A * input.B
176			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
177		},
178	)
179	checkResult := func(t *testing.T, result *ai.AgentResult) {
180		require.Len(t, result.Steps, 2)
181
182		var toolCalls []ai.ToolCallContent
183		for _, content := range result.Steps[0].Content {
184			if content.GetType() == ai.ContentTypeToolCall {
185				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
186			}
187		}
188		for _, tc := range toolCalls {
189			require.False(t, tc.Invalid)
190		}
191		require.Len(t, toolCalls, 2)
192
193		finalText := result.Response.Content.Text()
194		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
195		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
196	}
197
198	t.Run("multi tool", func(t *testing.T) {
199		r := newRecorder(t)
200
201		languageModel, err := pair.builder(r)
202		require.NoError(t, err, "failed to build language model")
203
204		agent := ai.NewAgent(
205			languageModel,
206			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
207			ai.WithTools(addTool),
208			ai.WithTools(multiplyTool),
209		)
210		result, err := agent.Generate(t.Context(), ai.AgentCall{
211			Prompt:          "Add and multiply the number 2 and 3",
212			ProviderOptions: pair.providerOptions,
213			MaxOutputTokens: ai.IntOption(4000),
214		})
215		require.NoError(t, err, "failed to generate")
216		checkResult(t, result)
217	})
218	t.Run("multi tool streaming", func(t *testing.T) {
219		r := newRecorder(t)
220
221		languageModel, err := pair.builder(r)
222		require.NoError(t, err, "failed to build language model")
223
224		agent := ai.NewAgent(
225			languageModel,
226			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
227			ai.WithTools(addTool),
228			ai.WithTools(multiplyTool),
229		)
230		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
231			Prompt:          "Add and multiply the number 2 and 3",
232			ProviderOptions: pair.providerOptions,
233			MaxOutputTokens: ai.IntOption(4000),
234		})
235		require.NoError(t, err, "failed to generate")
236		checkResult(t, result)
237	})
238}
239
240func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
241	for _, pair := range pairs {
242		t.Run(pair.name, func(t *testing.T) {
243			t.Run("thinking", func(t *testing.T) {
244				r := newRecorder(t)
245
246				languageModel, err := pair.builder(r)
247				require.NoError(t, err, "failed to build language model")
248
249				type WeatherInput struct {
250					Location string `json:"location" description:"the city"`
251				}
252
253				weatherTool := ai.NewAgentTool(
254					"weather",
255					"Get weather information for a location",
256					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
257						return ai.NewTextResponse("40 C"), nil
258					},
259				)
260
261				agent := ai.NewAgent(
262					languageModel,
263					ai.WithSystemPrompt("You are a helpful assistant"),
264					ai.WithTools(weatherTool),
265				)
266				result, err := agent.Generate(t.Context(), ai.AgentCall{
267					Prompt:          "What's the weather in Florence, Italy?",
268					ProviderOptions: pair.providerOptions,
269				})
270				require.NoError(t, err, "failed to generate")
271
272				want1 := "Florence"
273				want2 := "40"
274				got := result.Response.Content.Text()
275				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
276
277				thinkChecks(t, result)
278			})
279			t.Run("thinking-streaming", func(t *testing.T) {
280				r := newRecorder(t)
281
282				languageModel, err := pair.builder(r)
283				require.NoError(t, err, "failed to build language model")
284
285				type WeatherInput struct {
286					Location string `json:"location" description:"the city"`
287				}
288
289				weatherTool := ai.NewAgentTool(
290					"weather",
291					"Get weather information for a location",
292					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
293						return ai.NewTextResponse("40 C"), nil
294					},
295				)
296
297				agent := ai.NewAgent(
298					languageModel,
299					ai.WithSystemPrompt("You are a helpful assistant"),
300					ai.WithTools(weatherTool),
301				)
302				result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
303					Prompt:          "What's the weather in Florence, Italy?",
304					ProviderOptions: pair.providerOptions,
305				})
306				require.NoError(t, err, "failed to generate")
307
308				want1 := "Florence"
309				want2 := "40"
310				got := result.Response.Content.Text()
311				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
312
313				thinkChecks(t, result)
314			})
315		})
316	}
317}