provider_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/ai/ai"
 10	_ "github.com/joho/godotenv/autoload"
 11)
 12
 13func TestSimple(t *testing.T) {
 14	for _, pair := range languageModelBuilders {
 15		t.Run(pair.name, func(t *testing.T) {
 16			r := newRecorder(t)
 17
 18			languageModel, err := pair.builder(r)
 19			if err != nil {
 20				t.Fatalf("failed to build language model: %v", err)
 21			}
 22
 23			agent := ai.NewAgent(
 24				languageModel,
 25				ai.WithSystemPrompt("You are a helpful assistant"),
 26			)
 27			result, err := agent.Generate(t.Context(), ai.AgentCall{
 28				Prompt: "Say hi in Portuguese",
 29			})
 30			if err != nil {
 31				t.Fatalf("failed to generate: %v", err)
 32			}
 33
 34			want := "Olá"
 35			got := result.Response.Content.Text()
 36			if !strings.Contains(got, want) {
 37				t.Fatalf("unexpected response: got %q, want %q", got, want)
 38			}
 39		})
 40	}
 41}
 42
 43func TestTool(t *testing.T) {
 44	for _, pair := range languageModelBuilders {
 45		t.Run(pair.name, func(t *testing.T) {
 46			r := newRecorder(t)
 47
 48			languageModel, err := pair.builder(r)
 49			if err != nil {
 50				t.Fatalf("failed to build language model: %v", err)
 51			}
 52
 53			type WeatherInput struct {
 54				Location string `json:"location" description:"the city"`
 55			}
 56
 57			weatherTool := ai.NewAgentTool(
 58				"weather",
 59				"Get weather information for a location",
 60				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 61					return ai.NewTextResponse("40 C"), nil
 62				},
 63			)
 64
 65			agent := ai.NewAgent(
 66				languageModel,
 67				ai.WithSystemPrompt("You are a helpful assistant"),
 68				ai.WithTools(weatherTool),
 69			)
 70			result, err := agent.Generate(t.Context(), ai.AgentCall{
 71				Prompt: "What's the weather in Florence?",
 72			})
 73			if err != nil {
 74				t.Fatalf("failed to generate: %v", err)
 75			}
 76
 77			want1 := "Florence"
 78			want2 := "40"
 79			got := result.Response.Content.Text()
 80			if !strings.Contains(got, want1) || !strings.Contains(got, want2) {
 81				t.Fatalf("unexpected response: got %q, want %q %q", got, want1, want2)
 82			}
 83		})
 84	}
 85}
 86
 87func TestStream(t *testing.T) {
 88	for _, pair := range languageModelBuilders {
 89		t.Run(pair.name, func(t *testing.T) {
 90			r := newRecorder(t)
 91
 92			languageModel, err := pair.builder(r)
 93			if err != nil {
 94				t.Fatalf("failed to build language model: %v", err)
 95			}
 96
 97			agent := ai.NewAgent(
 98				languageModel,
 99				ai.WithSystemPrompt("You are a helpful assistant"),
100			)
101
102			var collectedText strings.Builder
103			textDeltaCount := 0
104			stepCount := 0
105
106			streamCall := ai.AgentStreamCall{
107				Prompt: "Count from 1 to 3 in Spanish",
108				OnTextDelta: func(id, text string) error {
109					textDeltaCount++
110					collectedText.WriteString(text)
111					return nil
112				},
113				OnStepFinish: func(step ai.StepResult) error {
114					stepCount++
115					return nil
116				},
117			}
118
119			result, err := agent.Stream(t.Context(), streamCall)
120			if err != nil {
121				t.Fatalf("failed to stream: %v", err)
122			}
123
124			finalText := result.Response.Content.Text()
125			if finalText == "" {
126				t.Fatal("expected non-empty response")
127			}
128
129			if !strings.Contains(strings.ToLower(finalText), "uno") ||
130				!strings.Contains(strings.ToLower(finalText), "dos") ||
131				!strings.Contains(strings.ToLower(finalText), "tres") {
132				t.Fatalf("unexpected response: %q", finalText)
133			}
134
135			if textDeltaCount == 0 {
136				t.Fatal("expected at least one text delta callback")
137			}
138
139			if stepCount == 0 {
140				t.Fatal("expected at least one step finish callback")
141			}
142
143			if collectedText.String() == "" {
144				t.Fatal("expected collected text from deltas to be non-empty")
145			}
146		})
147	}
148}
149
150func TestStreamWithTools(t *testing.T) {
151	for _, pair := range languageModelBuilders {
152		t.Run(pair.name, func(t *testing.T) {
153			r := newRecorder(t)
154
155			languageModel, err := pair.builder(r)
156			if err != nil {
157				t.Fatalf("failed to build language model: %v", err)
158			}
159
160			type CalculatorInput struct {
161				A int `json:"a" description:"first number"`
162				B int `json:"b" description:"second number"`
163			}
164
165			calculatorTool := ai.NewAgentTool(
166				"add",
167				"Add two numbers",
168				func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
169					result := input.A + input.B
170					return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
171				},
172			)
173
174			agent := ai.NewAgent(
175				languageModel,
176				ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
177				ai.WithTools(calculatorTool),
178			)
179
180			toolCallCount := 0
181			toolResultCount := 0
182			var collectedText strings.Builder
183
184			streamCall := ai.AgentStreamCall{
185				Prompt: "What is 15 + 27?",
186				OnTextDelta: func(id, text string) error {
187					collectedText.WriteString(text)
188					return nil
189				},
190				OnToolCall: func(toolCall ai.ToolCallContent) error {
191					toolCallCount++
192					if toolCall.ToolName != "add" {
193						t.Errorf("unexpected tool name: %s", toolCall.ToolName)
194					}
195					return nil
196				},
197				OnToolResult: func(result ai.ToolResultContent) error {
198					toolResultCount++
199					return nil
200				},
201			}
202
203			result, err := agent.Stream(t.Context(), streamCall)
204			if err != nil {
205				t.Fatalf("failed to stream: %v", err)
206			}
207
208			finalText := result.Response.Content.Text()
209			if !strings.Contains(finalText, "42") {
210				t.Fatalf("expected response to contain '42', got: %q", finalText)
211			}
212
213			if toolCallCount == 0 {
214				t.Fatal("expected at least one tool call")
215			}
216
217			if toolResultCount == 0 {
218				t.Fatal("expected at least one tool result")
219			}
220		})
221	}
222}