1package providertests
2
3import (
4 "cmp"
5 "fmt"
6 "net/http"
7 "os"
8 "testing"
9
10 "charm.land/fantasy"
11 "charm.land/fantasy/providers/google"
12 "github.com/stretchr/testify/require"
13 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16var geminiTestModels = []testModel{
17 {"gemini-2.5-flash", "gemini-2.5-flash", true},
18 {"gemini-2.5-pro", "gemini-2.5-pro", true},
19}
20
21var vertexTestModels = []testModel{
22 {"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
23 {"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
24 {"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
25}
26
27func TestGoogleCommon(t *testing.T) {
28 var pairs []builderPair
29 for _, m := range geminiTestModels {
30 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
31 }
32 for _, m := range vertexTestModels {
33 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
34 }
35 testCommon(t, pairs)
36}
37
38func TestGoogleThinking(t *testing.T) {
39 opts := fantasy.ProviderOptions{
40 google.Name: &google.ProviderOptions{
41 ThinkingConfig: &google.ThinkingConfig{
42 ThinkingBudget: fantasy.Opt(int64(100)),
43 IncludeThoughts: fantasy.Opt(true),
44 },
45 },
46 }
47
48 var pairs []builderPair
49 for _, m := range geminiTestModels {
50 if !m.reasoning {
51 continue
52 }
53 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
54 }
55 testThinking(t, pairs, testGoogleThinking)
56}
57
58func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
59 reasoningContentCount := 0
60 // Test if we got the signature
61 for _, step := range result.Steps {
62 for _, msg := range step.Messages {
63 for _, content := range msg.Content {
64 if content.GetType() == fantasy.ContentTypeReasoning {
65 reasoningContentCount += 1
66 }
67 }
68 }
69 }
70 require.Greater(t, reasoningContentCount, 0)
71}
72
73func generateIDMock() google.ToolCallIDFunc {
74 id := 0
75 return func() string {
76 id++
77 return fmt.Sprintf("%d", id)
78 }
79}
80
81func geminiBuilder(model string) builderFunc {
82 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
83 provider, err := google.New(
84 google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
85 google.WithHTTPClient(&http.Client{Transport: r}),
86 google.WithToolCallIDFunc(generateIDMock()),
87 )
88 if err != nil {
89 return nil, err
90 }
91 return provider.LanguageModel(t.Context(), model)
92 }
93}
94
95func vertexBuilder(model string) builderFunc {
96 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
97 provider, err := google.New(
98 google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
99 google.WithHTTPClient(&http.Client{Transport: r}),
100 google.WithToolCallIDFunc(generateIDMock()),
101 )
102 if err != nil {
103 return nil, err
104 }
105 return provider.LanguageModel(t.Context(), model)
106 }
107}