1package providertests
2
3import (
4 "context"
5 "strconv"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/ai/ai"
10 _ "github.com/joho/godotenv/autoload"
11)
12
13func TestSimple(t *testing.T) {
14 for _, pair := range languageModelBuilders {
15 t.Run(pair.name, func(t *testing.T) {
16 r := newRecorder(t)
17
18 languageModel, err := pair.builder(r)
19 if err != nil {
20 t.Fatalf("failed to build language model: %v", err)
21 }
22
23 agent := ai.NewAgent(
24 languageModel,
25 ai.WithSystemPrompt("You are a helpful assistant"),
26 )
27 result, err := agent.Generate(t.Context(), ai.AgentCall{
28 Prompt: "Say hi in Portuguese",
29 })
30 if err != nil {
31 t.Fatalf("failed to generate: %v", err)
32 }
33
34 want := "Olá"
35 got := result.Response.Content.Text()
36 if !strings.Contains(got, want) {
37 t.Fatalf("unexpected response: got %q, want %q", got, want)
38 }
39 })
40 }
41}
42
43func TestTool(t *testing.T) {
44 for _, pair := range languageModelBuilders {
45 t.Run(pair.name, func(t *testing.T) {
46 r := newRecorder(t)
47
48 languageModel, err := pair.builder(r)
49 if err != nil {
50 t.Fatalf("failed to build language model: %v", err)
51 }
52
53 type WeatherInput struct {
54 Location string `json:"location" description:"the city"`
55 }
56
57 weatherTool := ai.NewAgentTool(
58 "weather",
59 "Get weather information for a location",
60 func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
61 return ai.NewTextResponse("40 C"), nil
62 },
63 )
64
65 agent := ai.NewAgent(
66 languageModel,
67 ai.WithSystemPrompt("You are a helpful assistant"),
68 ai.WithTools(weatherTool),
69 )
70 result, err := agent.Generate(t.Context(), ai.AgentCall{
71 Prompt: "What's the weather in Florence?",
72 })
73 if err != nil {
74 t.Fatalf("failed to generate: %v", err)
75 }
76
77 want1 := "Florence"
78 want2 := "40"
79 got := result.Response.Content.Text()
80 if !strings.Contains(got, want1) || !strings.Contains(got, want2) {
81 t.Fatalf("unexpected response: got %q, want %q %q", got, want1, want2)
82 }
83 })
84 }
85}
86
87func TestStream(t *testing.T) {
88 for _, pair := range languageModelBuilders {
89 t.Run(pair.name, func(t *testing.T) {
90 r := newRecorder(t)
91
92 languageModel, err := pair.builder(r)
93 if err != nil {
94 t.Fatalf("failed to build language model: %v", err)
95 }
96
97 agent := ai.NewAgent(
98 languageModel,
99 ai.WithSystemPrompt("You are a helpful assistant"),
100 )
101
102 var collectedText strings.Builder
103 textDeltaCount := 0
104 stepCount := 0
105
106 streamCall := ai.AgentStreamCall{
107 Prompt: "Count from 1 to 3 in Spanish",
108 OnTextDelta: func(id, text string) error {
109 textDeltaCount++
110 collectedText.WriteString(text)
111 return nil
112 },
113 OnStepFinish: func(step ai.StepResult) error {
114 stepCount++
115 return nil
116 },
117 }
118
119 result, err := agent.Stream(t.Context(), streamCall)
120 if err != nil {
121 t.Fatalf("failed to stream: %v", err)
122 }
123
124 finalText := result.Response.Content.Text()
125 if finalText == "" {
126 t.Fatal("expected non-empty response")
127 }
128
129 if !strings.Contains(strings.ToLower(finalText), "uno") ||
130 !strings.Contains(strings.ToLower(finalText), "dos") ||
131 !strings.Contains(strings.ToLower(finalText), "tres") {
132 t.Fatalf("unexpected response: %q", finalText)
133 }
134
135 if textDeltaCount == 0 {
136 t.Fatal("expected at least one text delta callback")
137 }
138
139 if stepCount == 0 {
140 t.Fatal("expected at least one step finish callback")
141 }
142
143 if collectedText.String() == "" {
144 t.Fatal("expected collected text from deltas to be non-empty")
145 }
146 })
147 }
148}
149
150func TestStreamWithTools(t *testing.T) {
151 for _, pair := range languageModelBuilders {
152 t.Run(pair.name, func(t *testing.T) {
153 r := newRecorder(t)
154
155 languageModel, err := pair.builder(r)
156 if err != nil {
157 t.Fatalf("failed to build language model: %v", err)
158 }
159
160 type CalculatorInput struct {
161 A int `json:"a" description:"first number"`
162 B int `json:"b" description:"second number"`
163 }
164
165 calculatorTool := ai.NewAgentTool(
166 "add",
167 "Add two numbers",
168 func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
169 result := input.A + input.B
170 return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
171 },
172 )
173
174 agent := ai.NewAgent(
175 languageModel,
176 ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
177 ai.WithTools(calculatorTool),
178 )
179
180 toolCallCount := 0
181 toolResultCount := 0
182 var collectedText strings.Builder
183
184 streamCall := ai.AgentStreamCall{
185 Prompt: "What is 15 + 27?",
186 OnTextDelta: func(id, text string) error {
187 collectedText.WriteString(text)
188 return nil
189 },
190 OnToolCall: func(toolCall ai.ToolCallContent) error {
191 toolCallCount++
192 if toolCall.ToolName != "add" {
193 t.Errorf("unexpected tool name: %s", toolCall.ToolName)
194 }
195 return nil
196 },
197 OnToolResult: func(result ai.ToolResultContent) error {
198 toolResultCount++
199 return nil
200 },
201 }
202
203 result, err := agent.Stream(t.Context(), streamCall)
204 if err != nil {
205 t.Fatalf("failed to stream: %v", err)
206 }
207
208 finalText := result.Response.Content.Text()
209 if !strings.Contains(finalText, "42") {
210 t.Fatalf("expected response to contain '42', got: %q", finalText)
211 }
212
213 if toolCallCount == 0 {
214 t.Fatal("expected at least one tool call")
215 }
216
217 if toolResultCount == 0 {
218 t.Fatal("expected at least one tool result")
219 }
220 })
221 }
222}