1package providertests
2
3import (
4 "cmp"
5 "fmt"
6 "net/http"
7 "os"
8 "testing"
9
10 "charm.land/fantasy"
11 "charm.land/fantasy/providers/google"
12 "charm.land/x/vcr"
13 "github.com/stretchr/testify/require"
14)
15
16var geminiTestModels = []testModel{
17 {"gemini-3-pro-preview", "gemini-3-pro-preview", true},
18 {"gemini-2.5-flash", "gemini-2.5-flash", true},
19 {"gemini-2.5-pro", "gemini-2.5-pro", true},
20}
21
22var vertexTestModels = []testModel{
23 {"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
24 {"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
25 {"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
26}
27
28func TestGoogleCommon(t *testing.T) {
29 var pairs []builderPair
30 for _, m := range geminiTestModels {
31 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
32 }
33 for _, m := range vertexTestModels {
34 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
35 }
36 testCommon(t, pairs)
37}
38
39func TestGoogleThinking(t *testing.T) {
40 opts := fantasy.ProviderOptions{
41 google.Name: &google.ProviderOptions{
42 ThinkingConfig: &google.ThinkingConfig{
43 ThinkingBudget: fantasy.Opt(int64(100)),
44 IncludeThoughts: fantasy.Opt(true),
45 },
46 },
47 }
48
49 var pairs []builderPair
50 for _, m := range geminiTestModels {
51 if !m.reasoning {
52 continue
53 }
54 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
55 }
56 testThinking(t, pairs, testGoogleThinking)
57}
58
59func TestGoogleObjectGeneration(t *testing.T) {
60 var pairs []builderPair
61 for _, m := range geminiTestModels {
62 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
63 }
64 testObjectGeneration(t, pairs)
65}
66
67func TestGoogleVertexObjectGeneration(t *testing.T) {
68 var pairs []builderPair
69 for _, m := range vertexTestModels {
70 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
71 }
72 testObjectGeneration(t, pairs)
73}
74
75func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
76 reasoningContentCount := 0
77 // Test if we got the signature
78 for _, step := range result.Steps {
79 for _, msg := range step.Messages {
80 for _, content := range msg.Content {
81 if content.GetType() == fantasy.ContentTypeReasoning {
82 reasoningContentCount += 1
83 }
84 }
85 }
86 }
87 require.Greater(t, reasoningContentCount, 0)
88}
89
90func generateIDMock() google.ToolCallIDFunc {
91 id := 0
92 return func() string {
93 id++
94 return fmt.Sprintf("%d", id)
95 }
96}
97
98func geminiBuilder(model string) builderFunc {
99 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
100 provider, err := google.New(
101 google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
102 google.WithHTTPClient(&http.Client{Transport: r}),
103 google.WithToolCallIDFunc(generateIDMock()),
104 )
105 if err != nil {
106 return nil, err
107 }
108 return provider.LanguageModel(t.Context(), model)
109 }
110}
111
112func vertexBuilder(model string) builderFunc {
113 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
114 provider, err := google.New(
115 google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
116 google.WithHTTPClient(&http.Client{Transport: r}),
117 google.WithSkipAuth(!r.IsRecording()),
118 google.WithToolCallIDFunc(generateIDMock()),
119 )
120 if err != nil {
121 return nil, err
122 }
123 return provider.LanguageModel(t.Context(), model)
124 }
125}