google_test.go

  1package providertests
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"strings"
  9	"testing"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/providers/google"
 13	"charm.land/x/vcr"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17var geminiTestModels = []testModel{
 18	{"gemini-3-pro-preview", "gemini-3-pro-preview", true},
 19	{"gemini-2.5-flash", "gemini-2.5-flash", true},
 20	{"gemini-2.5-pro", "gemini-2.5-pro", true},
 21}
 22
 23var vertexTestModels = []testModel{
 24	{"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
 25	{"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
 26	{"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
 27}
 28
 29func TestGoogleCommon(t *testing.T) {
 30	var pairs []builderPair
 31	for _, m := range geminiTestModels {
 32		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
 33	}
 34	for _, m := range vertexTestModels {
 35		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
 36	}
 37	testCommon(t, pairs)
 38}
 39
 40func TestGoogleThinking(t *testing.T) {
 41	gemini2Opts := fantasy.ProviderOptions{
 42		google.Name: &google.ProviderOptions{
 43			ThinkingConfig: &google.ThinkingConfig{
 44				ThinkingBudget:  fantasy.Opt(int64(100)),
 45				IncludeThoughts: fantasy.Opt(true),
 46			},
 47		},
 48	}
 49	gemini3Opts := fantasy.ProviderOptions{
 50		google.Name: &google.ProviderOptions{
 51			ThinkingConfig: &google.ThinkingConfig{
 52				ThinkingLevel:   fantasy.Opt(google.ThinkingLevelHigh),
 53				IncludeThoughts: fantasy.Opt(true),
 54			},
 55		},
 56	}
 57
 58	var pairs []builderPair
 59	for _, m := range geminiTestModels {
 60		if !m.reasoning {
 61			continue
 62		}
 63		opts := gemini3Opts
 64		if strings.HasPrefix(m.model, "gemini-2") {
 65			opts = gemini2Opts
 66		}
 67		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
 68	}
 69	testThinking(t, pairs, testGoogleThinking)
 70}
 71
 72func TestGoogleObjectGeneration(t *testing.T) {
 73	var pairs []builderPair
 74	for _, m := range geminiTestModels {
 75		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
 76	}
 77	testObjectGeneration(t, pairs)
 78}
 79
 80func TestGoogleVertexObjectGeneration(t *testing.T) {
 81	var pairs []builderPair
 82	for _, m := range vertexTestModels {
 83		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
 84	}
 85	testObjectGeneration(t, pairs)
 86}
 87
 88func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
 89	reasoningContentCount := 0
 90	// Test if we got the signature
 91	for _, step := range result.Steps {
 92		for _, msg := range step.Messages {
 93			for _, content := range msg.Content {
 94				if content.GetType() == fantasy.ContentTypeReasoning {
 95					reasoningContentCount += 1
 96				}
 97			}
 98		}
 99	}
100	require.Greater(t, reasoningContentCount, 0)
101}
102
103func generateIDMock() google.ToolCallIDFunc {
104	id := 0
105	return func() string {
106		id++
107		return fmt.Sprintf("%d", id)
108	}
109}
110
111func geminiBuilder(model string) builderFunc {
112	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
113		provider, err := google.New(
114			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
115			google.WithHTTPClient(&http.Client{Transport: r}),
116			google.WithToolCallIDFunc(generateIDMock()),
117		)
118		if err != nil {
119			return nil, err
120		}
121		return provider.LanguageModel(t.Context(), model)
122	}
123}
124
125func vertexBuilder(model string) builderFunc {
126	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
127		provider, err := google.New(
128			google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
129			google.WithHTTPClient(&http.Client{Transport: r}),
130			google.WithSkipAuth(!r.IsRecording()),
131			google.WithToolCallIDFunc(generateIDMock()),
132		)
133		if err != nil {
134			return nil, err
135		}
136		return provider.LanguageModel(t.Context(), model)
137	}
138}