common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"os"
  6	"strconv"
  7	"strings"
  8	"testing"
  9
 10	"github.com/charmbracelet/fantasy/ai"
 11	"github.com/joho/godotenv"
 12	"github.com/stretchr/testify/require"
 13	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 14)
 15
 16func init() {
 17	if _, err := os.Stat(".env"); err == nil {
 18		godotenv.Load(".env")
 19	} else {
 20		godotenv.Load(".env.sample")
 21	}
 22}
 23
 24type testModel struct {
 25	name      string
 26	model     string
 27	reasoning bool
 28}
 29
 30type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 31
 32type builderPair struct {
 33	name            string
 34	builder         builderFunc
 35	providerOptions ai.ProviderOptions
 36}
 37
 38func testCommon(t *testing.T, pairs []builderPair) {
 39	for _, pair := range pairs {
 40		t.Run(pair.name, func(t *testing.T) {
 41			testSimple(t, pair)
 42			testTool(t, pair)
 43			testMultiTool(t, pair)
 44		})
 45	}
 46}
 47
 48func testSimple(t *testing.T, pair builderPair) {
 49	checkResult := func(t *testing.T, result *ai.AgentResult) {
 50		options := []string{"Oi", "oi", "Olá", "olá"}
 51		got := result.Response.Content.Text()
 52		require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
 53	}
 54
 55	t.Run("simple", func(t *testing.T) {
 56		r := newRecorder(t)
 57
 58		languageModel, err := pair.builder(r)
 59		require.NoError(t, err, "failed to build language model")
 60
 61		agent := ai.NewAgent(
 62			languageModel,
 63			ai.WithSystemPrompt("You are a helpful assistant"),
 64		)
 65		result, err := agent.Generate(t.Context(), ai.AgentCall{
 66			Prompt:          "Say hi in Portuguese",
 67			ProviderOptions: pair.providerOptions,
 68			MaxOutputTokens: ai.Opt(int64(4000)),
 69		})
 70		require.NoError(t, err, "failed to generate")
 71		checkResult(t, result)
 72	})
 73	t.Run("simple streaming", func(t *testing.T) {
 74		r := newRecorder(t)
 75
 76		languageModel, err := pair.builder(r)
 77		require.NoError(t, err, "failed to build language model")
 78
 79		agent := ai.NewAgent(
 80			languageModel,
 81			ai.WithSystemPrompt("You are a helpful assistant"),
 82		)
 83		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 84			Prompt:          "Say hi in Portuguese",
 85			ProviderOptions: pair.providerOptions,
 86			MaxOutputTokens: ai.Opt(int64(4000)),
 87		})
 88		require.NoError(t, err, "failed to generate")
 89		checkResult(t, result)
 90	})
 91}
 92
 93func testTool(t *testing.T, pair builderPair) {
 94	type WeatherInput struct {
 95		Location string `json:"location" description:"the city"`
 96	}
 97
 98	weatherTool := ai.NewAgentTool(
 99		"weather",
100		"Get weather information for a location",
101		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
102			return ai.NewTextResponse("40 C"), nil
103		},
104	)
105	checkResult := func(t *testing.T, result *ai.AgentResult) {
106		require.GreaterOrEqual(t, len(result.Steps), 2)
107
108		var toolCalls []ai.ToolCallContent
109		for _, content := range result.Steps[0].Content {
110			if content.GetType() == ai.ContentTypeToolCall {
111				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
112			}
113		}
114		for _, tc := range toolCalls {
115			require.False(t, tc.Invalid)
116		}
117		require.Len(t, toolCalls, 1)
118		require.Equal(t, toolCalls[0].ToolName, "weather")
119
120		want1 := "Florence"
121		want2 := "40"
122		got := result.Response.Content.Text()
123		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
124	}
125
126	t.Run("tool", func(t *testing.T) {
127		r := newRecorder(t)
128
129		languageModel, err := pair.builder(r)
130		require.NoError(t, err, "failed to build language model")
131
132		agent := ai.NewAgent(
133			languageModel,
134			ai.WithSystemPrompt("You are a helpful assistant"),
135			ai.WithTools(weatherTool),
136		)
137		result, err := agent.Generate(t.Context(), ai.AgentCall{
138			Prompt:          "What's the weather in Florence,Italy?",
139			ProviderOptions: pair.providerOptions,
140			MaxOutputTokens: ai.Opt(int64(4000)),
141		})
142		require.NoError(t, err, "failed to generate")
143		checkResult(t, result)
144	})
145	t.Run("tool streaming", func(t *testing.T) {
146		r := newRecorder(t)
147
148		languageModel, err := pair.builder(r)
149		require.NoError(t, err, "failed to build language model")
150
151		agent := ai.NewAgent(
152			languageModel,
153			ai.WithSystemPrompt("You are a helpful assistant"),
154			ai.WithTools(weatherTool),
155		)
156		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
157			Prompt:          "What's the weather in Florence,Italy?",
158			ProviderOptions: pair.providerOptions,
159			MaxOutputTokens: ai.Opt(int64(4000)),
160		})
161		require.NoError(t, err, "failed to generate")
162		checkResult(t, result)
163	})
164}
165
166func testMultiTool(t *testing.T, pair builderPair) {
167	// Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
168	if strings.Contains(pair.name, "azure") {
169		t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
170	}
171	if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
172		t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
173	}
174	if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
175		t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
176	}
177	if strings.Contains(pair.name, "openai") && strings.Contains(pair.name, "o4-mini") {
178		t.Skip("skipping multi-tool tests for openai o4-mini it for some reason is not doing parallel tool calls even if asked")
179	}
180
181	type CalculatorInput struct {
182		A int `json:"a" description:"first number"`
183		B int `json:"b" description:"second number"`
184	}
185
186	addTool := ai.NewAgentTool(
187		"add",
188		"Add two numbers",
189		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
190			result := input.A + input.B
191			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
192		},
193	)
194	multiplyTool := ai.NewAgentTool(
195		"multiply",
196		"Multiply two numbers",
197		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
198			result := input.A * input.B
199			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
200		},
201	)
202	checkResult := func(t *testing.T, result *ai.AgentResult) {
203		require.Len(t, result.Steps, 2)
204
205		var toolCalls []ai.ToolCallContent
206		for _, content := range result.Steps[0].Content {
207			if content.GetType() == ai.ContentTypeToolCall {
208				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
209			}
210		}
211		for _, tc := range toolCalls {
212			require.False(t, tc.Invalid)
213		}
214		require.Len(t, toolCalls, 2)
215
216		finalText := result.Response.Content.Text()
217		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
218		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
219	}
220
221	t.Run("multi tool", func(t *testing.T) {
222		r := newRecorder(t)
223
224		languageModel, err := pair.builder(r)
225		require.NoError(t, err, "failed to build language model")
226
227		agent := ai.NewAgent(
228			languageModel,
229			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
230			ai.WithTools(addTool),
231			ai.WithTools(multiplyTool),
232		)
233		result, err := agent.Generate(t.Context(), ai.AgentCall{
234			Prompt:          "Add and multiply the number 2 and 3",
235			ProviderOptions: pair.providerOptions,
236			MaxOutputTokens: ai.Opt(int64(4000)),
237		})
238		require.NoError(t, err, "failed to generate")
239		checkResult(t, result)
240	})
241	t.Run("multi tool streaming", func(t *testing.T) {
242		r := newRecorder(t)
243
244		languageModel, err := pair.builder(r)
245		require.NoError(t, err, "failed to build language model")
246
247		agent := ai.NewAgent(
248			languageModel,
249			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
250			ai.WithTools(addTool),
251			ai.WithTools(multiplyTool),
252		)
253		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
254			Prompt:          "Add and multiply the number 2 and 3",
255			ProviderOptions: pair.providerOptions,
256			MaxOutputTokens: ai.Opt(int64(4000)),
257		})
258		require.NoError(t, err, "failed to generate")
259		checkResult(t, result)
260	})
261}
262
263func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
264	for _, pair := range pairs {
265		t.Run(pair.name, func(t *testing.T) {
266			t.Run("thinking", func(t *testing.T) {
267				r := newRecorder(t)
268
269				languageModel, err := pair.builder(r)
270				require.NoError(t, err, "failed to build language model")
271
272				type WeatherInput struct {
273					Location string `json:"location" description:"the city"`
274				}
275
276				weatherTool := ai.NewAgentTool(
277					"weather",
278					"Get weather information for a location",
279					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
280						return ai.NewTextResponse("40 C"), nil
281					},
282				)
283
284				agent := ai.NewAgent(
285					languageModel,
286					ai.WithSystemPrompt("You are a helpful assistant"),
287					ai.WithTools(weatherTool),
288				)
289				result, err := agent.Generate(t.Context(), ai.AgentCall{
290					Prompt:          "What's the weather in Florence, Italy?",
291					ProviderOptions: pair.providerOptions,
292				})
293				require.NoError(t, err, "failed to generate")
294
295				want1 := "Florence"
296				want2 := "40"
297				got := result.Response.Content.Text()
298				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
299
300				thinkChecks(t, result)
301			})
302			t.Run("thinking-streaming", func(t *testing.T) {
303				r := newRecorder(t)
304
305				languageModel, err := pair.builder(r)
306				require.NoError(t, err, "failed to build language model")
307
308				type WeatherInput struct {
309					Location string `json:"location" description:"the city"`
310				}
311
312				weatherTool := ai.NewAgentTool(
313					"weather",
314					"Get weather information for a location",
315					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
316						return ai.NewTextResponse("40 C"), nil
317					},
318				)
319
320				agent := ai.NewAgent(
321					languageModel,
322					ai.WithSystemPrompt("You are a helpful assistant"),
323					ai.WithTools(weatherTool),
324				)
325				result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
326					Prompt:          "What's the weather in Florence, Italy?",
327					ProviderOptions: pair.providerOptions,
328				})
329				require.NoError(t, err, "failed to generate")
330
331				want1 := "Florence"
332				want2 := "40"
333				got := result.Response.Content.Text()
334				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
335
336				thinkChecks(t, result)
337			})
338		})
339	}
340}
341
342func containsAny(s string, subs ...string) bool {
343	for _, sub := range subs {
344		if strings.Contains(s, sub) {
345			return true
346		}
347	}
348	return false
349}