1package providertests
2
3import (
4 "cmp"
5 "net/http"
6 "os"
7 "testing"
8
9 "github.com/charmbracelet/fantasy/ai"
10 "github.com/charmbracelet/fantasy/google"
11 "github.com/stretchr/testify/require"
12 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
13)
14
15func TestGoogleCommon(t *testing.T) {
16 testCommon(t, []builderPair{
17 {"gemini-2.5-flash", builderGoogleGemini25Flash, nil},
18 {"gemini-2.5-pro", builderGoogleGemini25Pro, nil},
19 })
20 opts := ai.ProviderOptions{
21 google.Name: &google.ProviderOptions{
22 ThinkingConfig: &google.ThinkingConfig{
23 ThinkingBudget: ai.IntOption(100),
24 IncludeThoughts: ai.BoolOption(true),
25 },
26 },
27 }
28 testThinking(t, []builderPair{
29 {"gemini-2.5-flash", builderGoogleGemini25Flash, opts},
30 {"gemini-2.5-pro", builderGoogleGemini25Pro, opts},
31 }, testGoogleThinking)
32}
33
34func testGoogleThinking(t *testing.T, result *ai.AgentResult) {
35 reasoningContentCount := 0
36 // Test if we got the signature
37 for _, step := range result.Steps {
38 for _, msg := range step.Messages {
39 for _, content := range msg.Content {
40 if content.GetType() == ai.ContentTypeReasoning {
41 reasoningContentCount += 1
42 }
43 }
44 }
45 }
46 require.Greater(t, reasoningContentCount, 0)
47}
48
49func builderGoogleGemini25Flash(r *recorder.Recorder) (ai.LanguageModel, error) {
50 provider := google.New(
51 google.WithAPIKey(cmp.Or(os.Getenv("GEMINI_API_KEY"), "(missing)")),
52 google.WithHTTPClient(&http.Client{Transport: r}),
53 )
54 return provider.LanguageModel("gemini-2.5-flash")
55}
56
57func builderGoogleGemini25Pro(r *recorder.Recorder) (ai.LanguageModel, error) {
58 provider := google.New(
59 google.WithAPIKey(cmp.Or(os.Getenv("GEMINI_API_KEY"), "(missing)")),
60 google.WithHTTPClient(&http.Client{Transport: r}),
61 )
62 return provider.LanguageModel("gemini-2.5-pro")
63}