google_test.go

  1package providertests
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"charm.land/fantasy/providers/google"
 12	"charm.land/x/vcr"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16var geminiTestModels = []testModel{
 17	{"gemini-3-pro-preview", "gemini-3-pro-preview", true},
 18	{"gemini-2.5-flash", "gemini-2.5-flash", true},
 19	{"gemini-2.5-pro", "gemini-2.5-pro", true},
 20}
 21
 22var vertexTestModels = []testModel{
 23	{"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
 24	{"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
 25	{"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
 26}
 27
 28func TestGoogleCommon(t *testing.T) {
 29	var pairs []builderPair
 30	for _, m := range geminiTestModels {
 31		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
 32	}
 33	for _, m := range vertexTestModels {
 34		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
 35	}
 36	testCommon(t, pairs)
 37}
 38
 39func TestGoogleThinking(t *testing.T) {
 40	opts := fantasy.ProviderOptions{
 41		google.Name: &google.ProviderOptions{
 42			ThinkingConfig: &google.ThinkingConfig{
 43				ThinkingBudget:  fantasy.Opt(int64(100)),
 44				IncludeThoughts: fantasy.Opt(true),
 45			},
 46		},
 47	}
 48
 49	var pairs []builderPair
 50	for _, m := range geminiTestModels {
 51		if !m.reasoning {
 52			continue
 53		}
 54		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
 55	}
 56	testThinking(t, pairs, testGoogleThinking)
 57}
 58
 59func TestGoogleObjectGeneration(t *testing.T) {
 60	var pairs []builderPair
 61	for _, m := range geminiTestModels {
 62		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
 63	}
 64	testObjectGeneration(t, pairs)
 65}
 66
 67func TestGoogleVertexObjectGeneration(t *testing.T) {
 68	var pairs []builderPair
 69	for _, m := range vertexTestModels {
 70		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
 71	}
 72	testObjectGeneration(t, pairs)
 73}
 74
 75func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
 76	reasoningContentCount := 0
 77	// Test if we got the signature
 78	for _, step := range result.Steps {
 79		for _, msg := range step.Messages {
 80			for _, content := range msg.Content {
 81				if content.GetType() == fantasy.ContentTypeReasoning {
 82					reasoningContentCount += 1
 83				}
 84			}
 85		}
 86	}
 87	require.Greater(t, reasoningContentCount, 0)
 88}
 89
 90func generateIDMock() google.ToolCallIDFunc {
 91	id := 0
 92	return func() string {
 93		id++
 94		return fmt.Sprintf("%d", id)
 95	}
 96}
 97
 98func geminiBuilder(model string) builderFunc {
 99	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
100		provider, err := google.New(
101			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
102			google.WithHTTPClient(&http.Client{Transport: r}),
103			google.WithToolCallIDFunc(generateIDMock()),
104		)
105		if err != nil {
106			return nil, err
107		}
108		return provider.LanguageModel(t.Context(), model)
109	}
110}
111
112func vertexBuilder(model string) builderFunc {
113	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
114		provider, err := google.New(
115			google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
116			google.WithHTTPClient(&http.Client{Transport: r}),
117			google.WithSkipAuth(!r.IsRecording()),
118			google.WithToolCallIDFunc(generateIDMock()),
119		)
120		if err != nil {
121			return nil, err
122		}
123		return provider.LanguageModel(t.Context(), model)
124	}
125}