1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/ai/ai"
10 _ "github.com/joho/godotenv/autoload"
11)
12
13func TestSimple(t *testing.T) {
14 for _, pair := range languageModelBuilders {
15 t.Run(pair.name, func(t *testing.T) {
16 r := newRecorder(t)
17
18 languageModel, err := pair.builder(r)
19 if err != nil {
20 t.Fatalf("failed to build language model: %v", err)
21 }
22
23 agent := ai.NewAgent(
24 languageModel,
25 ai.WithSystemPrompt("You are a helpful assistant"),
26 )
27 result, err := agent.Generate(t.Context(), ai.AgentCall{
28 Prompt: "Say hi in Portuguese",
29 })
30 if err != nil {
31 t.Fatalf("failed to generate: %v", err)
32 }
33
34 option1 := "Oi"
35 option2 := "Olá"
36 got := result.Response.Content.Text()
37 if !strings.Contains(got, option1) && !strings.Contains(got, option2) {
38 t.Fatalf("unexpected response: got %q, want %q or %q", got, option1, option2)
39 }
40 })
41 }
42}
43
44func TestTool(t *testing.T) {
45 for _, pair := range languageModelBuilders {
46 t.Run(pair.name, func(t *testing.T) {
47 r := newRecorder(t)
48
49 languageModel, err := pair.builder(r)
50 if err != nil {
51 t.Fatalf("failed to build language model: %v", err)
52 }
53
54 type WeatherInput struct {
55 Location string `json:"location" description:"the city"`
56 }
57
58 weatherTool := ai.NewAgentTool(
59 "weather",
60 "Get weather information for a location",
61 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
62 return ai.NewTextResponse("40 C"), nil
63 },
64 )
65
66 agent := ai.NewAgent(
67 languageModel,
68 ai.WithSystemPrompt("You are a helpful assistant"),
69 ai.WithTools(weatherTool),
70 )
71 result, err := agent.Generate(t.Context(), ai.AgentCall{
72 Prompt: "What's the weather in Florence?",
73 })
74 if err != nil {
75 t.Fatalf("failed to generate: %v", err)
76 }
77
78 want1 := "Florence"
79 want2 := "40"
80 got := result.Response.Content.Text()
81 if !strings.Contains(got, want1) || !strings.Contains(got, want2) {
82 t.Fatalf("unexpected response: got %q, want %q %q", got, want1, want2)
83 }
84 })
85 }
86}
87
88func TestStream(t *testing.T) {
89 for _, pair := range languageModelBuilders {
90 t.Run(pair.name, func(t *testing.T) {
91 r := newRecorder(t)
92
93 languageModel, err := pair.builder(r)
94 if err != nil {
95 t.Fatalf("failed to build language model: %v", err)
96 }
97
98 agent := ai.NewAgent(
99 languageModel,
100 ai.WithSystemPrompt("You are a helpful assistant"),
101 )
102
103 var collectedText strings.Builder
104 textDeltaCount := 0
105 stepCount := 0
106
107 streamCall := ai.AgentStreamCall{
108 Prompt: "Count from 1 to 3 in Spanish",
109 OnTextDelta: func(id, text string) error {
110 textDeltaCount++
111 collectedText.WriteString(text)
112 return nil
113 },
114 OnStepFinish: func(step ai.StepResult) error {
115 stepCount++
116 return nil
117 },
118 }
119
120 result, err := agent.Stream(t.Context(), streamCall)
121 if err != nil {
122 t.Fatalf("failed to stream: %v", err)
123 }
124
125 finalText := result.Response.Content.Text()
126 if finalText == "" {
127 t.Fatal("expected non-empty response")
128 }
129
130 if !strings.Contains(strings.ToLower(finalText), "uno") ||
131 !strings.Contains(strings.ToLower(finalText), "dos") ||
132 !strings.Contains(strings.ToLower(finalText), "tres") {
133 t.Fatalf("unexpected response: %q", finalText)
134 }
135
136 if textDeltaCount == 0 {
137 t.Fatal("expected at least one text delta callback")
138 }
139
140 if stepCount == 0 {
141 t.Fatal("expected at least one step finish callback")
142 }
143
144 if collectedText.String() == "" {
145 t.Fatal("expected collected text from deltas to be non-empty")
146 }
147 })
148 }
149}
150
151func TestStreamWithTools(t *testing.T) {
152 for _, pair := range languageModelBuilders {
153 t.Run(pair.name, func(t *testing.T) {
154 r := newRecorder(t)
155
156 languageModel, err := pair.builder(r)
157 if err != nil {
158 t.Fatalf("failed to build language model: %v", err)
159 }
160
161 type CalculatorInput struct {
162 A int `json:"a" description:"first number"`
163 B int `json:"b" description:"second number"`
164 }
165
166 calculatorTool := ai.NewAgentTool(
167 "add",
168 "Add two numbers",
169 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
170 result := input.A + input.B
171 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
172 },
173 )
174
175 agent := ai.NewAgent(
176 languageModel,
177 ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
178 ai.WithTools(calculatorTool),
179 )
180
181 toolCallCount := 0
182 toolResultCount := 0
183 var collectedText strings.Builder
184
185 streamCall := ai.AgentStreamCall{
186 Prompt: "What is 15 + 27?",
187 OnTextDelta: func(id, text string) error {
188 collectedText.WriteString(text)
189 return nil
190 },
191 OnToolCall: func(toolCall ai.ToolCallContent) error {
192 toolCallCount++
193 if toolCall.ToolName != "add" {
194 t.Errorf("unexpected tool name: %s", toolCall.ToolName)
195 }
196 return nil
197 },
198 OnToolResult: func(result ai.ToolResultContent) error {
199 toolResultCount++
200 return nil
201 },
202 }
203
204 result, err := agent.Stream(t.Context(), streamCall)
205 if err != nil {
206 t.Fatalf("failed to stream: %v", err)
207 }
208
209 finalText := result.Response.Content.Text()
210 if !strings.Contains(finalText, "42") {
211 t.Fatalf("expected response to contain '42', got: %q", finalText)
212 }
213
214 if toolCallCount == 0 {
215 t.Fatal("expected at least one tool call")
216 }
217
218 if toolResultCount == 0 {
219 t.Fatal("expected at least one tool result")
220 }
221 })
222 }
223}