1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/stretchr/testify/require"
11 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
12)
13
14type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
15
16type builderPair struct {
17 name string
18 builder builderFunc
19 providerOptions ai.ProviderOptions
20}
21
22func testCommon(t *testing.T, pairs []builderPair) {
23 for _, pair := range pairs {
24 testSimple(t, pair)
25 }
26}
27
28func testSimple(t *testing.T, pair builderPair) {
29 checkResult := func(t *testing.T, result *ai.AgentResult) {
30 option1 := "Oi"
31 option2 := "Olá"
32 got := result.Response.Content.Text()
33 require.True(t, strings.Contains(got, option1) || strings.Contains(got, option2), "unexpected response: got %q, want %q or %q", got, option1, option2)
34 }
35
36 t.Run("simple "+pair.name, func(t *testing.T) {
37 r := newRecorder(t)
38
39 languageModel, err := pair.builder(r)
40 require.NoError(t, err, "failed to build language model")
41
42 agent := ai.NewAgent(
43 languageModel,
44 ai.WithSystemPrompt("You are a helpful assistant"),
45 )
46 result, err := agent.Generate(t.Context(), ai.AgentCall{
47 Prompt: "Say hi in Portuguese",
48 ProviderOptions: pair.providerOptions,
49 MaxOutputTokens: ai.IntOption(4000),
50 })
51 require.NoError(t, err, "failed to generate")
52 checkResult(t, result)
53 })
54 t.Run("simple streaming "+pair.name, func(t *testing.T) {
55 r := newRecorder(t)
56
57 languageModel, err := pair.builder(r)
58 require.NoError(t, err, "failed to build language model")
59
60 agent := ai.NewAgent(
61 languageModel,
62 ai.WithSystemPrompt("You are a helpful assistant"),
63 )
64 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
65 Prompt: "Say hi in Portuguese",
66 ProviderOptions: pair.providerOptions,
67 MaxOutputTokens: ai.IntOption(4000),
68 })
69 require.NoError(t, err, "failed to generate")
70 checkResult(t, result)
71 })
72}
73
74func testTool(t *testing.T, pair builderPair) {
75 type WeatherInput struct {
76 Location string `json:"location" description:"the city"`
77 }
78
79 weatherTool := ai.NewAgentTool(
80 "weather",
81 "Get weather information for a location",
82 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
83 return ai.NewTextResponse("40 C"), nil
84 },
85 )
86 checkResult := func(t *testing.T, result *ai.AgentResult) {
87 require.Len(t, result.Steps, 2)
88
89 var toolCalls []ai.ToolCallContent
90 for _, content := range result.Steps[0].Content {
91 if content.GetType() == ai.ContentTypeToolCall {
92 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
93 }
94 }
95 require.Len(t, toolCalls, 1)
96 require.Equal(t, toolCalls[0].ToolName, "weather")
97
98 want1 := "Florence"
99 want2 := "40"
100 got := result.Response.Content.Text()
101 require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
102 }
103
104 t.Run("tool "+pair.name, func(t *testing.T) {
105 r := newRecorder(t)
106
107 languageModel, err := pair.builder(r)
108 require.NoError(t, err, "failed to build language model")
109
110 agent := ai.NewAgent(
111 languageModel,
112 ai.WithSystemPrompt("You are a helpful assistant"),
113 ai.WithTools(weatherTool),
114 )
115 result, err := agent.Generate(t.Context(), ai.AgentCall{
116 Prompt: "What's the weather in Florence,Italy?",
117 ProviderOptions: pair.providerOptions,
118 MaxOutputTokens: ai.IntOption(4000),
119 })
120 require.NoError(t, err, "failed to generate")
121 checkResult(t, result)
122 })
123 t.Run("tool streaming "+pair.name, func(t *testing.T) {
124 r := newRecorder(t)
125
126 languageModel, err := pair.builder(r)
127 require.NoError(t, err, "failed to build language model")
128
129 agent := ai.NewAgent(
130 languageModel,
131 ai.WithSystemPrompt("You are a helpful assistant"),
132 ai.WithTools(weatherTool),
133 )
134 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
135 Prompt: "What's the weather in Florence,Italy?",
136 ProviderOptions: pair.providerOptions,
137 MaxOutputTokens: ai.IntOption(4000),
138 })
139 require.NoError(t, err, "failed to generate")
140 checkResult(t, result)
141 })
142}
143
144func testMultiTool(t *testing.T, pair builderPair) {
145 type WeatherInput struct {
146 Location string `json:"location" description:"the city"`
147 }
148
149 type CalculatorInput struct {
150 A int `json:"a" description:"first number"`
151 B int `json:"b" description:"second number"`
152 }
153
154 addTool := ai.NewAgentTool(
155 "add",
156 "Add two numbers",
157 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
158 result := input.A + input.B
159 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
160 },
161 )
162 multiplyTool := ai.NewAgentTool(
163 "multiply",
164 "Multiply two numbers",
165 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
166 result := input.A * input.B
167 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
168 },
169 )
170 checkResult := func(t *testing.T, result *ai.AgentResult) {
171 require.Len(t, result.Steps, 2)
172
173 var toolCalls []ai.ToolCallContent
174 for _, content := range result.Steps[0].Content {
175 if content.GetType() == ai.ContentTypeToolCall {
176 toolCalls = append(toolCalls, content.(ai.ToolCallContent))
177 }
178 }
179 require.Len(t, toolCalls, 2)
180
181 finalText := result.Response.Content.Text()
182 require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
183 require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
184 }
185
186 t.Run("multi tool "+pair.name, func(t *testing.T) {
187 r := newRecorder(t)
188
189 languageModel, err := pair.builder(r)
190 require.NoError(t, err, "failed to build language model")
191
192 agent := ai.NewAgent(
193 languageModel,
194 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
195 ai.WithTools(addTool),
196 ai.WithTools(multiplyTool),
197 )
198 result, err := agent.Generate(t.Context(), ai.AgentCall{
199 Prompt: "Add and multiply the number 2 and 3",
200 ProviderOptions: pair.providerOptions,
201 MaxOutputTokens: ai.IntOption(4000),
202 })
203 require.NoError(t, err, "failed to generate")
204 checkResult(t, result)
205 })
206 t.Run("multi tool streaming "+pair.name, func(t *testing.T) {
207 r := newRecorder(t)
208
209 languageModel, err := pair.builder(r)
210 require.NoError(t, err, "failed to build language model")
211
212 agent := ai.NewAgent(
213 languageModel,
214 ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
215 ai.WithTools(addTool),
216 ai.WithTools(multiplyTool),
217 )
218 result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
219 Prompt: "Add and multiply the number 2 and 3",
220 ProviderOptions: pair.providerOptions,
221 MaxOutputTokens: ai.IntOption(4000),
222 })
223 require.NoError(t, err, "failed to generate")
224 checkResult(t, result)
225 })
226}