1package providertests
2
3import (
4 "cmp"
5 "net/http"
6 "os"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/charmbracelet/fantasy/google"
11 "github.com/stretchr/testify/require"
12 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15var geminiTestModels = []testModel{
16 {"gemini-2.5-flash", "gemini-2.5-flash", true},
17 {"gemini-2.5-pro", "gemini-2.5-pro", true},
18}
19
20var vertexTestModels = []testModel{
21 {"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
22 {"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
23 {"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
24}
25
26func TestGoogleCommon(t *testing.T) {
27 var pairs []builderPair
28 for _, m := range geminiTestModels {
29 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil})
30 }
31 for _, m := range vertexTestModels {
32 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil})
33 }
34 testCommon(t, pairs)
35}
36
37func TestGoogleThinking(t *testing.T) {
38 opts := ai.ProviderOptions{
39 google.Name: &google.ProviderOptions{
40 ThinkingConfig: &google.ThinkingConfig{
41 ThinkingBudget: ai.IntOption(100),
42 IncludeThoughts: ai.BoolOption(true),
43 },
44 },
45 }
46
47 var pairs []builderPair
48 for _, m := range geminiTestModels {
49 if !m.reasoning {
50 continue
51 }
52 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts})
53 }
54 testThinking(t, pairs, testGoogleThinking)
55}
56
57func testGoogleThinking(t *testing.T, result *ai.AgentResult) {
58 reasoningContentCount := 0
59 // Test if we got the signature
60 for _, step := range result.Steps {
61 for _, msg := range step.Messages {
62 for _, content := range msg.Content {
63 if content.GetType() == ai.ContentTypeReasoning {
64 reasoningContentCount += 1
65 }
66 }
67 }
68 }
69 require.Greater(t, reasoningContentCount, 0)
70}
71
72func geminiBuilder(model string) builderFunc {
73 return func(r *recorder.Recorder) (ai.LanguageModel, error) {
74 provider := google.New(
75 google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
76 google.WithHTTPClient(&http.Client{Transport: r}),
77 )
78 return provider.LanguageModel(model)
79 }
80}
81
82func vertexBuilder(model string) builderFunc {
83 return func(r *recorder.Recorder) (ai.LanguageModel, error) {
84 provider := google.New(
85 google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
86 google.WithHTTPClient(&http.Client{Transport: r}),
87 google.WithSkipAuth(!r.IsRecording()),
88 )
89 return provider.LanguageModel(model)
90 }
91}