google_test.go

  1package providertests
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"charm.land/fantasy/providers/google"
 12	"github.com/stretchr/testify/require"
 13	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 14)
 15
 16var geminiTestModels = []testModel{
 17	{"gemini-2.5-flash", "gemini-2.5-flash", true},
 18	{"gemini-2.5-pro", "gemini-2.5-pro", true},
 19}
 20
 21var vertexTestModels = []testModel{
 22	{"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
 23	{"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
 24	{"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
 25}
 26
 27func TestGoogleCommon(t *testing.T) {
 28	var pairs []builderPair
 29	for _, m := range geminiTestModels {
 30		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
 31	}
 32	for _, m := range vertexTestModels {
 33		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
 34	}
 35	testCommon(t, pairs)
 36}
 37
 38func TestGoogleThinking(t *testing.T) {
 39	opts := fantasy.ProviderOptions{
 40		google.Name: &google.ProviderOptions{
 41			ThinkingConfig: &google.ThinkingConfig{
 42				ThinkingBudget:  fantasy.Opt(int64(100)),
 43				IncludeThoughts: fantasy.Opt(true),
 44			},
 45		},
 46	}
 47
 48	var pairs []builderPair
 49	for _, m := range geminiTestModels {
 50		if !m.reasoning {
 51			continue
 52		}
 53		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
 54	}
 55	testThinking(t, pairs, testGoogleThinking)
 56}
 57
 58func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
 59	reasoningContentCount := 0
 60	// Test if we got the signature
 61	for _, step := range result.Steps {
 62		for _, msg := range step.Messages {
 63			for _, content := range msg.Content {
 64				if content.GetType() == fantasy.ContentTypeReasoning {
 65					reasoningContentCount += 1
 66				}
 67			}
 68		}
 69	}
 70	require.Greater(t, reasoningContentCount, 0)
 71}
 72
 73func generateIDMock() google.ToolCallIDFunc {
 74	id := 0
 75	return func() string {
 76		id++
 77		return fmt.Sprintf("%d", id)
 78	}
 79}
 80
 81func geminiBuilder(model string) builderFunc {
 82	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 83		provider, err := google.New(
 84			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
 85			google.WithHTTPClient(&http.Client{Transport: r}),
 86			google.WithToolCallIDFunc(generateIDMock()),
 87		)
 88		if err != nil {
 89			return nil, err
 90		}
 91		return provider.LanguageModel(t.Context(), model)
 92	}
 93}
 94
 95func vertexBuilder(model string) builderFunc {
 96	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 97		provider, err := google.New(
 98			google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
 99			google.WithHTTPClient(&http.Client{Transport: r}),
100			google.WithToolCallIDFunc(generateIDMock()),
101		)
102		if err != nil {
103			return nil, err
104		}
105		return provider.LanguageModel(t.Context(), model)
106	}
107}