common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"os"
  6	"strconv"
  7	"strings"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"github.com/joho/godotenv"
 12	"github.com/stretchr/testify/require"
 13	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 14)
 15
 16func init() {
 17	if _, err := os.Stat(".env"); err == nil {
 18		godotenv.Load(".env")
 19	} else {
 20		godotenv.Load(".env.sample")
 21	}
 22}
 23
 24type testModel struct {
 25	name      string
 26	model     string
 27	reasoning bool
 28}
 29
 30type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
 31
 32type builderPair struct {
 33	name            string
 34	builder         builderFunc
 35	providerOptions fantasy.ProviderOptions
 36	prepareStep     fantasy.PrepareStepFunction
 37}
 38
 39func testCommon(t *testing.T, pairs []builderPair) {
 40	for _, pair := range pairs {
 41		t.Run(pair.name, func(t *testing.T) {
 42			testSimple(t, pair)
 43			testTool(t, pair)
 44			testMultiTool(t, pair)
 45		})
 46	}
 47}
 48
 49func testSimple(t *testing.T, pair builderPair) {
 50	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
 51		options := []string{"Oi", "oi", "Olá", "olá"}
 52		got := result.Response.Content.Text()
 53		require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
 54	}
 55
 56	t.Run("simple", func(t *testing.T) {
 57		r := newRecorder(t)
 58
 59		languageModel, err := pair.builder(t, r)
 60		require.NoError(t, err, "failed to build language model")
 61
 62		agent := fantasy.NewAgent(
 63			languageModel,
 64			fantasy.WithSystemPrompt("You are a helpful assistant"),
 65		)
 66		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
 67			Prompt:          "Say hi in Portuguese",
 68			ProviderOptions: pair.providerOptions,
 69			MaxOutputTokens: fantasy.Opt(int64(4000)),
 70			PrepareStep:     pair.prepareStep,
 71		})
 72		require.NoError(t, err, "failed to generate")
 73		checkResult(t, result)
 74	})
 75	t.Run("simple streaming", func(t *testing.T) {
 76		r := newRecorder(t)
 77
 78		languageModel, err := pair.builder(t, r)
 79		require.NoError(t, err, "failed to build language model")
 80
 81		agent := fantasy.NewAgent(
 82			languageModel,
 83			fantasy.WithSystemPrompt("You are a helpful assistant"),
 84		)
 85		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
 86			Prompt:          "Say hi in Portuguese",
 87			ProviderOptions: pair.providerOptions,
 88			MaxOutputTokens: fantasy.Opt(int64(4000)),
 89			PrepareStep:     pair.prepareStep,
 90		})
 91		require.NoError(t, err, "failed to generate")
 92		checkResult(t, result)
 93	})
 94}
 95
 96func testTool(t *testing.T, pair builderPair) {
 97	type WeatherInput struct {
 98		Location string `json:"location" description:"the city"`
 99	}
100
101	weatherTool := fantasy.NewAgentTool(
102		"weather",
103		"Get weather information for a location",
104		func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
105			return fantasy.NewTextResponse("40 C"), nil
106		},
107	)
108	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
109		require.GreaterOrEqual(t, len(result.Steps), 2)
110
111		var toolCalls []fantasy.ToolCallContent
112		for _, content := range result.Steps[0].Content {
113			if content.GetType() == fantasy.ContentTypeToolCall {
114				toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
115			}
116		}
117		for _, tc := range toolCalls {
118			require.False(t, tc.Invalid)
119		}
120		require.Len(t, toolCalls, 1)
121		require.Equal(t, toolCalls[0].ToolName, "weather")
122
123		want1 := "Florence"
124		want2 := "40"
125		got := result.Response.Content.Text()
126		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
127	}
128
129	t.Run("tool", func(t *testing.T) {
130		r := newRecorder(t)
131
132		languageModel, err := pair.builder(t, r)
133		require.NoError(t, err, "failed to build language model")
134
135		agent := fantasy.NewAgent(
136			languageModel,
137			fantasy.WithSystemPrompt("You are a helpful assistant"),
138			fantasy.WithTools(weatherTool),
139		)
140		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
141			Prompt:          "What's the weather in Florence,Italy?",
142			ProviderOptions: pair.providerOptions,
143			MaxOutputTokens: fantasy.Opt(int64(4000)),
144			PrepareStep:     pair.prepareStep,
145		})
146		require.NoError(t, err, "failed to generate")
147		checkResult(t, result)
148	})
149	t.Run("tool streaming", func(t *testing.T) {
150		r := newRecorder(t)
151
152		languageModel, err := pair.builder(t, r)
153		require.NoError(t, err, "failed to build language model")
154
155		agent := fantasy.NewAgent(
156			languageModel,
157			fantasy.WithSystemPrompt("You are a helpful assistant"),
158			fantasy.WithTools(weatherTool),
159		)
160		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
161			Prompt:          "What's the weather in Florence,Italy?",
162			ProviderOptions: pair.providerOptions,
163			MaxOutputTokens: fantasy.Opt(int64(4000)),
164			PrepareStep:     pair.prepareStep,
165		})
166		require.NoError(t, err, "failed to generate")
167		checkResult(t, result)
168	})
169}
170
171func testMultiTool(t *testing.T, pair builderPair) {
172	// Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
173	if strings.Contains(pair.name, "azure") {
174		t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
175	}
176	if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
177		t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
178	}
179	if strings.Contains(pair.name, "bedrock") && strings.Contains(pair.name, "claude") {
180		t.Skip("skipping multi-tool tests for bedrock claude as it does not support parallel multi-tool calls")
181	}
182	if strings.Contains(pair.name, "openai") && strings.Contains(pair.name, "o4-mini") {
183		t.Skip("skipping multi-tool tests for openai o4-mini it for some reason is not doing parallel tool calls even if asked")
184	}
185	if strings.Contains(pair.name, "llama-cpp") && strings.Contains(pair.name, "gpt-oss") {
186		t.Skip("skipping multi-tool tests for llama-cpp gpt-oss as it does not support parallel multi-tool calls")
187	}
188
189	type CalculatorInput struct {
190		A int `json:"a" description:"first number"`
191		B int `json:"b" description:"second number"`
192	}
193
194	addTool := fantasy.NewAgentTool(
195		"add",
196		"Add two numbers",
197		func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
198			result := input.A + input.B
199			return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
200		},
201	)
202	multiplyTool := fantasy.NewAgentTool(
203		"multiply",
204		"Multiply two numbers",
205		func(ctx context.Context, input CalculatorInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
206			result := input.A * input.B
207			return fantasy.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
208		},
209	)
210	checkResult := func(t *testing.T, result *fantasy.AgentResult) {
211		require.Len(t, result.Steps, 2)
212
213		var toolCalls []fantasy.ToolCallContent
214		for _, content := range result.Steps[0].Content {
215			if content.GetType() == fantasy.ContentTypeToolCall {
216				toolCalls = append(toolCalls, content.(fantasy.ToolCallContent))
217			}
218		}
219		for _, tc := range toolCalls {
220			require.False(t, tc.Invalid)
221		}
222		require.Len(t, toolCalls, 2)
223
224		finalText := result.Response.Content.Text()
225		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
226		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
227	}
228
229	t.Run("multi tool", func(t *testing.T) {
230		r := newRecorder(t)
231
232		languageModel, err := pair.builder(t, r)
233		require.NoError(t, err, "failed to build language model")
234
235		agent := fantasy.NewAgent(
236			languageModel,
237			fantasy.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
238			fantasy.WithTools(addTool),
239			fantasy.WithTools(multiplyTool),
240		)
241		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
242			Prompt:          "Add and multiply the number 2 and 3",
243			ProviderOptions: pair.providerOptions,
244			MaxOutputTokens: fantasy.Opt(int64(4000)),
245			PrepareStep:     pair.prepareStep,
246		})
247		require.NoError(t, err, "failed to generate")
248		checkResult(t, result)
249	})
250	t.Run("multi tool streaming", func(t *testing.T) {
251		r := newRecorder(t)
252
253		languageModel, err := pair.builder(t, r)
254		require.NoError(t, err, "failed to build language model")
255
256		agent := fantasy.NewAgent(
257			languageModel,
258			fantasy.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
259			fantasy.WithTools(addTool),
260			fantasy.WithTools(multiplyTool),
261		)
262		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
263			Prompt:          "Add and multiply the number 2 and 3",
264			ProviderOptions: pair.providerOptions,
265			MaxOutputTokens: fantasy.Opt(int64(4000)),
266			PrepareStep:     pair.prepareStep,
267		})
268		require.NoError(t, err, "failed to generate")
269		checkResult(t, result)
270	})
271}
272
273func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *fantasy.AgentResult)) {
274	for _, pair := range pairs {
275		t.Run(pair.name, func(t *testing.T) {
276			t.Run("thinking", func(t *testing.T) {
277				r := newRecorder(t)
278
279				languageModel, err := pair.builder(t, r)
280				require.NoError(t, err, "failed to build language model")
281
282				type WeatherInput struct {
283					Location string `json:"location" description:"the city"`
284				}
285
286				weatherTool := fantasy.NewAgentTool(
287					"weather",
288					"Get weather information for a location",
289					func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
290						return fantasy.NewTextResponse("40 C"), nil
291					},
292				)
293
294				agent := fantasy.NewAgent(
295					languageModel,
296					fantasy.WithSystemPrompt("You are a helpful assistant"),
297					fantasy.WithTools(weatherTool),
298				)
299				result, err := agent.Generate(t.Context(), fantasy.AgentCall{
300					Prompt:          "What's the weather in Florence, Italy?",
301					ProviderOptions: pair.providerOptions,
302					PrepareStep:     pair.prepareStep,
303				})
304				require.NoError(t, err, "failed to generate")
305
306				want1 := "Florence"
307				want2 := "40"
308				got := result.Response.Content.Text()
309				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
310
311				thinkChecks(t, result)
312			})
313			t.Run("thinking-streaming", func(t *testing.T) {
314				r := newRecorder(t)
315
316				languageModel, err := pair.builder(t, r)
317				require.NoError(t, err, "failed to build language model")
318
319				type WeatherInput struct {
320					Location string `json:"location" description:"the city"`
321				}
322
323				weatherTool := fantasy.NewAgentTool(
324					"weather",
325					"Get weather information for a location",
326					func(ctx context.Context, input WeatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
327						return fantasy.NewTextResponse("40 C"), nil
328					},
329				)
330
331				agent := fantasy.NewAgent(
332					languageModel,
333					fantasy.WithSystemPrompt("You are a helpful assistant"),
334					fantasy.WithTools(weatherTool),
335				)
336				result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
337					Prompt:          "What's the weather in Florence, Italy?",
338					ProviderOptions: pair.providerOptions,
339					PrepareStep:     pair.prepareStep,
340				})
341				require.NoError(t, err, "failed to generate")
342
343				want1 := "Florence"
344				want2 := "40"
345				got := result.Response.Content.Text()
346				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
347
348				thinkChecks(t, result)
349			})
350		})
351	}
352}
353
354func containsAny(s string, subs ...string) bool {
355	for _, sub := range subs {
356		if strings.Contains(s, sub) {
357			return true
358		}
359	}
360	return false
361}