common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"os"
  6	"strconv"
  7	"strings"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"charm.land/x/vcr"
 12	"github.com/joho/godotenv"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16func init() {
 17	if _, err := os.Stat(".env"); err == nil {
 18		godotenv.Load(".env")
 19	} else {
 20		godotenv.Load(".env.sample")
 21	}
 22}
 23
 24type testModel struct {
 25	name      string
 26	model     string
 27	reasoning bool
 28}
 29
 30type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
 31
 32type builderPair struct {
 33	name            string
 34	builder         builderFunc
 35	providerOptions fantasy.ProviderOptions
 36	prepareStep     fantasy.PrepareStepFunction
 37}
 38
 39func testCommon(t *testing.T, pairs []builderPair) {
 40	for _, pair := range pairs {
 41		t.Run(pair.name, func(t *testing.T) {
 42			testSimple(t, pair)
 43			testTool(t, pair)
 44			testMultiTool(t, pair)
 45		})
 46	}
 47}
 48
 49func testSimple(t *testing.T, pair builderPair) {
 50	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
 51		options := []string{"Oi", "oi", "Olá", "olá"}
 52		got := result.Response.Content.Text()
 53		require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
 54	}
 55
 56	t.Run("simple", func(t *testing.T) {
 57		if strings.HasPrefix(pair.name, "avian-") {
 58			t.Skip("Avian only support streaming")
 59		}
 60
 61		r := vcr.NewRecorder(t)
 62
 63		languageModel, err := pair.builder(t, r)
 64		require.NoError(t, err, "failed to build language model")
 65
 66		agent := fantasy.NewAgent(
 67			languageModel,
 68			fantasy.WithSystemPrompt("You are a helpful assistant"),
 69		)
 70		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
 71			Prompt:          "Say hi in Portuguese",
 72			ProviderOptions: pair.providerOptions,
 73			MaxOutputTokens: fantasy.Opt(int64(4000)),
 74			PrepareStep:     pair.prepareStep,
 75		})
 76		require.NoError(t, err, "failed to generate")
 77		checkResult(t, result)
 78	})
 79
 80	t.Run("simple streaming", func(t *testing.T) {
 81		r := vcr.NewRecorder(t)
 82
 83		languageModel, err := pair.builder(t, r)
 84		require.NoError(t, err, "failed to build language model")
 85
 86		agent := fantasy.NewAgent(
 87			languageModel,
 88			fantasy.WithSystemPrompt("You are a helpful assistant"),
 89		)
 90		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
 91			Prompt:          "Say hi in Portuguese",
 92			ProviderOptions: pair.providerOptions,
 93			MaxOutputTokens: fantasy.Opt(int64(4000)),
 94			PrepareStep:     pair.prepareStep,
 95		})
 96		require.NoError(t, err, "failed to generate")
 97		checkResult(t, result)
 98	})
 99}
