1package providertests
2
3import (
4 "cmp"
5 "fmt"
6 "net/http"
7 "os"
8 "testing"
9
10 "charm.land/fantasy"
11 "charm.land/fantasy/providers/google"
12 "github.com/stretchr/testify/require"
13 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
14)
15
16var geminiTestModels = []testModel{
17 {"gemini-2.5-flash", "gemini-2.5-flash", true},
18 {"gemini-2.5-pro", "gemini-2.5-pro", true},
19}
20
21var vertexTestModels = []testModel{
22 {"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
23 {"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
24 {"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
25}
26
27func TestGoogleCommon(t *testing.T) {
28 var pairs []builderPair
29 for _, m := range geminiTestModels {
30 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
31 }
32 for _, m := range vertexTestModels {
33 // TODO: fixme
34 continue
35 pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
36 }
37 testCommon(t, pairs)
38}
39
40func TestGoogleThinking(t *testing.T) {
41 opts := fantasy.ProviderOptions{
42 google.Name: &google.ProviderOptions{
43 ThinkingConfig: &google.ThinkingConfig{
44 ThinkingBudget: fantasy.Opt(int64(100)),
45 IncludeThoughts: fantasy.Opt(true),
46 },
47 },
48 }
49
50 var pairs []builderPair
51 for _, m := range geminiTestModels {
52 if !m.reasoning {
53 continue
54 }
55 pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
56 }
57 testThinking(t, pairs, testGoogleThinking)
58}
59
60func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
61 reasoningContentCount := 0
62 // Test if we got the signature
63 for _, step := range result.Steps {
64 for _, msg := range step.Messages {
65 for _, content := range msg.Content {
66 if content.GetType() == fantasy.ContentTypeReasoning {
67 reasoningContentCount += 1
68 }
69 }
70 }
71 }
72 require.Greater(t, reasoningContentCount, 0)
73}
74
75func generateIDMock() google.ToolCallIDFunc {
76 id := 0
77 return func() string {
78 id++
79 return fmt.Sprintf("%d", id)
80 }
81}
82
83func geminiBuilder(model string) builderFunc {
84 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
85 provider, err := google.New(
86 google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
87 google.WithHTTPClient(&http.Client{Transport: r}),
88 google.WithToolCallIDFunc(generateIDMock()),
89 )
90 if err != nil {
91 return nil, err
92 }
93 return provider.LanguageModel(t.Context(), model)
94 }
95}
96
97func vertexBuilder(model string) builderFunc {
98 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
99 provider, err := google.New(
100 google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
101 google.WithHTTPClient(&http.Client{Transport: r}),
102 google.WithSkipAuth(!r.IsRecording()),
103 google.WithToolCallIDFunc(generateIDMock()),
104 )
105 if err != nil {
106 return nil, err
107 }
108 return provider.LanguageModel(t.Context(), model)
109 }
110}