google_test.go

 1package providertests
 2
 3import (
 4	"cmp"
 5	"net/http"
 6	"os"
 7	"testing"
 8
 9	"github.com/charmbracelet/fantasy/ai"
10	"github.com/charmbracelet/fantasy/google"
11	"github.com/stretchr/testify/require"
12	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15var googleTestModels = []testModel{
16	{"gemini-2.5-flash", "gemini-2.5-flash", true},
17	{"gemini-2.5-pro", "gemini-2.5-pro", true},
18}
19
20func TestGoogleCommon(t *testing.T) {
21	var pairs []builderPair
22	for _, m := range googleTestModels {
23		pairs = append(pairs, builderPair{m.name, googleBuilder(m.model), nil})
24	}
25	testCommon(t, pairs)
26}
27
28func TestGoogleThinking(t *testing.T) {
29	opts := ai.ProviderOptions{
30		google.Name: &google.ProviderOptions{
31			ThinkingConfig: &google.ThinkingConfig{
32				ThinkingBudget:  ai.IntOption(100),
33				IncludeThoughts: ai.BoolOption(true),
34			},
35		},
36	}
37
38	var pairs []builderPair
39	for _, m := range googleTestModels {
40		if !m.reasoning {
41			continue
42		}
43		pairs = append(pairs, builderPair{m.name, googleBuilder(m.model), opts})
44	}
45	testThinking(t, pairs, testGoogleThinking)
46}
47
48func testGoogleThinking(t *testing.T, result *ai.AgentResult) {
49	reasoningContentCount := 0
50	// Test if we got the signature
51	for _, step := range result.Steps {
52		for _, msg := range step.Messages {
53			for _, content := range msg.Content {
54				if content.GetType() == ai.ContentTypeReasoning {
55					reasoningContentCount += 1
56				}
57			}
58		}
59	}
60	require.Greater(t, reasoningContentCount, 0)
61}
62
63func googleBuilder(model string) builderFunc {
64	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
65		provider := google.New(
66			google.WithAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
67			google.WithHTTPClient(&http.Client{Transport: r}),
68		)
69		return provider.LanguageModel(model)
70	}
71}