100
101func testTool(t *testing.T, pair builderPair) {
102	type WeatherInput struct {
103		Location string `json:"location" description:"the city"`
104	}
105
106	weatherTool := fantasy.NewAgentTool(
107		"weather",
108		"Get weather information for a location",
109		func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
110			return fantasy.NewTextResponse("40 C"), nil
111		},
112	)
113	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
114		require.GreaterOrEqual(t, len(result.Steps), 2)
115
116		var toolCalls []fantasy.ToolCallContent
117		for _, content := range result.Steps[0].Content {
118			if content.GetType() == fantasy.ContentTypeToolCall {
119				toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
120			}
121		}
122		for _, tc := range toolCalls {
123			require.False(t, tc.Invalid)
124		}
125		require.Len(t, toolCalls, 1)
126		require.Equal(t, toolCalls[0].ToolName, "weather")
127
128		want1 := "Florence"
129		want2 := "40"
130		got := result.Response.Content.Text()
131		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
132	}
133
134	t.Run("tool", func(t *testing.T) {
135		if strings.HasPrefix(pair.name, "avian-") {
136			t.Skip("Avian only support streaming")
137		}
138
139		r := vcr.NewRecorder(t)
140
141		languageModel, err := pair.builder(t, r)
142		require.NoError(t, err, "failed to build language model")
143
144		agent := fantasy.NewAgent(
145			languageModel,
146			fantasy.WithSystemPrompt("You are a helpful assistant"),
147			fantasy.WithTools(weatherTool),
148		)
149		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
150			Prompt:          "What's the weather in Florence,Italy?",
151			ProviderOptions: pair.providerOptions,
152			MaxOutputTokens: fantasy.Opt(int64(4000)),
153			PrepareStep:     pair.prepareStep,
154		})
155		require.NoError(t, err, "failed to generate")
156		checkResult(t, result)
157	})
158
159	t.Run("tool streaming", func(t *testing.T) {
160		r := vcr.NewRecorder(t)
161
162		languageModel, err := pair.builder(t, r)
163		require.NoError(t, err, "failed to build language model")
164
165		agent := fantasy.NewAgent(
166			languageModel,
167			fantasy.WithSystemPrompt("You are a helpful assistant"),
168			fantasy.WithTools(weatherTool),
169		)
170		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
171			Prompt:          "What's the weather in Florence,Italy?",
172			ProviderOptions: pair.providerOptions,
173			MaxOutputTokens: fantasy.Opt(int64(4000)),
174			PrepareStep:     pair.prepareStep,
175		})
176		require.NoError(t, err, "failed to generate")
177		checkResult(t, result)
178	})
179}
180
181func testMultiTool(t *testing.T, pair builderPair) {
182	// Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
183	if strings.Contains(pair.name, "azure") {
184		t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
185	}
186	if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
187		t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
188	}
189	if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
190		t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
191	}
192	if strings.Contains(pair.name, "openai") && strings.Contains(pair.name, "o4-mini") {
193		t.Skip("skipping multi-tool tests for openai o4-mini it for some reason is not doing parallel tool calls even if asked")
194	}
195	if strings.Contains(pair.name, "llama-cpp") && strings.Contains(pair.name, "gpt-oss") {
196		t.Skip("skipping multi-tool tests for llama-cpp gpt-oss as it does not support parallel multi-tool calls")
197	}
198
199	type CalculatorInput struct {
200		A int `json:"a" description:"first number"`
201		B int `json:"b" description:"second number"`
202	}
203
204	addTool := fantasy.NewAgentTool(
205		"add",
206		"Add two numbers",
207		func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
208			result := input.A + input.B
209			return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
210		},
211	)
212	multiplyTool := fantasy.NewAgentTool(
213		"multiply",
214		"Multiply two numbers",
215		func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
216			result := input.A * input.B
217			return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
218		},
219	)
220	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
221		require.Len(t, result.Steps, 2)
222
223		var toolCalls []fantasy.ToolCallContent
224		for _, content := range result.Steps[0].Content {
225			if content.GetType() == fantasy.ContentTypeToolCall {
226				toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
227			}
228		}
229		for _, tc := range toolCalls {
230			require.False(t, tc.Invalid)
231		}
232		require.Len(t, toolCalls, 2)
233
234		finalText := result.Response.Content.Text()
235		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
236		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
237	}
238
239	t.Run("multi tool", func(t *testing.T) {
240		if strings.HasPrefix(pair.name, "avian-") {
241			t.Skip("Avian only support streaming")
242		}
243
244		r := vcr.NewRecorder(t)
245
246		languageModel, err := pair.builder(t, r)
247		require.NoError(t, err, "failed to build language model")
248
249		agent := fantasy.NewAgent(
250			languageModel,
251			fantasy.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
252			fantasy.WithTools(addTool),
253			fantasy.WithTools(multiplyTool),
254		)
255		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
256			Prompt:          "Add and multiply the number 2 and 3",
257			ProviderOptions: pair.providerOptions,
258			MaxOutputTokens: fantasy.Opt(int64(4000)),
259			PrepareStep:     pair.prepareStep,
260		})
261		require.NoError(t, err, "failed to generate")
262		checkResult(t, result)
263	})
264
265	t.Run("multi tool streaming", func(t *testing.T) {
266		r := vcr.NewRecorder(t)
267
268		languageModel, err := pair.builder(t, r)
269		require.NoError(t, err, "failed to build language model")
270
271		agent := fantasy.NewAgent(
272			languageModel,
273			fantasy.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
274			fantasy.WithTools(addTool),
275			fantasy.WithTools(multiplyTool),
276		)
277		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
278			Prompt:          "Add and multiply the number 2 and 3",
279			ProviderOptions: pair.providerOptions,
280			MaxOutputTokens: fantasy.Opt(int64(4000)),
281			PrepareStep:     pair.prepareStep,
282		})
283		require.NoError(t, err, "failed to generate")
284		checkResult(t, result)
285	})
286}
287
288func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *fantasy.AgentResult)) {
289	for _, pair := range pairs {
290		t.Run(pair.name, func(t *testing.T) {
291			t.Run("thinking", func(t *testing.T) {
292				if strings.HasPrefix(pair.name, "avian-") {
293					t.Skip("Avian only support streaming")
294				}
295
296				r := vcr.NewRecorder(t)
297
298				languageModel, err := pair.builder(t, r)
299				require.NoError(t, err, "failed to build language model")
300
301				type WeatherInput struct {
302					Location string `json:"location" description:"the city"`
303				}
304
305				weatherTool := fantasy.NewAgentTool(
306					"weather",
307					"Get weather information for a location",
308					func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
309						return fantasy.NewTextResponse("40 C"), nil
310					},
311				)
312
313				agent := fantasy.NewAgent(
314					languageModel,
315					fantasy.WithSystemPrompt("You are a helpful assistant"),
316					fantasy.WithTools(weatherTool),
317				)
318				result, err := agent.Generate(t.Context(), fantasy.AgentCall{
319					Prompt:          "What's the weather in Florence, Italy?",
320					ProviderOptions: pair.providerOptions,
321					PrepareStep:     pair.prepareStep,
322				})
323				require.NoError(t, err, "failed to generate")
324
325				want1 := "Florence"
326				want2 := "40"
327				got := result.Response.Content.Text()
328				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
329
330				thinkChecks(t, result)
331			})
332
333			t.Run("thinking-streaming", func(t *testing.T) {
334				r := vcr.NewRecorder(t)
335
336				languageModel, err := pair.builder(t, r)
337				require.NoError(t, err, "failed to build language model")
338
339				type WeatherInput struct {
340					Location string `json:"location" description:"the city"`
341				}
342
343				weatherTool := fantasy.NewAgentTool(
344					"weather",
345					"Get weather information for a location",
346					func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
347						return fantasy.NewTextResponse("40 C"), nil
348					},
349				)
350
351				agent := fantasy.NewAgent(
352					languageModel,
353					fantasy.WithSystemPrompt("You are a helpful assistant"),
354					fantasy.WithTools(weatherTool),
355				)
356				result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
357					Prompt:          "What's the weather in Florence, Italy?",
358					ProviderOptions: pair.providerOptions,
359					PrepareStep:     pair.prepareStep,
360				})
361				require.NoError(t, err, "failed to generate")
362
363				want1 := "Florence"
364				want2 := "40"
365				got := result.Response.Content.Text()
366				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
367
368				thinkChecks(t, result)
369			})
370		})
371	}
372}
373
374func containsAny(s string, subs ...string) bool {
375	for _, sub := range subs {
376		if strings.Contains(s, sub) {
377			return true
378		}
379	}
380	return false
381}