common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"os"
  6	"strconv"
  7	"strings"
  8	"testing"
  9
 10	"github.com/charmbracelet/fantasy/ai"
 11	"github.com/joho/godotenv"
 12	"github.com/stretchr/testify/require"
 13	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 14)
 15
 16func init() {
 17	if _, err := os.Stat(".env"); err == nil {
 18		godotenv.Load(".env")
 19	} else {
 20		godotenv.Load(".env.sample")
 21	}
 22}
 23
 24type testModel struct {
 25	name      string
 26	model     string
 27	reasoning bool
 28}
 29
 30type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 31
 32type builderPair struct {
 33	name            string
 34	builder         builderFunc
 35	providerOptions ai.ProviderOptions
 36}
 37
 38func testCommon(t *testing.T, pairs []builderPair) {
 39	for _, pair := range pairs {
 40		t.Run(pair.name, func(t *testing.T) {
 41			testSimple(t, pair)
 42			testTool(t, pair)
 43			testMultiTool(t, pair)
 44		})
 45	}
 46}
 47
 48func testSimple(t *testing.T, pair builderPair) {
 49	checkResult := func(t *testing.T, result *ai.AgentResult) {
 50		options := []string{"Oi", "oi", "Olá", "olá"}
 51		got := result.Response.Content.Text()
 52		require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
 53	}
 54
 55	t.Run("simple", func(t *testing.T) {
 56		r := newRecorder(t)
 57
 58		languageModel, err := pair.builder(r)
 59		require.NoError(t, err, "failed to build language model")
 60
 61		agent := ai.NewAgent(
 62			languageModel,
 63			ai.WithSystemPrompt("You are a helpful assistant"),
 64		)
 65		result, err := agent.Generate(t.Context(), ai.AgentCall{
 66			Prompt:          "Say hi in Portuguese",
 67			ProviderOptions: pair.providerOptions,
 68			MaxOutputTokens: ai.IntOption(4000),
 69		})
 70		require.NoError(t, err, "failed to generate")
 71		checkResult(t, result)
 72	})
 73	t.Run("simple streaming", func(t *testing.T) {
 74		r := newRecorder(t)
 75
 76		languageModel, err := pair.builder(r)
 77		require.NoError(t, err, "failed to build language model")
 78
 79		agent := ai.NewAgent(
 80			languageModel,
 81			ai.WithSystemPrompt("You are a helpful assistant"),
 82		)
 83		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 84			Prompt:          "Say hi in Portuguese",
 85			ProviderOptions: pair.providerOptions,
 86			MaxOutputTokens: ai.IntOption(4000),
 87		})
 88		require.NoError(t, err, "failed to generate")
 89		checkResult(t, result)
 90	})
 91}
 92
 93func testTool(t *testing.T, pair builderPair) {
 94	type WeatherInput struct {
 95		Location string `json:"location" description:"the city"`
 96	}
 97
 98	weatherTool := ai.NewAgentTool(
 99		"weather",
100		"Get weather information for a location",
101		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
102			return ai.NewTextResponse("40 C"), nil
103		},
104	)
105	checkResult := func(t *testing.T, result *ai.AgentResult) {
106		require.GreaterOrEqual(t, len(result.Steps), 2)
107
108		var toolCalls []ai.ToolCallContent
109		for _, content := range result.Steps[0].Content {
110			if content.GetType() == ai.ContentTypeToolCall {
111				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
112			}
113		}
114		for _, tc := range toolCalls {
115			require.False(t, tc.Invalid)
116		}
117		require.Len(t, toolCalls, 1)
118		require.Equal(t, toolCalls[0].ToolName, "weather")
119
120		want1 := "Florence"
121		want2 := "40"
122		got := result.Response.Content.Text()
123		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
124	}
125
126	t.Run("tool", func(t *testing.T) {
127		r := newRecorder(t)
128
129		languageModel, err := pair.builder(r)
130		require.NoError(t, err, "failed to build language model")
131
132		agent := ai.NewAgent(
133			languageModel,
134			ai.WithSystemPrompt("You are a helpful assistant"),
135			ai.WithTools(weatherTool),
136		)
137		result, err := agent.Generate(t.Context(), ai.AgentCall{
138			Prompt:          "What's the weather in Florence,Italy?",
139			ProviderOptions: pair.providerOptions,
140			MaxOutputTokens: ai.IntOption(4000),
141		})
142		require.NoError(t, err, "failed to generate")
143		checkResult(t, result)
144	})
145	t.Run("tool streaming", func(t *testing.T) {
146		r := newRecorder(t)
147
148		languageModel, err := pair.builder(r)
149		require.NoError(t, err, "failed to build language model")
150
151		agent := ai.NewAgent(
152			languageModel,
153			ai.WithSystemPrompt("You are a helpful assistant"),
154			ai.WithTools(weatherTool),
155		)
156		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
157			Prompt:          "What's the weather in Florence,Italy?",
158			ProviderOptions: pair.providerOptions,
159			MaxOutputTokens: ai.IntOption(4000),
160		})
161		require.NoError(t, err, "failed to generate")
162		checkResult(t, result)
163	})
164}
165
166func testMultiTool(t *testing.T, pair builderPair) {
167	// Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
168	if strings.Contains(pair.name, "azure") {
169		t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
170	}
171	if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
172		t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
173	}
174	if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
175		t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
176	}
177
178	type CalculatorInput struct {
179		A int `json:"a" description:"first number"`
180		B int `json:"b" description:"second number"`
181	}
182
183	addTool := ai.NewAgentTool(
184		"add",
185		"Add two numbers",
186		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
187			result := input.A + input.B
188			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
189		},
190	)
191	multiplyTool := ai.NewAgentTool(
192		"multiply",
193		"Multiply two numbers",
194		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
195			result := input.A * input.B
196			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
197		},
198	)
199	checkResult := func(t *testing.T, result *ai.AgentResult) {
200		require.Len(t, result.Steps, 2)
201
202		var toolCalls []ai.ToolCallContent
203		for _, content := range result.Steps[0].Content {
204			if content.GetType() == ai.ContentTypeToolCall {
205				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
206			}
207		}
208		for _, tc := range toolCalls {
209			require.False(t, tc.Invalid)
210		}
211		require.Len(t, toolCalls, 2)
212
213		finalText := result.Response.Content.Text()
214		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
215		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
216	}
217
218	t.Run("multi tool", func(t *testing.T) {
219		r := newRecorder(t)
220
221		languageModel, err := pair.builder(r)
222		require.NoError(t, err, "failed to build language model")
223
224		agent := ai.NewAgent(
225			languageModel,
226			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
227			ai.WithTools(addTool),
228			ai.WithTools(multiplyTool),
229		)
230		result, err := agent.Generate(t.Context(), ai.AgentCall{
231			Prompt:          "Add and multiply the number 2 and 3",
232			ProviderOptions: pair.providerOptions,
233			MaxOutputTokens: ai.IntOption(4000),
234		})
235		require.NoError(t, err, "failed to generate")
236		checkResult(t, result)
237	})
238	t.Run("multi tool streaming", func(t *testing.T) {
239		r := newRecorder(t)
240
241		languageModel, err := pair.builder(r)
242		require.NoError(t, err, "failed to build language model")
243
244		agent := ai.NewAgent(
245			languageModel,
246			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
247			ai.WithTools(addTool),
248			ai.WithTools(multiplyTool),
249		)
250		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
251			Prompt:          "Add and multiply the number 2 and 3",
252			ProviderOptions: pair.providerOptions,
253			MaxOutputTokens: ai.IntOption(4000),
254		})
255		require.NoError(t, err, "failed to generate")
256		checkResult(t, result)
257	})
258}
259
260func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
261	for _, pair := range pairs {
262		t.Run(pair.name, func(t *testing.T) {
263			t.Run("thinking", func(t *testing.T) {
264				r := newRecorder(t)
265
266				languageModel, err := pair.builder(r)
267				require.NoError(t, err, "failed to build language model")
268
269				type WeatherInput struct {
270					Location string `json:"location" description:"the city"`
271				}
272
273				weatherTool := ai.NewAgentTool(
274					"weather",
275					"Get weather information for a location",
276					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
277						return ai.NewTextResponse("40 C"), nil
278					},
279				)
280
281				agent := ai.NewAgent(
282					languageModel,
283					ai.WithSystemPrompt("You are a helpful assistant"),
284					ai.WithTools(weatherTool),
285				)
286				result, err := agent.Generate(t.Context(), ai.AgentCall{
287					Prompt:          "What's the weather in Florence, Italy?",
288					ProviderOptions: pair.providerOptions,
289				})
290				require.NoError(t, err, "failed to generate")
291
292				want1 := "Florence"
293				want2 := "40"
294				got := result.Response.Content.Text()
295				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
296
297				thinkChecks(t, result)
298			})
299			t.Run("thinking-streaming", func(t *testing.T) {
300				r := newRecorder(t)
301
302				languageModel, err := pair.builder(r)
303				require.NoError(t, err, "failed to build language model")
304
305				type WeatherInput struct {
306					Location string `json:"location" description:"the city"`
307				}
308
309				weatherTool := ai.NewAgentTool(
310					"weather",
311					"Get weather information for a location",
312					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
313						return ai.NewTextResponse("40 C"), nil
314					},
315				)
316
317				agent := ai.NewAgent(
318					languageModel,
319					ai.WithSystemPrompt("You are a helpful assistant"),
320					ai.WithTools(weatherTool),
321				)
322				result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
323					Prompt:          "What's the weather in Florence, Italy?",
324					ProviderOptions: pair.providerOptions,
325				})
326				require.NoError(t, err, "failed to generate")
327
328				want1 := "Florence"
329				want2 := "40"
330				got := result.Response.Content.Text()
331				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
332
333				thinkChecks(t, result)
334			})
335		})
336	}
337}
338
339func containsAny(s string, subs ...string) bool {
340	for _, sub := range subs {
341		if strings.Contains(s, sub) {
342			return true
343		}
344	}
345	return false
346}