google_test.go

 1package providertests
 2
 3import (
 4	"cmp"
 5	"net/http"
 6	"os"
 7	"testing"
 8
 9	"github.com/charmbracelet/fantasy/ai"
10	"github.com/charmbracelet/fantasy/google"
11	"github.com/stretchr/testify/require"
12	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15var geminiTestModels = []testModel{
16	{"gemini-2.5-flash", "gemini-2.5-flash", true},
17	{"gemini-2.5-pro", "gemini-2.5-pro", true},
18}
19
20var vertexTestModels = []testModel{
21	{"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
22	{"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
23	{"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
24}
25
26func TestGoogleCommon(t *testing.T) {
27	var pairs []builderPair
28	for _, m := range geminiTestModels {
29		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil})
30	}
31	for _, m := range vertexTestModels {
32		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil})
33	}
34	testCommon(t, pairs)
35}
36
37func TestGoogleThinking(t *testing.T) {
38	opts := ai.ProviderOptions{
39		google.Name: &google.ProviderOptions{
40			ThinkingConfig: &google.ThinkingConfig{
41				ThinkingBudget:  ai.Opt(int64(100)),
42				IncludeThoughts: ai.Opt(true),
43			},
44		},
45	}
46
47	var pairs []builderPair
48	for _, m := range geminiTestModels {
49		if !m.reasoning {
50			continue
51		}
52		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts})
53	}
54	testThinking(t, pairs, testGoogleThinking)
55}
56
57func testGoogleThinking(t *testing.T, result *ai.AgentResult) {
58	reasoningContentCount := 0
59	// Test if we got the signature
60	for _, step := range result.Steps {
61		for _, msg := range step.Messages {
62			for _, content := range msg.Content {
63				if content.GetType() == ai.ContentTypeReasoning {
64					reasoningContentCount += 1
65				}
66			}
67		}
68	}
69	require.Greater(t, reasoningContentCount, 0)
70}
71
72func geminiBuilder(model string) builderFunc {
73	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
74		provider := google.New(
75			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
76			google.WithHTTPClient(&http.Client{Transport: r}),
77		)
78		return provider.LanguageModel(model)
79	}
80}
81
82func vertexBuilder(model string) builderFunc {
83	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
84		provider := google.New(
85			google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
86			google.WithHTTPClient(&http.Client{Transport: r}),
87			google.WithSkipAuth(!r.IsRecording()),
88		)
89		return provider.LanguageModel(model)
90	}
91}