google_test.go

 1package providertests
 2
 3import (
 4	"cmp"
 5	"net/http"
 6	"os"
 7	"testing"
 8
 9	"charm.land/fantasy"
10	"charm.land/fantasy/providers/google"
11	"github.com/stretchr/testify/require"
12	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15var geminiTestModels = []testModel{
16	{"gemini-2.5-flash", "gemini-2.5-flash", true},
17	{"gemini-2.5-pro", "gemini-2.5-pro", true},
18}
19
20var vertexTestModels = []testModel{
21	{"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
22	{"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
23	{"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
24}
25
26func TestGoogleCommon(t *testing.T) {
27	var pairs []builderPair
28	for _, m := range geminiTestModels {
29		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
30	}
31	for _, m := range vertexTestModels {
32		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
33	}
34	testCommon(t, pairs)
35}
36
37func TestGoogleThinking(t *testing.T) {
38	opts := fantasy.ProviderOptions{
39		google.Name: &google.ProviderOptions{
40			ThinkingConfig: &google.ThinkingConfig{
41				ThinkingBudget:  fantasy.Opt(int64(100)),
42				IncludeThoughts: fantasy.Opt(true),
43			},
44		},
45	}
46
47	var pairs []builderPair
48	for _, m := range geminiTestModels {
49		if !m.reasoning {
50			continue
51		}
52		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
53	}
54	testThinking(t, pairs, testGoogleThinking)
55}
56
57func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
58	reasoningContentCount := 0
59	// Test if we got the signature
60	for _, step := range result.Steps {
61		for _, msg := range step.Messages {
62			for _, content := range msg.Content {
63				if content.GetType() == fantasy.ContentTypeReasoning {
64					reasoningContentCount += 1
65				}
66			}
67		}
68	}
69	require.Greater(t, reasoningContentCount, 0)
70}
71
72func geminiBuilder(model string) builderFunc {
73	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
74		provider, err := google.New(
75			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
76			google.WithHTTPClient(&http.Client{Transport: r}),
77		)
78		if err != nil {
79			return nil, err
80		}
81		return provider.LanguageModel(t.Context(), model)
82	}
83}
84
85func vertexBuilder(model string) builderFunc {
86	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
87		provider, err := google.New(
88			google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
89			google.WithHTTPClient(&http.Client{Transport: r}),
90			google.WithSkipAuth(!r.IsRecording()),
91		)
92		if err != nil {
93			return nil, err
94		}
95		return provider.LanguageModel(t.Context(), model)
96	}
97}