common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"os"
  6	"strconv"
  7	"strings"
  8	"testing"
  9
 10	"github.com/charmbracelet/fantasy/ai"
 11	"github.com/joho/godotenv"
 12	"github.com/stretchr/testify/require"
 13	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 14)
 15
 16func init() {
 17	if _, err := os.Stat(".env"); err == nil {
 18		godotenv.Load(".env")
 19	} else {
 20		godotenv.Load(".env.sample")
 21	}
 22}
 23
 24type testModel struct {
 25	name      string
 26	model     string
 27	reasoning bool
 28}
 29
 30type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 31
 32type builderPair struct {
 33	name            string
 34	builder         builderFunc
 35	providerOptions ai.ProviderOptions
 36}
 37
 38func testCommon(t *testing.T, pairs []builderPair) {
 39	for _, pair := range pairs {
 40		t.Run(pair.name, func(t *testing.T) {
 41			testSimple(t, pair)
 42			testTool(t, pair)
 43			testMultiTool(t, pair)
 44		})
 45	}
 46}
 47
 48func testSimple(t *testing.T, pair builderPair) {
 49	checkResult := func(t *testing.T, result *ai.AgentResult) {
 50		options := []string{"Oi", "oi", "Olá", "olá"}
 51		got := result.Response.Content.Text()
 52		require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
 53	}
 54
 55	t.Run("simple", func(t *testing.T) {
 56		r := newRecorder(t)
 57
 58		languageModel, err := pair.builder(r)
 59		require.NoError(t, err, "failed to build language model")
 60
 61		agent := ai.NewAgent(
 62			languageModel,
 63			ai.WithSystemPrompt("You are a helpful assistant"),
 64		)
 65		result, err := agent.Generate(t.Context(), ai.AgentCall{
 66			Prompt:          "Say hi in Portuguese",
 67			ProviderOptions: pair.providerOptions,
 68			MaxOutputTokens: ai.IntOption(4000),
 69		})
 70		require.NoError(t, err, "failed to generate")
 71		checkResult(t, result)
 72	})
 73	t.Run("simple streaming", func(t *testing.T) {
 74		r := newRecorder(t)
 75
 76		languageModel, err := pair.builder(r)
 77		require.NoError(t, err, "failed to build language model")
 78
 79		agent := ai.NewAgent(
 80			languageModel,
 81			ai.WithSystemPrompt("You are a helpful assistant"),
 82		)
 83		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 84			Prompt:          "Say hi in Portuguese",
 85			ProviderOptions: pair.providerOptions,
 86			MaxOutputTokens: ai.IntOption(4000),
 87		})
 88		require.NoError(t, err, "failed to generate")
 89		checkResult(t, result)
 90	})
 91}
 92
 93func testTool(t *testing.T, pair builderPair) {
 94	type WeatherInput struct {
 95		Location string `json:"location" description:"the city"`
 96	}
 97
 98	weatherTool := ai.NewAgentTool(
 99		"weather",
100		"Get weather information for a location",
101		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
102			return ai.NewTextResponse("40 C"), nil
103		},
104	)
105	checkResult := func(t *testing.T, result *ai.AgentResult) {
106		require.GreaterOrEqual(t, len(result.Steps), 2)
107
108		var toolCalls []ai.ToolCallContent
109		for _, content := range result.Steps[0].Content {
110			if content.GetType() == ai.ContentTypeToolCall {
111				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
112			}
113		}
114		for _, tc := range toolCalls {
115			require.False(t, tc.Invalid)
116		}
117		require.Len(t, toolCalls, 1)
118		require.Equal(t, toolCalls[0].ToolName, "weather")
119
120		want1 := "Florence"
121		want2 := "40"
122		got := result.Response.Content.Text()
123		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
124	}
125
126	t.Run("tool", func(t *testing.T) {
127		r := newRecorder(t)
128
129		languageModel, err := pair.builder(r)
130		require.NoError(t, err, "failed to build language model")
131
132		agent := ai.NewAgent(
133			languageModel,
134			ai.WithSystemPrompt("You are a helpful assistant"),
135			ai.WithTools(weatherTool),
136		)
137		result, err := agent.Generate(t.Context(), ai.AgentCall{
138			Prompt:          "What's the weather in Florence,Italy?",
139			ProviderOptions: pair.providerOptions,
140			MaxOutputTokens: ai.IntOption(4000),
141		})
142		require.NoError(t, err, "failed to generate")
143		checkResult(t, result)
144	})
145	t.Run("tool streaming", func(t *testing.T) {
146		r := newRecorder(t)
147
148		languageModel, err := pair.builder(r)
149		require.NoError(t, err, "failed to build language model")
150
151		agent := ai.NewAgent(
152			languageModel,
153			ai.WithSystemPrompt("You are a helpful assistant"),
154			ai.WithTools(weatherTool),
155		)
156		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
157			Prompt:          "What's the weather in Florence,Italy?",
158			ProviderOptions: pair.providerOptions,
159			MaxOutputTokens: ai.IntOption(4000),
160		})
161		require.NoError(t, err, "failed to generate")
162		checkResult(t, result)
163	})
164}
165
166func testMultiTool(t *testing.T, pair builderPair) {
167	// Apparently, Azure and Vertex+Anthropic do not support multi-tools calls at all?
168	if strings.Contains(pair.name, "azure") {
169		t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
170	}
171	if strings.Contains(pair.name, "vertex") && strings.Contains(pair.name, "claude") {
172		t.Skip("skipping multi-tool tests for vertex claude as it does not support parallel multi-tool calls")
173	}
174
175	type CalculatorInput struct {
176		A int `json:"a" description:"first number"`
177		B int `json:"b" description:"second number"`
178	}
179
180	addTool := ai.NewAgentTool(
181		"add",
182		"Add two numbers",
183		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
184			result := input.A + input.B
185			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
186		},
187	)
188	multiplyTool := ai.NewAgentTool(
189		"multiply",
190		"Multiply two numbers",
191		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
192			result := input.A * input.B
193			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
194		},
195	)
196	checkResult := func(t *testing.T, result *ai.AgentResult) {
197		require.Len(t, result.Steps, 2)
198
199		var toolCalls []ai.ToolCallContent
200		for _, content := range result.Steps[0].Content {
201			if content.GetType() == ai.ContentTypeToolCall {
202				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
203			}
204		}
205		for _, tc := range toolCalls {
206			require.False(t, tc.Invalid)
207		}
208		require.Len(t, toolCalls, 2)
209
210		finalText := result.Response.Content.Text()
211		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
212		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
213	}
214
215	t.Run("multi tool", func(t *testing.T) {
216		r := newRecorder(t)
217
218		languageModel, err := pair.builder(r)
219		require.NoError(t, err, "failed to build language model")
220
221		agent := ai.NewAgent(
222			languageModel,
223			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
224			ai.WithTools(addTool),
225			ai.WithTools(multiplyTool),
226		)
227		result, err := agent.Generate(t.Context(), ai.AgentCall{
228			Prompt:          "Add and multiply the number 2 and 3",
229			ProviderOptions: pair.providerOptions,
230			MaxOutputTokens: ai.IntOption(4000),
231		})
232		require.NoError(t, err, "failed to generate")
233		checkResult(t, result)
234	})
235	t.Run("multi tool streaming", func(t *testing.T) {
236		r := newRecorder(t)
237
238		languageModel, err := pair.builder(r)
239		require.NoError(t, err, "failed to build language model")
240
241		agent := ai.NewAgent(
242			languageModel,
243			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
244			ai.WithTools(addTool),
245			ai.WithTools(multiplyTool),
246		)
247		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
248			Prompt:          "Add and multiply the number 2 and 3",
249			ProviderOptions: pair.providerOptions,
250			MaxOutputTokens: ai.IntOption(4000),
251		})
252		require.NoError(t, err, "failed to generate")
253		checkResult(t, result)
254	})
255}
256
257func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
258	for _, pair := range pairs {
259		t.Run(pair.name, func(t *testing.T) {
260			t.Run("thinking", func(t *testing.T) {
261				r := newRecorder(t)
262
263				languageModel, err := pair.builder(r)
264				require.NoError(t, err, "failed to build language model")
265
266				type WeatherInput struct {
267					Location string `json:"location" description:"the city"`
268				}
269
270				weatherTool := ai.NewAgentTool(
271					"weather",
272					"Get weather information for a location",
273					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
274						return ai.NewTextResponse("40 C"), nil
275					},
276				)
277
278				agent := ai.NewAgent(
279					languageModel,
280					ai.WithSystemPrompt("You are a helpful assistant"),
281					ai.WithTools(weatherTool),
282				)
283				result, err := agent.Generate(t.Context(), ai.AgentCall{
284					Prompt:          "What's the weather in Florence, Italy?",
285					ProviderOptions: pair.providerOptions,
286				})
287				require.NoError(t, err, "failed to generate")
288
289				want1 := "Florence"
290				want2 := "40"
291				got := result.Response.Content.Text()
292				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
293
294				thinkChecks(t, result)
295			})
296			t.Run("thinking-streaming", func(t *testing.T) {
297				r := newRecorder(t)
298
299				languageModel, err := pair.builder(r)
300				require.NoError(t, err, "failed to build language model")
301
302				type WeatherInput struct {
303					Location string `json:"location" description:"the city"`
304				}
305
306				weatherTool := ai.NewAgentTool(
307					"weather",
308					"Get weather information for a location",
309					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
310						return ai.NewTextResponse("40 C"), nil
311					},
312				)
313
314				agent := ai.NewAgent(
315					languageModel,
316					ai.WithSystemPrompt("You are a helpful assistant"),
317					ai.WithTools(weatherTool),
318				)
319				result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
320					Prompt:          "What's the weather in Florence, Italy?",
321					ProviderOptions: pair.providerOptions,
322				})
323				require.NoError(t, err, "failed to generate")
324
325				want1 := "Florence"
326				want2 := "40"
327				got := result.Response.Content.Text()
328				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
329
330				thinkChecks(t, result)
331			})
332		})
333	}
334}
335
336func containsAny(s string, subs ...string) bool {
337	for _, sub := range subs {
338		if strings.Contains(s, sub) {
339			return true
340		}
341	}
342	return false
343}