common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	_ "github.com/joho/godotenv/autoload"
 11	"github.com/stretchr/testify/require"
 12	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 13)
 14
 15type testModel struct {
 16	name      string
 17	model     string
 18	reasoning bool
 19}
 20
 21type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 22
 23type builderPair struct {
 24	name            string
 25	builder         builderFunc
 26	providerOptions ai.ProviderOptions
 27}
 28
 29func testCommon(t *testing.T, pairs []builderPair) {
 30	for _, pair := range pairs {
 31		t.Run(pair.name, func(t *testing.T) {
 32			testSimple(t, pair)
 33			testTool(t, pair)
 34			testMultiTool(t, pair)
 35		})
 36	}
 37}
 38
 39func testSimple(t *testing.T, pair builderPair) {
 40	checkResult := func(t *testing.T, result *ai.AgentResult) {
 41		options := []string{"Oi", "oi", "Olá", "olá"}
 42		got := result.Response.Content.Text()
 43		require.True(t, containsAny(got, options...), "unexpected response: got %q, want any of: %q", got, options)
 44	}
 45
 46	t.Run("simple", func(t *testing.T) {
 47		r := newRecorder(t)
 48
 49		languageModel, err := pair.builder(r)
 50		require.NoError(t, err, "failed to build language model")
 51
 52		agent := ai.NewAgent(
 53			languageModel,
 54			ai.WithSystemPrompt("You are a helpful assistant"),
 55		)
 56		result, err := agent.Generate(t.Context(), ai.AgentCall{
 57			Prompt:          "Say hi in Portuguese",
 58			ProviderOptions: pair.providerOptions,
 59			MaxOutputTokens: ai.IntOption(4000),
 60		})
 61		require.NoError(t, err, "failed to generate")
 62		checkResult(t, result)
 63	})
 64	t.Run("simple streaming", func(t *testing.T) {
 65		r := newRecorder(t)
 66
 67		languageModel, err := pair.builder(r)
 68		require.NoError(t, err, "failed to build language model")
 69
 70		agent := ai.NewAgent(
 71			languageModel,
 72			ai.WithSystemPrompt("You are a helpful assistant"),
 73		)
 74		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 75			Prompt:          "Say hi in Portuguese",
 76			ProviderOptions: pair.providerOptions,
 77			MaxOutputTokens: ai.IntOption(4000),
 78		})
 79		require.NoError(t, err, "failed to generate")
 80		checkResult(t, result)
 81	})
 82}
 83
 84func testTool(t *testing.T, pair builderPair) {
 85	type WeatherInput struct {
 86		Location string `json:"location" description:"the city"`
 87	}
 88
 89	weatherTool := ai.NewAgentTool(
 90		"weather",
 91		"Get weather information for a location",
 92		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 93			return ai.NewTextResponse("40 C"), nil
 94		},
 95	)
 96	checkResult := func(t *testing.T, result *ai.AgentResult) {
 97		require.Len(t, result.Steps, 2)
 98
 99		var toolCalls []ai.ToolCallContent
100		for _, content := range result.Steps[0].Content {
101			if content.GetType() == ai.ContentTypeToolCall {
102				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
103			}
104		}
105		for _, tc := range toolCalls {
106			require.False(t, tc.Invalid)
107		}
108		require.Len(t, toolCalls, 1)
109		require.Equal(t, toolCalls[0].ToolName, "weather")
110
111		want1 := "Florence"
112		want2 := "40"
113		got := result.Response.Content.Text()
114		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
115	}
116
117	t.Run("tool", func(t *testing.T) {
118		r := newRecorder(t)
119
120		languageModel, err := pair.builder(r)
121		require.NoError(t, err, "failed to build language model")
122
123		agent := ai.NewAgent(
124			languageModel,
125			ai.WithSystemPrompt("You are a helpful assistant"),
126			ai.WithTools(weatherTool),
127		)
128		result, err := agent.Generate(t.Context(), ai.AgentCall{
129			Prompt:          "What's the weather in Florence,Italy?",
130			ProviderOptions: pair.providerOptions,
131			MaxOutputTokens: ai.IntOption(4000),
132		})
133		require.NoError(t, err, "failed to generate")
134		checkResult(t, result)
135	})
136	t.Run("tool streaming", func(t *testing.T) {
137		r := newRecorder(t)
138
139		languageModel, err := pair.builder(r)
140		require.NoError(t, err, "failed to build language model")
141
142		agent := ai.NewAgent(
143			languageModel,
144			ai.WithSystemPrompt("You are a helpful assistant"),
145			ai.WithTools(weatherTool),
146		)
147		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
148			Prompt:          "What's the weather in Florence,Italy?",
149			ProviderOptions: pair.providerOptions,
150			MaxOutputTokens: ai.IntOption(4000),
151		})
152		require.NoError(t, err, "failed to generate")
153		checkResult(t, result)
154	})
155}
156
157func testMultiTool(t *testing.T, pair builderPair) {
158	// Apparently, Azure does not support multi-tools calls at all?
159	if strings.Contains(pair.name, "azure") {
160		t.Skip("skipping multi-tool tests for azure as it does not support parallel multi-tool calls")
161	}
162
163	type CalculatorInput struct {
164		A int `json:"a" description:"first number"`
165		B int `json:"b" description:"second number"`
166	}
167
168	addTool := ai.NewAgentTool(
169		"add",
170		"Add two numbers",
171		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
172			result := input.A + input.B
173			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
174		},
175	)
176	multiplyTool := ai.NewAgentTool(
177		"multiply",
178		"Multiply two numbers",
179		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
180			result := input.A * input.B
181			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
182		},
183	)
184	checkResult := func(t *testing.T, result *ai.AgentResult) {
185		require.Len(t, result.Steps, 2)
186
187		var toolCalls []ai.ToolCallContent
188		for _, content := range result.Steps[0].Content {
189			if content.GetType() == ai.ContentTypeToolCall {
190				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
191			}
192		}
193		for _, tc := range toolCalls {
194			require.False(t, tc.Invalid)
195		}
196		require.Len(t, toolCalls, 2)
197
198		finalText := result.Response.Content.Text()
199		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
200		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
201	}
202
203	t.Run("multi tool", func(t *testing.T) {
204		r := newRecorder(t)
205
206		languageModel, err := pair.builder(r)
207		require.NoError(t, err, "failed to build language model")
208
209		agent := ai.NewAgent(
210			languageModel,
211			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
212			ai.WithTools(addTool),
213			ai.WithTools(multiplyTool),
214		)
215		result, err := agent.Generate(t.Context(), ai.AgentCall{
216			Prompt:          "Add and multiply the number 2 and 3",
217			ProviderOptions: pair.providerOptions,
218			MaxOutputTokens: ai.IntOption(4000),
219		})
220		require.NoError(t, err, "failed to generate")
221		checkResult(t, result)
222	})
223	t.Run("multi tool streaming", func(t *testing.T) {
224		r := newRecorder(t)
225
226		languageModel, err := pair.builder(r)
227		require.NoError(t, err, "failed to build language model")
228
229		agent := ai.NewAgent(
230			languageModel,
231			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
232			ai.WithTools(addTool),
233			ai.WithTools(multiplyTool),
234		)
235		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
236			Prompt:          "Add and multiply the number 2 and 3",
237			ProviderOptions: pair.providerOptions,
238			MaxOutputTokens: ai.IntOption(4000),
239		})
240		require.NoError(t, err, "failed to generate")
241		checkResult(t, result)
242	})
243}
244
245func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
246	for _, pair := range pairs {
247		t.Run(pair.name, func(t *testing.T) {
248			t.Run("thinking", func(t *testing.T) {
249				r := newRecorder(t)
250
251				languageModel, err := pair.builder(r)
252				require.NoError(t, err, "failed to build language model")
253
254				type WeatherInput struct {
255					Location string `json:"location" description:"the city"`
256				}
257
258				weatherTool := ai.NewAgentTool(
259					"weather",
260					"Get weather information for a location",
261					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
262						return ai.NewTextResponse("40 C"), nil
263					},
264				)
265
266				agent := ai.NewAgent(
267					languageModel,
268					ai.WithSystemPrompt("You are a helpful assistant"),
269					ai.WithTools(weatherTool),
270				)
271				result, err := agent.Generate(t.Context(), ai.AgentCall{
272					Prompt:          "What's the weather in Florence, Italy?",
273					ProviderOptions: pair.providerOptions,
274				})
275				require.NoError(t, err, "failed to generate")
276
277				want1 := "Florence"
278				want2 := "40"
279				got := result.Response.Content.Text()
280				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
281
282				thinkChecks(t, result)
283			})
284			t.Run("thinking-streaming", func(t *testing.T) {
285				r := newRecorder(t)
286
287				languageModel, err := pair.builder(r)
288				require.NoError(t, err, "failed to build language model")
289
290				type WeatherInput struct {
291					Location string `json:"location" description:"the city"`
292				}
293
294				weatherTool := ai.NewAgentTool(
295					"weather",
296					"Get weather information for a location",
297					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
298						return ai.NewTextResponse("40 C"), nil
299					},
300				)
301
302				agent := ai.NewAgent(
303					languageModel,
304					ai.WithSystemPrompt("You are a helpful assistant"),
305					ai.WithTools(weatherTool),
306				)
307				result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
308					Prompt:          "What's the weather in Florence, Italy?",
309					ProviderOptions: pair.providerOptions,
310				})
311				require.NoError(t, err, "failed to generate")
312
313				want1 := "Florence"
314				want2 := "40"
315				got := result.Response.Content.Text()
316				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
317
318				thinkChecks(t, result)
319			})
320		})
321	}
322}
323
324func containsAny(s string, subs ...string) bool {
325	for _, sub := range subs {
326		if strings.Contains(s, sub) {
327			return true
328		}
329	}
330	return false
331}