google_test.go

  1package providertests
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"charm.land/fantasy/providers/google"
 12	"charm.land/x/vcr"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16var geminiTestModels = []testModel{
 17	{"gemini-2.5-flash", "gemini-2.5-flash", true},
 18	{"gemini-2.5-pro", "gemini-2.5-pro", true},
 19}
 20
 21var vertexTestModels = []testModel{
 22	{"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
 23	{"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
 24	{"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
 25}
 26
 27func TestGoogleCommon(t *testing.T) {
 28	var pairs []builderPair
 29	for _, m := range geminiTestModels {
 30		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
 31	}
 32	for _, m := range vertexTestModels {
 33		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
 34	}
 35	testCommon(t, pairs)
 36}
 37
 38func TestGoogleThinking(t *testing.T) {
 39	opts := fantasy.ProviderOptions{
 40		google.Name: &google.ProviderOptions{
 41			ThinkingConfig: &google.ThinkingConfig{
 42				ThinkingBudget:  fantasy.Opt(int64(100)),
 43				IncludeThoughts: fantasy.Opt(true),
 44			},
 45		},
 46	}
 47
 48	var pairs []builderPair
 49	for _, m := range geminiTestModels {
 50		if !m.reasoning {
 51			continue
 52		}
 53		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
 54	}
 55	testThinking(t, pairs, testGoogleThinking)
 56}
 57
 58func TestGoogleObjectGeneration(t *testing.T) {
 59	var pairs []builderPair
 60	for _, m := range geminiTestModels {
 61		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
 62	}
 63	testObjectGeneration(t, pairs)
 64}
 65
 66func TestGoogleVertexObjectGeneration(t *testing.T) {
 67	var pairs []builderPair
 68	for _, m := range vertexTestModels {
 69		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
 70	}
 71	testObjectGeneration(t, pairs)
 72}
 73
 74func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
 75	reasoningContentCount := 0
 76	// Test if we got the signature
 77	for _, step := range result.Steps {
 78		for _, msg := range step.Messages {
 79			for _, content := range msg.Content {
 80				if content.GetType() == fantasy.ContentTypeReasoning {
 81					reasoningContentCount += 1
 82				}
 83			}
 84		}
 85	}
 86	require.Greater(t, reasoningContentCount, 0)
 87}
 88
 89func generateIDMock() google.ToolCallIDFunc {
 90	id := 0
 91	return func() string {
 92		id++
 93		return fmt.Sprintf("%d", id)
 94	}
 95}
 96
 97func geminiBuilder(model string) builderFunc {
 98	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 99		provider, err := google.New(
100			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
101			google.WithHTTPClient(&http.Client{Transport: r}),
102			google.WithToolCallIDFunc(generateIDMock()),
103		)
104		if err != nil {
105			return nil, err
106		}
107		return provider.LanguageModel(t.Context(), model)
108	}
109}
110
111func vertexBuilder(model string) builderFunc {
112	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
113		provider, err := google.New(
114			google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
115			google.WithHTTPClient(&http.Client{Transport: r}),
116			google.WithSkipAuth(!r.IsRecording()),
117			google.WithToolCallIDFunc(generateIDMock()),
118		)
119		if err != nil {
120			return nil, err
121		}
122		return provider.LanguageModel(t.Context(), model)
123	}
124}