common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	_ "github.com/joho/godotenv/autoload"
 11	"github.com/stretchr/testify/require"
 12	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 13)
 14
 15type testModel struct {
 16	name      string
 17	model     string
 18	reasoning bool
 19}
 20
 21type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 22
 23type builderPair struct {
 24	name            string
 25	builder         builderFunc
 26	providerOptions ai.ProviderOptions
 27}
 28
 29func testCommon(t *testing.T, pairs []builderPair) {
 30	for _, pair := range pairs {
 31		t.Run(pair.name, func(t *testing.T) {
 32			testSimple(t, pair)
 33			testTool(t, pair)
 34			testMultiTool(t, pair)
 35		})
 36	}
 37}
 38
 39func testSimple(t *testing.T, pair builderPair) {
 40	checkResult := func(t *testing.T, result *ai.AgentResult) {
 41		option1 := "Oi"
 42		option2 := "Olá"
 43		got := result.Response.Content.Text()
 44		require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
 45	}
 46
 47	t.Run("simple", func(t *testing.T) {
 48		r := newRecorder(t)
 49
 50		languageModel, err := pair.builder(r)
 51		require.NoError(t, err, "failed to build language model")
 52
 53		agent := ai.NewAgent(
 54			languageModel,
 55			ai.WithSystemPrompt("You are a helpful assistant"),
 56		)
 57		result, err := agent.Generate(t.Context(), ai.AgentCall{
 58			Prompt:          "Say hi in Portuguese",
 59			ProviderOptions: pair.providerOptions,
 60			MaxOutputTokens: ai.IntOption(4000),
 61		})
 62		require.NoError(t, err, "failed to generate")
 63		checkResult(t, result)
 64	})
 65	t.Run("simple streaming", func(t *testing.T) {
 66		r := newRecorder(t)
 67
 68		languageModel, err := pair.builder(r)
 69		require.NoError(t, err, "failed to build language model")
 70
 71		agent := ai.NewAgent(
 72			languageModel,
 73			ai.WithSystemPrompt("You are a helpful assistant"),
 74		)
 75		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 76			Prompt:          "Say hi in Portuguese",
 77			ProviderOptions: pair.providerOptions,
 78			MaxOutputTokens: ai.IntOption(4000),
 79		})
 80		require.NoError(t, err, "failed to generate")
 81		checkResult(t, result)
 82	})
 83}
 84
 85func testTool(t *testing.T, pair builderPair) {
 86	type WeatherInput struct {
 87		Location string `json:"location" description:"the city"`
 88	}
 89
 90	weatherTool := ai.NewAgentTool(
 91		"weather",
 92		"Get weather information for a location",
 93		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 94			return ai.NewTextResponse("40 C"), nil
 95		},
 96	)
 97	checkResult := func(t *testing.T, result *ai.AgentResult) {
 98		require.Len(t, result.Steps, 2)
 99
100		var toolCalls []ai.ToolCallContent
101		for _, content := range result.Steps[0].Content {
102			if content.GetType() == ai.ContentTypeToolCall {
103				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
104			}
105		}
106		for _, tc := range toolCalls {
107			require.False(t, tc.Invalid)
108		}
109		require.Len(t, toolCalls, 1)
110		require.Equal(t, toolCalls[0].ToolName, "weather")
111
112		want1 := "Florence"
113		want2 := "40"
114		got := result.Response.Content.Text()
115		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
116	}
117
118	t.Run("tool", func(t *testing.T) {
119		r := newRecorder(t)
120
121		languageModel, err := pair.builder(r)
122		require.NoError(t, err, "failed to build language model")
123
124		agent := ai.NewAgent(
125			languageModel,
126			ai.WithSystemPrompt("You are a helpful assistant"),
127			ai.WithTools(weatherTool),
128		)
129		result, err := agent.Generate(t.Context(), ai.AgentCall{
130			Prompt:          "What's the weather in Florence,Italy?",
131			ProviderOptions: pair.providerOptions,
132			MaxOutputTokens: ai.IntOption(4000),
133		})
134		require.NoError(t, err, "failed to generate")
135		checkResult(t, result)
136	})
137	t.Run("tool streaming", func(t *testing.T) {
138		r := newRecorder(t)
139
140		languageModel, err := pair.builder(r)
141		require.NoError(t, err, "failed to build language model")
142
143		agent := ai.NewAgent(
144			languageModel,
145			ai.WithSystemPrompt("You are a helpful assistant"),
146			ai.WithTools(weatherTool),
147		)
148		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
149			Prompt:          "What's the weather in Florence,Italy?",
150			ProviderOptions: pair.providerOptions,
151			MaxOutputTokens: ai.IntOption(4000),
152		})
153		require.NoError(t, err, "failed to generate")
154		checkResult(t, result)
155	})
156}
157
158func testMultiTool(t *testing.T, pair builderPair) {
159	type CalculatorInput struct {
160		A int `json:"a" description:"first number"`
161		B int `json:"b" description:"second number"`
162	}
163
164	addTool := ai.NewAgentTool(
165		"add",
166		"Add two numbers",
167		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
168			result := input.A + input.B
169			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
170		},
171	)
172	multiplyTool := ai.NewAgentTool(
173		"multiply",
174		"Multiply two numbers",
175		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
176			result := input.A * input.B
177			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
178		},
179	)
180	checkResult := func(t *testing.T, result *ai.AgentResult) {
181		require.Len(t, result.Steps, 2)
182
183		var toolCalls []ai.ToolCallContent
184		for _, content := range result.Steps[0].Content {
185			if content.GetType() == ai.ContentTypeToolCall {
186				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
187			}
188		}
189		for _, tc := range toolCalls {
190			require.False(t, tc.Invalid)
191		}
192		require.Len(t, toolCalls, 2)
193
194		finalText := result.Response.Content.Text()
195		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
196		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
197	}
198
199	t.Run("multi tool", func(t *testing.T) {
200		r := newRecorder(t)
201
202		languageModel, err := pair.builder(r)
203		require.NoError(t, err, "failed to build language model")
204
205		agent := ai.NewAgent(
206			languageModel,
207			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
208			ai.WithTools(addTool),
209			ai.WithTools(multiplyTool),
210		)
211		result, err := agent.Generate(t.Context(), ai.AgentCall{
212			Prompt:          "Add and multiply the number 2 and 3",
213			ProviderOptions: pair.providerOptions,
214			MaxOutputTokens: ai.IntOption(4000),
215		})
216		require.NoError(t, err, "failed to generate")
217		checkResult(t, result)
218	})
219	t.Run("multi tool streaming", func(t *testing.T) {
220		r := newRecorder(t)
221
222		languageModel, err := pair.builder(r)
223		require.NoError(t, err, "failed to build language model")
224
225		agent := ai.NewAgent(
226			languageModel,
227			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
228			ai.WithTools(addTool),
229			ai.WithTools(multiplyTool),
230		)
231		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
232			Prompt:          "Add and multiply the number 2 and 3",
233			ProviderOptions: pair.providerOptions,
234			MaxOutputTokens: ai.IntOption(4000),
235		})
236		require.NoError(t, err, "failed to generate")
237		checkResult(t, result)
238	})
239}
240
241func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T, *ai.AgentResult)) {
242	for _, pair := range pairs {
243		t.Run(pair.name, func(t *testing.T) {
244			t.Run("thinking", func(t *testing.T) {
245				r := newRecorder(t)
246
247				languageModel, err := pair.builder(r)
248				require.NoError(t, err, "failed to build language model")
249
250				type WeatherInput struct {
251					Location string `json:"location" description:"the city"`
252				}
253
254				weatherTool := ai.NewAgentTool(
255					"weather",
256					"Get weather information for a location",
257					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
258						return ai.NewTextResponse("40 C"), nil
259					},
260				)
261
262				agent := ai.NewAgent(
263					languageModel,
264					ai.WithSystemPrompt("You are a helpful assistant"),
265					ai.WithTools(weatherTool),
266				)
267				result, err := agent.Generate(t.Context(), ai.AgentCall{
268					Prompt:          "What's the weather in Florence, Italy?",
269					ProviderOptions: pair.providerOptions,
270				})
271				require.NoError(t, err, "failed to generate")
272
273				want1 := "Florence"
274				want2 := "40"
275				got := result.Response.Content.Text()
276				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
277
278				thinkChecks(t, result)
279			})
280			t.Run("thinking-streaming", func(t *testing.T) {
281				r := newRecorder(t)
282
283				languageModel, err := pair.builder(r)
284				require.NoError(t, err, "failed to build language model")
285
286				type WeatherInput struct {
287					Location string `json:"location" description:"the city"`
288				}
289
290				weatherTool := ai.NewAgentTool(
291					"weather",
292					"Get weather information for a location",
293					func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
294						return ai.NewTextResponse("40 C"), nil
295					},
296				)
297
298				agent := ai.NewAgent(
299					languageModel,
300					ai.WithSystemPrompt("You are a helpful assistant"),
301					ai.WithTools(weatherTool),
302				)
303				result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
304					Prompt:          "What's the weather in Florence, Italy?",
305					ProviderOptions: pair.providerOptions,
306				})
307				require.NoError(t, err, "failed to generate")
308
309				want1 := "Florence"
310				want2 := "40"
311				got := result.Response.Content.Text()
312				require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
313
314				thinkChecks(t, result)
315			})
316		})
317	}
318}