1package providertests
2
3import (
4 "cmp"
5 "fmt"
6 "net/http"
7 "os"
8 "testing"
9
10 "charm.land/fantasy"
11 "charm.land/fantasy/providers/google"
12 "github.com/stretchr/testify/require"
13 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16var geminiTestModels = []testModel{
17 {"gemini-2.5-flash", "gemini-2.5-flash", true},
18 {"gemini-2.5-pro", "gemini-2.5-pro", true},
19}
20
21var vertexTestModels = []testModel{
22 {"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
23 {"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
24 {"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
25}
26
27func TestGoogleCommon(t *testing.T) {
28 var pairs []builderPair
29 for _, m := range geminiTestModels {
30 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
31 }
32 for _, m := range vertexTestModels {
33 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
34 }
35 testCommon(t, pairs)
36}
37
38func TestGoogleThinking(t *testing.T) {
39 opts := fantasy.ProviderOptions{
40 google.Name: &google.ProviderOptions{
41 ThinkingConfig: &google.ThinkingConfig{
42 ThinkingBudget: fantasy.Opt(int64(100)),
43 IncludeThoughts: fantasy.Opt(true),
44 },
45 },
46 }
47
48 var pairs []builderPair
49 for _, m := range geminiTestModels {
50 if !m.reasoning {
51 continue
52 }
53 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
54 }
55 testThinking(t, pairs, testGoogleThinking)
56}
57
58func TestGoogleObjectGeneration(t *testing.T) {
59 var pairs []builderPair
60 for _, m := range geminiTestModels {
61 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
62 }
63 testObjectGeneration(t, pairs)
64}
65
66func TestGoogleVertexObjectGeneration(t *testing.T) {
67 var pairs []builderPair
68 for _, m := range vertexTestModels {
69 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
70 }
71 testObjectGeneration(t, pairs)
72}
73
74func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
75 reasoningContentCount := 0
76 // Test if we got the signature
77 for _, step := range result.Steps {
78 for _, msg := range step.Messages {
79 for _, content := range msg.Content {
80 if content.GetType() == fantasy.ContentTypeReasoning {
81 reasoningContentCount += 1
82 }
83 }
84 }
85 }
86 require.Greater(t, reasoningContentCount, 0)
87}
88
89func generateIDMock() google.ToolCallIDFunc {
90 id := 0
91 return func() string {
92 id++
93 return fmt.Sprintf("%d", id)
94 }
95}
96
97func geminiBuilder(model string) builderFunc {
98 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
99 provider, err := google.New(
100 google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
101 google.WithHTTPClient(&http.Client{Transport: r}),
102 google.WithToolCallIDFunc(generateIDMock()),
103 )
104 if err != nil {
105 return nil, err
106 }
107 return provider.LanguageModel(t.Context(), model)
108 }
109}
110
111func vertexBuilder(model string) builderFunc {
112 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
113 provider, err := google.New(
114 google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
115 google.WithHTTPClient(&http.Client{Transport: r}),
116 google.WithSkipAuth(!r.IsRecording()),
117 google.WithToolCallIDFunc(generateIDMock()),
118 )
119 if err != nil {
120 return nil, err
121 }
122 return provider.LanguageModel(t.Context(), model)
123 }
124}