common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"os"
  6	"strconv"
  7	"strings"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"github.com/joho/godotenv"
 12	"github.com/stretchr/testify/require"
 13	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 14)
 15
 16func init() {
 17	if _, err := os.Stat(".env"); err == nil {
 18		godotenv.Load(".env")
 19	} else {
 20		godotenv.Load(".env.sample")
 21	}
 22}
 23
 24type testModel struct {
 25	name      string
 26	model     string
 27	reasoning bool
 28}
 29
 30type builderFunc func(r *recorder.Recorder) (fantasy.LanguageModel, error)
 31
 32type builderPair struct {
 33	name            string
 34	builder         builderFunc
 35	providerOptions fantasy.ProviderOptions
 36	prepareStep     fantasy.PrepareStepFunction
 37}
 38
 39func testCommon(t *testing.T, pairs []builderPair) {
 40	for _, pair := range pairs {
 41		t.Run(pair.name, func(t *testing.T) {
 42			testSimple(t, pair)
 43			testTool(t, pair)
 44			testMultiTool(t, pair)
 45		})
 46	}
 47}
 48
 49func testSimple(t *testing.T, pair builderPair) {
 50	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
 51		options := []string{"Oi", "oi", "Olá", "olá"}
 52		got := result.Response.Content.Text()
 53		require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
 54	}
 55
 56	t.Run("simple", func(t *testing.T) {
 57		r := newRecorder(t)
 58
 59		languageModel, err := pair.builder(r)
 60		require.NoError(t, err, "failed to build language model")
 61
 62		agent := fantasy.NewAgent(
 63			languageModel,
 64			fantasy.WithSystemPrompt("You are a helpful assistant"),
 65		)
 66		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
 67			Prompt:          "Say hi in Portuguese",
 68			ProviderOptions: pair.providerOptions,
 69			MaxOutputTokens: fantasy.Opt(int64(4000)),
 70			PrepareStep:     pair.prepareStep,
 71		})
 72		require.NoError(t, err, "failed to generate")
 73		checkResult(t, result)
 74	})
 75	t.Run("simple streaming", func(t *testing.T) {
 76		r := newRecorder(t)
 77
 78		languageModel, err := pair.builder(r)
 79		require.NoError(t, err, "failed to build language model")
 80
 81		agent := fantasy.NewAgent(
 82			languageModel,
 83			fantasy.WithSystemPrompt("You are a helpful assistant"),
 84		)
 85		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
 86			Prompt:          "Say hi in Portuguese",
 87			ProviderOptions: pair.providerOptions,
 88			MaxOutputTokens: fantasy.Opt(int64(4000)),
 89			PrepareStep:     pair.prepareStep,
 90		})
 91		require.NoError(t, err, "failed to generate")
 92		checkResult(t, result)
 93	})
 94}
 95
 96func testTool(t *testing.T, pair builderPair) {
 97	type WeatherInput struct {
 98		Location string `json:"location" description:"the city"`
 99	}
100
101	weatherTool := fantasy.NewAgentTool(
102		"weather",
103		"Get weather information for a location",
104		func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
105			return fantasy.NewTextResponse("40 C"), nil
106		},
107	)
108	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
109		require.GreaterOrEqual(t, len(result.Steps), 2)
110
111		var toolCalls []fantasy.ToolCallContent
112		for _, content := range result.Steps[0].Content {
113			if content.GetType() == fantasy.ContentTypeToolCall {
114				toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
115			}
116		}
117		for _, tc := range toolCalls {
118			require.False(t, tc.Invalid)
119		}
120		require.Len(t, toolCalls, 1)
121		require.Equal(t, toolCalls[0].ToolName, "weather")
122
123		want1 := "Florence"
124		want2 := "40"
125		got := result.Response.Content.Text()
126		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
127	}
128
129	t.Run("tool", func(t *testing.T) {
130		r := newRecorder(t)
131
132		languageModel, err := pair.builder(r)
133		require.NoError(t, err, "failed to build language model")
134
135		agent := fantasy.NewAgent(
136			languageModel,
137			fantasy.WithSystemPrompt("You are a helpful assistant"),
138			fantasy.WithTools(weatherTool),
139		)
140		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
141			Prompt:          "What's the weather in Florence,Italy?",
142			ProviderOptions: pair.providerOptions,
143			MaxOutputTokens: fantasy.Opt(int64(4000)),
144			PrepareStep:     pair.prepareStep,
145		})
146		require.NoError(t, err, "failed to generate")
147		checkResult(t, result)
148	})
149	t.Run("tool streaming", func(t *testing.T) {
150		r := newRecorder(t)
151
152		languageModel, err := pair.builder(r)
153		require.NoError(t, err, "failed to build language model")
154
155		agent := fantasy.NewAgent(
156			languageModel,
157			fantasy.WithSystemPrompt("You are a helpful assistant"),
158			fantasy.WithTools(weatherTool),
159		)
160		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
161			Prompt:          "What's the weather in Florence,Italy?",
162			ProviderOptions: pair.providerOptions,
163			MaxOutputTokens: fantasy.Opt(int64(4000)),
164			PrepareStep:     pair.prepareStep,
165		})
166		require.NoError(t, err, "failed to generate")
167		checkResult(t, result)
168	})
169}
170
171func testMultiTool(t *testing.T, pair builderPair) {
172	// Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
173	if strings.Contains(pair.name, "azure") {
174		t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
175	}
176	if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
177		t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
178	}
179	if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
180		t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
181	}
182	if strings.Contains(pair.name, "openai") && strings.Contains(pair.name, "o4-mini") {
183		t.Skip("skipping multi-tool tests for openai o4-mini it for some reason is not doing parallel tool calls even if asked")
184	}
185
186	type CalculatorInput struct {
187		A int `json:"a" description:"first number"`
188		B int `json:"b" description:"second number"`
189	}
190
191	addTool := fantasy.NewAgentTool(
192		"add",
193		"Add two numbers",
194		func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
195			result := input.A + input.B
196			return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
197		},
198	)
199	multiplyTool := fantasy.NewAgentTool(
200		"multiply",
201		"Multiply two numbers",
202		func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
203			result := input.A * input.B
204			return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
205		},
206	)
207	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
208		require.Len(t, result.Steps, 2)
209
210		var toolCalls []fantasy.ToolCallContent
211		for _, content := range result.Steps[0].Content {
212			if content.GetType() == fantasy.ContentTypeToolCall {
213				toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
214			}
215		}
216		for _, tc := range toolCalls {
217			require.False(t, tc.Invalid)
218		}
219		require.Len(t, toolCalls, 2)
220
221		finalText := result.Response.Content.Text()
222		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
223		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
224	}
225
226	t.Run("multi tool", func(t *testing.T) {
227		r := newRecorder(t)
228
229		languageModel, err := pair.builder(r)
230		require.NoError(t, err, "failed to build language model")
231
232		agent := fantasy.NewAgent(
233			languageModel,
234			fantasy.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
235			fantasy.WithTools(addTool),
236			fantasy.WithTools(multiplyTool),
237		)
238		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
239			Prompt:          "Add and multiply the number 2 and 3",
240			ProviderOptions: pair.providerOptions,
241			MaxOutputTokens: fantasy.Opt(int64(4000)),
242			PrepareStep:     pair.prepareStep,
243		})
244		require.NoError(t, err, "failed to generate")
245		checkResult(t, result)
246	})
247	t.Run("multi tool streaming", func(t *testing.T) {
248		r := newRecorder(t)
249
250		languageModel, err := pair.builder(r)
251		require.NoError(t, err, "failed to build language model")
252
253		agent := fantasy.NewAgent(
254			languageModel,
255			fantasy.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
256			fantasy.WithTools(addTool),
257			fantasy.WithTools(multiplyTool),
258		)
259		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
260			Prompt:          "Add and multiply the number 2 and 3",
261			ProviderOptions: pair.providerOptions,
262			MaxOutputTokens: fantasy.Opt(int64(4000)),
263			PrepareStep:     pair.prepareStep,
264		})
265		require.NoError(t, err, "failed to generate")
266		checkResult(t, result)
267	})
268}
269
270func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *fantasy.AgentResult)) {
271	for _, pair := range pairs {
272		t.Run(pair.name, func(t *testing.T) {
273			t.Run("thinking", func(t *testing.T) {
274				r := newRecorder(t)
275
276				languageModel, err := pair.builder(r)
277				require.NoError(t, err, "failed to build language model")
278
279				type WeatherInput struct {
280					Location string `json:"location" description:"the city"`
281				}
282
283				weatherTool := fantasy.NewAgentTool(
284					"weather",
285					"Get weather information for a location",
286					func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
287						return fantasy.NewTextResponse("40 C"), nil
288					},
289				)
290
291				agent := fantasy.NewAgent(
292					languageModel,
293					fantasy.WithSystemPrompt("You are a helpful assistant"),
294					fantasy.WithTools(weatherTool),
295				)
296				result, err := agent.Generate(t.Context(), fantasy.AgentCall{
297					Prompt:          "What's the weather in Florence, Italy?",
298					ProviderOptions: pair.providerOptions,
299					PrepareStep:     pair.prepareStep,
300				})
301				require.NoError(t, err, "failed to generate")
302
303				want1 := "Florence"
304				want2 := "40"
305				got := result.Response.Content.Text()
306				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
307
308				thinkChecks(t, result)
309			})
310			t.Run("thinking-streaming", func(t *testing.T) {
311				r := newRecorder(t)
312
313				languageModel, err := pair.builder(r)
314				require.NoError(t, err, "failed to build language model")
315
316				type WeatherInput struct {
317					Location string `json:"location" description:"the city"`
318				}
319
320				weatherTool := fantasy.NewAgentTool(
321					"weather",
322					"Get weather information for a location",
323					func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
324						return fantasy.NewTextResponse("40 C"), nil
325					},
326				)
327
328				agent := fantasy.NewAgent(
329					languageModel,
330					fantasy.WithSystemPrompt("You are a helpful assistant"),
331					fantasy.WithTools(weatherTool),
332				)
333				result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
334					Prompt:          "What's the weather in Florence, Italy?",
335					ProviderOptions: pair.providerOptions,
336					PrepareStep:     pair.prepareStep,
337				})
338				require.NoError(t, err, "failed to generate")
339
340				want1 := "Florence"
341				want2 := "40"
342				got := result.Response.Content.Text()
343				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
344
345				thinkChecks(t, result)
346			})
347		})
348	}
349}
350
351func containsAny(s string, subs ...string) bool {
352	for _, sub := range subs {
353		if strings.Contains(s, sub) {
354			return true
355		}
356	}
357	return false
358}