1package providertests
2
3import (
4 "cmp"
5 "net/http"
6 "os"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/charmbracelet/fantasy/google"
11 "github.com/stretchr/testify/require"
12 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15var googleTestModels = []testModel{
16 {"gemini-2.5-flash", "gemini-2.5-flash", true},
17 {"gemini-2.5-pro", "gemini-2.5-pro", true},
18}
19
20func TestGoogleCommon(t *testing.T) {
21 var pairs []builderPair
22 for _, m := range googleTestModels {
23 pairs = append(pairs, builderPair{m.name, googleBuilder(m.model), nil})
24 }
25 testCommon(t, pairs)
26}
27
28func TestGoogleThinking(t *testing.T) {
29 opts := ai.ProviderOptions{
30 google.Name: &google.ProviderOptions{
31 ThinkingConfig: &google.ThinkingConfig{
32 ThinkingBudget: ai.IntOption(100),
33 IncludeThoughts: ai.BoolOption(true),
34 },
35 },
36 }
37
38 var pairs []builderPair
39 for _, m := range googleTestModels {
40 if !m.reasoning {
41 continue
42 }
43 pairs = append(pairs, builderPair{m.name, googleBuilder(m.model), opts})
44 }
45 testThinking(t, pairs, testGoogleThinking)
46}
47
48func testGoogleThinking(t *testing.T, result *ai.AgentResult) {
49 reasoningContentCount := 0
50 // Test if we got the signature
51 for _, step := range result.Steps {
52 for _, msg := range step.Messages {
53 for _, content := range msg.Content {
54 if content.GetType() == ai.ContentTypeReasoning {
55 reasoningContentCount += 1
56 }
57 }
58 }
59 }
60 require.Greater(t, reasoningContentCount, 0)
61}
62
63func googleBuilder(model string) builderFunc {
64 return func(r *recorder.Recorder) (ai.LanguageModel, error) {
65 provider := google.New(
66 google.WithAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
67 google.WithHTTPClient(&http.Client{Transport: r}),
68 )
69 return provider.LanguageModel(model)
70 }
71}