provider_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"strconv"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/ai/ai"
 10	_ "github.com/joho/godotenv/autoload"
 11)
 12
 13func TestSimple(t *testing.T) {
 14	for _, pair := range languageModelBuilders {
 15		t.Run(pair.name, func(t *testing.T) {
 16			r := newRecorder(t)
 17
 18			languageModel, err := pair.builder(r)
 19			if err != nil {
 20				t.Fatalf("failed to build language model: %v", err)
 21			}
 22
 23			agent := ai.NewAgent(
 24				languageModel,
 25				ai.WithSystemPrompt("You are a helpful assistant"),
 26			)
 27			result, err := agent.Generate(t.Context(), ai.AgentCall{
 28				Prompt: "Say hi in Portuguese",
 29			})
 30			if err != nil {
 31				t.Fatalf("failed to generate: %v", err)
 32			}
 33
 34			option1 := "Oi"
 35			option2 := "Olá"
 36			got := result.Response.Content.Text()
 37			if !strings.Contains(got, option1) && !strings.Contains(got, option2) {
 38				t.Fatalf("unexpected response: got %q, want %q or %q", got, option1, option2)
 39			}
 40		})
 41	}
 42}
 43
 44func TestTool(t *testing.T) {
 45	for _, pair := range languageModelBuilders {
 46		t.Run(pair.name, func(t *testing.T) {
 47			r := newRecorder(t)
 48
 49			languageModel, err := pair.builder(r)
 50			if err != nil {
 51				t.Fatalf("failed to build language model: %v", err)
 52			}
 53
 54			type WeatherInput struct {
 55				Location string `json:"location" description:"the city"`
 56			}
 57
 58			weatherTool := ai.NewAgentTool(
 59				"weather",
 60				"Get weather information for a location",
 61				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 62					return ai.NewTextResponse("40 C"), nil
 63				},
 64			)
 65
 66			agent := ai.NewAgent(
 67				languageModel,
 68				ai.WithSystemPrompt("You are a helpful assistant"),
 69				ai.WithTools(weatherTool),
 70			)
 71			result, err := agent.Generate(t.Context(), ai.AgentCall{
 72				Prompt: "What's the weather in Florence?",
 73			})
 74			if err != nil {
 75				t.Fatalf("failed to generate: %v", err)
 76			}
 77
 78			want1 := "Florence"
 79			want2 := "40"
 80			got := result.Response.Content.Text()
 81			if !strings.Contains(got, want1) || !strings.Contains(got, want2) {
 82				t.Fatalf("unexpected response: got %q, want %q %q", got, want1, want2)
 83			}
 84		})
 85	}
 86}
 87
 88func TestStream(t *testing.T) {
 89	for _, pair := range languageModelBuilders {
 90		t.Run(pair.name, func(t *testing.T) {
 91			r := newRecorder(t)
 92
 93			languageModel, err := pair.builder(r)
 94			if err != nil {
 95				t.Fatalf("failed to build language model: %v", err)
 96			}
 97
 98			agent := ai.NewAgent(
 99				languageModel,
100				ai.WithSystemPrompt("You are a helpful assistant"),
101			)
102
103			var collectedText strings.Builder
104			textDeltaCount := 0
105			stepCount := 0
106
107			streamCall := ai.AgentStreamCall{
108				Prompt: "Count from 1 to 3 in Spanish",
109				OnTextDelta: func(id, text string) error {
110					textDeltaCount++
111					collectedText.WriteString(text)
112					return nil
113				},
114				OnStepFinish: func(step ai.StepResult) error {
115					stepCount++
116					return nil
117				},
118			}
119
120			result, err := agent.Stream(t.Context(), streamCall)
121			if err != nil {
122				t.Fatalf("failed to stream: %v", err)
123			}
124
125			finalText := result.Response.Content.Text()
126			if finalText == "" {
127				t.Fatal("expected non-empty response")
128			}
129
130			if !strings.Contains(strings.ToLower(finalText), "uno") ||
131				!strings.Contains(strings.ToLower(finalText), "dos") ||
132				!strings.Contains(strings.ToLower(finalText), "tres") {
133				t.Fatalf("unexpected response: %q", finalText)
134			}
135
136			if textDeltaCount == 0 {
137				t.Fatal("expected at least one text delta callback")
138			}
139
140			if stepCount == 0 {
141				t.Fatal("expected at least one step finish callback")
142			}
143
144			if collectedText.String() == "" {
145				t.Fatal("expected collected text from deltas to be non-empty")
146			}
147		})
148	}
149}
150
151func TestStreamWithTools(t *testing.T) {
152	for _, pair := range languageModelBuilders {
153		t.Run(pair.name, func(t *testing.T) {
154			r := newRecorder(t)
155
156			languageModel, err := pair.builder(r)
157			if err != nil {
158				t.Fatalf("failed to build language model: %v", err)
159			}
160
161			type CalculatorInput struct {
162				A int `json:"a" description:"first number"`
163				B int `json:"b" description:"second number"`
164			}
165
166			calculatorTool := ai.NewAgentTool(
167				"add",
168				"Add two numbers",
169				func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
170					result := input.A + input.B
171					return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
172				},
173			)
174
175			agent := ai.NewAgent(
176				languageModel,
177				ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
178				ai.WithTools(calculatorTool),
179			)
180
181			toolCallCount := 0
182			toolResultCount := 0
183			var collectedText strings.Builder
184
185			streamCall := ai.AgentStreamCall{
186				Prompt: "What is 15 + 27?",
187				OnTextDelta: func(id, text string) error {
188					collectedText.WriteString(text)
189					return nil
190				},
191				OnToolCall: func(toolCall ai.ToolCallContent) error {
192					toolCallCount++
193					if toolCall.ToolName != "add" {
194						t.Errorf("unexpected tool name: %s", toolCall.ToolName)
195					}
196					return nil
197				},
198				OnToolResult: func(result ai.ToolResultContent) error {
199					toolResultCount++
200					return nil
201				},
202			}
203
204			result, err := agent.Stream(t.Context(), streamCall)
205			if err != nil {
206				t.Fatalf("failed to stream: %v", err)
207			}
208
209			finalText := result.Response.Content.Text()
210			if !strings.Contains(finalText, "42") {
211				t.Fatalf("expected response to contain '42', got: %q", finalText)
212			}
213
214			if toolCallCount == 0 {
215				t.Fatal("expected at least one tool call")
216			}
217
218			if toolResultCount == 0 {
219				t.Fatal("expected at least one tool result")
220			}
221		})
222	}
223}