package providertests

import (
	"cmp"
	"fmt"
	"net/http"
	"os"
	"testing"

	"charm.land/fantasy"
	"charm.land/fantasy/providers/google"
	"charm.land/x/vcr"
	"github.com/stretchr/testify/require"
)

var geminiTestModels = []testModel{
	{"gemini-3-pro-preview", "gemini-3-pro-preview", true},
	{"gemini-2.5-flash", "gemini-2.5-flash", true},
	{"gemini-2.5-pro", "gemini-2.5-pro", true},
}

var vertexTestModels = []testModel{
	{"vertex-gemini-2-5-flash", "gemini-2.5-flash", true},
	{"vertex-gemini-2-5-pro", "gemini-2.5-pro", true},
	{"vertex-claude-3-7-sonnet", "claude-3-7-sonnet@20250219", true},
}

func TestGoogleCommon(t *testing.T) {
	var pairs []builderPair
	for _, m := range geminiTestModels {
		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
	}
	for _, m := range vertexTestModels {
		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
	}
	testCommon(t, pairs)
}

func TestGoogleThinking(t *testing.T) {
	opts := fantasy.ProviderOptions{
		google.Name: &google.ProviderOptions{
			ThinkingConfig: &google.ThinkingConfig{
				ThinkingBudget:  fantasy.Opt(int64(100)),
				IncludeThoughts: fantasy.Opt(true),
			},
		},
	}

	var pairs []builderPair
	for _, m := range geminiTestModels {
		if !m.reasoning {
			continue
		}
		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), opts, nil})
	}
	testThinking(t, pairs, testGoogleThinking)
}

func TestGoogleObjectGeneration(t *testing.T) {
	var pairs []builderPair
	for _, m := range geminiTestModels {
		pairs = append(pairs, builderPair{m.name, geminiBuilder(m.model), nil, nil})
	}
	testObjectGeneration(t, pairs)
}

func TestGoogleVertexObjectGeneration(t *testing.T) {
	var pairs []builderPair
	for _, m := range vertexTestModels {
		pairs = append(pairs, builderPair{m.name, vertexBuilder(m.model), nil, nil})
	}
	testObjectGeneration(t, pairs)
}

func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
	reasoningContentCount := 0
	// Test if we got the signature
	for _, step := range result.Steps {
		for _, msg := range step.Messages {
			for _, content := range msg.Content {
				if content.GetType() == fantasy.ContentTypeReasoning {
					reasoningContentCount += 1
				}
			}
		}
	}
	require.Greater(t, reasoningContentCount, 0)
}

func generateIDMock() google.ToolCallIDFunc {
	id := 0
	return func() string {
		id++
		return fmt.Sprintf("%d", id)
	}
}

func geminiBuilder(model string) builderFunc {
	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
		provider, err := google.New(
			google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
			google.WithHTTPClient(&http.Client{Transport: r}),
			google.WithToolCallIDFunc(generateIDMock()),
		)
		if err != nil {
			return nil, err
		}
		return provider.LanguageModel(t.Context(), model)
	}
}

func vertexBuilder(model string) builderFunc {
	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
		provider, err := google.New(
			google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
			google.WithHTTPClient(&http.Client{Transport: r}),
			google.WithSkipAuth(!r.IsRecording()),
			google.WithToolCallIDFunc(generateIDMock()),
		)
		if err != nil {
			return nil, err
		}
		return provider.LanguageModel(t.Context(), model)
	}
}
