google_test.go

  1package providertests
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"charm.land/fantasy/providers/google"
 12	"github.com/stretchr/testify/require"
 13	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 14)
 15
 16var geminiTestModels = []testModel{
 17	{"gemini-2.5-flash", "gemini-2.5-flash", true},
 18	{"gemini-2.5-pro", "gemini-2.5-pro", true},
 19}
 20
 21var vertexTestModels = []testModel{
 22	{"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
 23	{"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
 24	{"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
 25}
 26
 27func TestGoogleCommon(t *testing.T) {
 28	var pairs []builderPair
 29	for _, m := range geminiTestModels {
 30		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
 31	}
 32	for _, m := range vertexTestModels {
 33		// TODO: fixme
 34		continue
 35		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
 36	}
 37	testCommon(t, pairs)
 38}
 39
 40func TestGoogleThinking(t *testing.T) {
 41	opts := fantasy.ProviderOptions{
 42		google.Name: &google.ProviderOptions{
 43			ThinkingConfig: &google.ThinkingConfig{
 44				ThinkingBudget:  fantasy.Opt(int64(100)),
 45				IncludeThoughts: fantasy.Opt(true),
 46			},
 47		},
 48	}
 49
 50	var pairs []builderPair
 51	for _, m := range geminiTestModels {
 52		if !m.reasoning {
 53			continue
 54		}
 55		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
 56	}
 57	testThinking(t, pairs, testGoogleThinking)
 58}
 59
 60func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
 61	reasoningContentCount := 0
 62	// Test if we got the signature
 63	for _, step := range result.Steps {
 64		for _, msg := range step.Messages {
 65			for _, content := range msg.Content {
 66				if content.GetType() == fantasy.ContentTypeReasoning {
 67					reasoningContentCount += 1
 68				}
 69			}
 70		}
 71	}
 72	require.Greater(t, reasoningContentCount, 0)
 73}
 74
 75func generateIDMock() google.ToolCallIDFunc {
 76	id := 0
 77	return func() string {
 78		id++
 79		return fmt.Sprintf("%d", id)
 80	}
 81}
 82
 83func geminiBuilder(model string) builderFunc {
 84	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 85		provider, err := google.New(
 86			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
 87			google.WithHTTPClient(&http.Client{Transport: r}),
 88			google.WithToolCallIDFunc(generateIDMock()),
 89		)
 90		if err != nil {
 91			return nil, err
 92		}
 93		return provider.LanguageModel(t.Context(), model)
 94	}
 95}
 96
 97func vertexBuilder(model string) builderFunc {
 98	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 99		provider, err := google.New(
100			google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
101			google.WithHTTPClient(&http.Client{Transport: r}),
102			google.WithSkipAuth(!r.IsRecording()),
103			google.WithToolCallIDFunc(generateIDMock()),
104		)
105		if err != nil {
106			return nil, err
107		}
108		return provider.LanguageModel(t.Context(), model)
109	}
110}