google_test.go

 1package providertests
 2
 3import (
 4	"cmp"
 5	"net/http"
 6	"os"
 7	"testing"
 8
 9	"github.com/charmbracelet/fantasy/ai"
10	"github.com/charmbracelet/fantasy/google"
11	"github.com/stretchr/testify/require"
12	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15func TestGoogleCommon(t *testing.T) {
16	testCommon(t, []builderPair{
17		{"gemini-2.5-flash", builderGoogleGemini25Flash, nil},
18		{"gemini-2.5-pro", builderGoogleGemini25Pro, nil},
19	})
20	opts := ai.ProviderOptions{
21		google.Name: &google.ProviderOptions{
22			ThinkingConfig: &google.ThinkingConfig{
23				ThinkingBudget:  ai.IntOption(100),
24				IncludeThoughts: ai.BoolOption(true),
25			},
26		},
27	}
28	testThinking(t, []builderPair{
29		{"gemini-2.5-flash", builderGoogleGemini25Flash, opts},
30		{"gemini-2.5-pro", builderGoogleGemini25Pro, opts},
31	}, testGoogleThinking)
32}
33
34func testGoogleThinking(t *testing.T, result *ai.AgentResult) {
35	reasoningContentCount := 0
36	// Test if we got the signature
37	for _, step := range result.Steps {
38		for _, msg := range step.Messages {
39			for _, content := range msg.Content {
40				if content.GetType() == ai.ContentTypeReasoning {
41					reasoningContentCount += 1
42				}
43			}
44		}
45	}
46	require.Greater(t, reasoningContentCount, 0)
47}
48
49func builderGoogleGemini25Flash(r *recorder.Recorder) (ai.LanguageModel, error) {
50	provider := google.New(
51		google.WithAPIKey(cmp.Or(os.Getenv("GEMINI_API_KEY"), "(missing)")),
52		google.WithHTTPClient(&http.Client{Transport: r}),
53	)
54	return provider.LanguageModel("gemini-2.5-flash")
55}
56
57func builderGoogleGemini25Pro(r *recorder.Recorder) (ai.LanguageModel, error) {
58	provider := google.New(
59		google.WithAPIKey(cmp.Or(os.Getenv("GEMINI_API_KEY"), "(missing)")),
60		google.WithHTTPClient(&http.Client{Transport: r}),
61	)
62	return provider.LanguageModel("gemini-2.5-pro")
63}