common_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/fantasy/ai"
 10	"github.com/stretchr/testify/require"
 11	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 12)
 13
 14type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 15
 16type builderPair struct {
 17	name            string
 18	builder         builderFunc
 19	providerOptions ai.ProviderOptions
 20}
 21
 22func testCommon(t *testing.T, pairs []builderPair) {
 23	for _, pair := range pairs {
 24		testSimple(t, pair)
 25	}
 26}
 27
 28func testSimple(t *testing.T, pair builderPair) {
 29	checkResult := func(t *testing.T, result *ai.AgentResult) {
 30		option1 := "Oi"
 31		option2 := "Olá"
 32		got := result.Response.Content.Text()
 33		require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
 34	}
 35
 36	t.Run("simple "+pair.name, func(t *testing.T) {
 37		r := newRecorder(t)
 38
 39		languageModel, err := pair.builder(r)
 40		require.NoError(t, err, "failed to build language model")
 41
 42		agent := ai.NewAgent(
 43			languageModel,
 44			ai.WithSystemPrompt("You are a helpful assistant"),
 45		)
 46		result, err := agent.Generate(t.Context(), ai.AgentCall{
 47			Prompt:          "Say hi in Portuguese",
 48			ProviderOptions: pair.providerOptions,
 49			MaxOutputTokens: ai.IntOption(4000),
 50		})
 51		require.NoError(t, err, "failed to generate")
 52		checkResult(t, result)
 53	})
 54	t.Run("simple streaming "+pair.name, func(t *testing.T) {
 55		r := newRecorder(t)
 56
 57		languageModel, err := pair.builder(r)
 58		require.NoError(t, err, "failed to build language model")
 59
 60		agent := ai.NewAgent(
 61			languageModel,
 62			ai.WithSystemPrompt("You are a helpful assistant"),
 63		)
 64		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
 65			Prompt:          "Say hi in Portuguese",
 66			ProviderOptions: pair.providerOptions,
 67			MaxOutputTokens: ai.IntOption(4000),
 68		})
 69		require.NoError(t, err, "failed to generate")
 70		checkResult(t, result)
 71	})
 72}
 73
 74func testTool(t *testing.T, pair builderPair) {
 75	type WeatherInput struct {
 76		Location string `json:"location" description:"the city"`
 77	}
 78
 79	weatherTool := ai.NewAgentTool(
 80		"weather",
 81		"Get weather information for a location",
 82		func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 83			return ai.NewTextResponse("40 C"), nil
 84		},
 85	)
 86	checkResult := func(t *testing.T, result *ai.AgentResult) {
 87		require.Len(t, result.Steps, 2)
 88
 89		var toolCalls []ai.ToolCallContent
 90		for _, content := range result.Steps[0].Content {
 91			if content.GetType() == ai.ContentTypeToolCall {
 92				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
 93			}
 94		}
 95		require.Len(t, toolCalls, 1)
 96		require.Equal(t, toolCalls[0].ToolName, "weather")
 97
 98		want1 := "Florence"
 99		want2 := "40"
100		got := result.Response.Content.Text()
101		require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
102	}
103
104	t.Run("tool "+pair.name, func(t *testing.T) {
105		r := newRecorder(t)
106
107		languageModel, err := pair.builder(r)
108		require.NoError(t, err, "failed to build language model")
109
110		agent := ai.NewAgent(
111			languageModel,
112			ai.WithSystemPrompt("You are a helpful assistant"),
113			ai.WithTools(weatherTool),
114		)
115		result, err := agent.Generate(t.Context(), ai.AgentCall{
116			Prompt:          "What's the weather in Florence,Italy?",
117			ProviderOptions: pair.providerOptions,
118			MaxOutputTokens: ai.IntOption(4000),
119		})
120		require.NoError(t, err, "failed to generate")
121		checkResult(t, result)
122	})
123	t.Run("tool streaming "+pair.name, func(t *testing.T) {
124		r := newRecorder(t)
125
126		languageModel, err := pair.builder(r)
127		require.NoError(t, err, "failed to build language model")
128
129		agent := ai.NewAgent(
130			languageModel,
131			ai.WithSystemPrompt("You are a helpful assistant"),
132			ai.WithTools(weatherTool),
133		)
134		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
135			Prompt:          "What's the weather in Florence,Italy?",
136			ProviderOptions: pair.providerOptions,
137			MaxOutputTokens: ai.IntOption(4000),
138		})
139		require.NoError(t, err, "failed to generate")
140		checkResult(t, result)
141	})
142}
143
144func testMultiTool(t *testing.T, pair builderPair) {
145	type WeatherInput struct {
146		Location string `json:"location" description:"the city"`
147	}
148
149	type CalculatorInput struct {
150		A int `json:"a" description:"first number"`
151		B int `json:"b" description:"second number"`
152	}
153
154	addTool := ai.NewAgentTool(
155		"add",
156		"Add two numbers",
157		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
158			result := input.A + input.B
159			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
160		},
161	)
162	multiplyTool := ai.NewAgentTool(
163		"multiply",
164		"Multiply two numbers",
165		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
166			result := input.A * input.B
167			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
168		},
169	)
170	checkResult := func(t *testing.T, result *ai.AgentResult) {
171		require.Len(t, result.Steps, 2)
172
173		var toolCalls []ai.ToolCallContent
174		for _, content := range result.Steps[0].Content {
175			if content.GetType() == ai.ContentTypeToolCall {
176				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
177			}
178		}
179		require.Len(t, toolCalls, 2)
180
181		finalText := result.Response.Content.Text()
182		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
183		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
184	}
185
186	t.Run("multi tool "+pair.name, func(t *testing.T) {
187		r := newRecorder(t)
188
189		languageModel, err := pair.builder(r)
190		require.NoError(t, err, "failed to build language model")
191
192		agent := ai.NewAgent(
193			languageModel,
194			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
195			ai.WithTools(addTool),
196			ai.WithTools(multiplyTool),
197		)
198		result, err := agent.Generate(t.Context(), ai.AgentCall{
199			Prompt:          "Add and multiply the number 2 and 3",
200			ProviderOptions: pair.providerOptions,
201			MaxOutputTokens: ai.IntOption(4000),
202		})
203		require.NoError(t, err, "failed to generate")
204		checkResult(t, result)
205	})
206	t.Run("multi tool streaming "+pair.name, func(t *testing.T) {
207		r := newRecorder(t)
208
209		languageModel, err := pair.builder(r)
210		require.NoError(t, err, "failed to build language model")
211
212		agent := ai.NewAgent(
213			languageModel,
214			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
215			ai.WithTools(addTool),
216			ai.WithTools(multiplyTool),
217		)
218		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
219			Prompt:          "Add and multiply the number 2 and 3",
220			ProviderOptions: pair.providerOptions,
221			MaxOutputTokens: ai.IntOption(4000),
222		})
223		require.NoError(t, err, "failed to generate")
224		checkResult(t, result)
225	})
226}