provider_test.go

 1package providertests
 2
 3import (
 4	"context"
 5	"strings"
 6	"testing"
 7
 8	"github.com/charmbracelet/ai/ai"
 9	_ "github.com/joho/godotenv/autoload"
10)
11
12func TestSimple(t *testing.T) {
13	for _, pair := range languageModelBuilders {
14		t.Run(pair.name, func(t *testing.T) {
15			r := newRecorder(t)
16
17			languageModel, err := pair.builder(r)
18			if err != nil {
19				t.Fatalf("failed to build language model: %v", err)
20			}
21
22			agent := ai.NewAgent(
23				languageModel,
24				ai.WithSystemPrompt("You are a helpful assistant"),
25			)
26			result, err := agent.Generate(t.Context(), ai.AgentCall{
27				Prompt: "Say hi in Portuguese",
28			})
29			if err != nil {
30				t.Fatalf("failed to generate: %v", err)
31			}
32
33			want := "Olá"
34			got := result.Response.Content.Text()
35			if !strings.Contains(got, want) {
36				t.Fatalf("unexpected response: got %q, want %q", got, want)
37			}
38		})
39	}
40}
41
42func TestTool(t *testing.T) {
43	for _, pair := range languageModelBuilders {
44		t.Run(pair.name, func(t *testing.T) {
45			r := newRecorder(t)
46
47			languageModel, err := pair.builder(r)
48			if err != nil {
49				t.Fatalf("failed to build language model: %v", err)
50			}
51
52			type WeatherInput struct {
53				Location string `json:"location" description:"the city"`
54			}
55
56			weatherTool := ai.NewAgentTool(
57				"weather",
58				"Get weather information for a location",
59				func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
60					return ai.NewTextResponse("40 C"), nil
61				},
62			)
63
64			agent := ai.NewAgent(
65				languageModel,
66				ai.WithSystemPrompt("You are a helpful assistant"),
67				ai.WithTools(weatherTool),
68			)
69			result, err := agent.Generate(t.Context(), ai.AgentCall{
70				Prompt: "What's the weather in Florence?",
71			})
72			if err != nil {
73				t.Fatalf("failed to generate: %v", err)
74			}
75
76			want1 := "Florence"
77			want2 := "40"
78			got := result.Response.Content.Text()
79			if !strings.Contains(got, want1) || !strings.Contains(got, want2) {
80				t.Fatalf("unexpected response: got %q, want %q %q", got, want1, want2)
81			}
82		})
83	}
84}