1package providertests
2
3import (
4 "cmp"
5 "fmt"
6 "net/http"
7 "os"
8 "strings"
9 "testing"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/providers/google"
13 "charm.land/x/vcr"
14 "github.com/stretchr/testify/require"
15)
16
17var geminiTestModels = []testModel{
18 {"gemini-3-pro-preview", "gemini-3-pro-preview", true},
19 {"gemini-2.5-flash", "gemini-2.5-flash", true},
20 {"gemini-2.5-pro", "gemini-2.5-pro", true},
21}
22
23var vertexTestModels = []testModel{
24 {"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
25 {"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
26 {"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
27}
28
29func TestGoogleCommon(t *testing.T) {
30 var pairs []builderPair
31 for _, m := range geminiTestModels {
32 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
33 }
34 for _, m := range vertexTestModels {
35 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
36 }
37 testCommon(t, pairs)
38}
39
40func TestGoogleThinking(t *testing.T) {
41 gemini2Opts := fantasy.ProviderOptions{
42 google.Name: &google.ProviderOptions{
43 ThinkingConfig: &google.ThinkingConfig{
44 ThinkingBudget: fantasy.Opt(int64(100)),
45 IncludeThoughts: fantasy.Opt(true),
46 },
47 },
48 }
49 gemini3Opts := fantasy.ProviderOptions{
50 google.Name: &google.ProviderOptions{
51 ThinkingConfig: &google.ThinkingConfig{
52 ThinkingLevel: fantasy.Opt(google.ThinkingLevelHigh),
53 IncludeThoughts: fantasy.Opt(true),
54 },
55 },
56 }
57
58 var pairs []builderPair
59 for _, m := range geminiTestModels {
60 if !m.reasoning {
61 continue
62 }
63 opts := gemini3Opts
64 if strings.HasPrefix(m.model, "gemini-2") {
65 opts = gemini2Opts
66 }
67 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
68 }
69 testThinking(t, pairs, testGoogleThinking)
70}
71
72func TestGoogleObjectGeneration(t *testing.T) {
73 var pairs []builderPair
74 for _, m := range geminiTestModels {
75 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
76 }
77 testObjectGeneration(t, pairs)
78}
79
80func TestGoogleVertexObjectGeneration(t *testing.T) {
81 var pairs []builderPair
82 for _, m := range vertexTestModels {
83 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
84 }
85 testObjectGeneration(t, pairs)
86}
87
88func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
89 reasoningContentCount := 0
90 // Test if we got the signature
91 for _, step := range result.Steps {
92 for _, msg := range step.Messages {
93 for _, content := range msg.Content {
94 if content.GetType() == fantasy.ContentTypeReasoning {
95 reasoningContentCount += 1
96 }
97 }
98 }
99 }
100 require.Greater(t, reasoningContentCount, 0)
101}
102
103func generateIDMock() google.ToolCallIDFunc {
104 id := 0
105 return func() string {
106 id++
107 return fmt.Sprintf("%d", id)
108 }
109}
110
111func geminiBuilder(model string) builderFunc {
112 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
113 provider, err := google.New(
114 google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
115 google.WithHTTPClient(&http.Client{Transport: r}),
116 google.WithToolCallIDFunc(generateIDMock()),
117 )
118 if err != nil {
119 return nil, err
120 }
121 return provider.LanguageModel(t.Context(), model)
122 }
123}
124
125func vertexBuilder(model string) builderFunc {
126 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
127 provider, err := google.New(
128 google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
129 google.WithHTTPClient(&http.Client{Transport: r}),
130 google.WithSkipAuth(!r.IsRecording()),
131 google.WithToolCallIDFunc(generateIDMock()),
132 )
133 if err != nil {
134 return nil, err
135 }
136 return provider.LanguageModel(t.Context(), model)
137 }
138}