// Package main provides a command-line tool to fetch models from Venice
// and generate a configuration file for the provider.
package main

import (
	"context"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"math"
	"net/http"
	"os"
	"slices"
	"strings"
	"time"

	"charm.land/catwalk/pkg/catwalk"
)

type ModelsResponse struct {
	Data []VeniceModel `json:"data"`
}

type VeniceModel struct {
	ID        string          `json:"id"`
	ModelSpec VeniceModelSpec `json:"model_spec"`
	Type      string          `json:"type"`
}

type VeniceModelSpec struct {
	AvailableContextTokens int64                   `json:"availableContextTokens"`
	MaxCompletionTokens    int64                   `json:"maxCompletionTokens"`
	Capabilities           VeniceModelCapabilities `json:"capabilities"`
	Constraints            VeniceModelConstraints  `json:"constraints"`
	Name                   string                  `json:"name"`
	Offline                bool                    `json:"offline"`
	Pricing                VeniceModelPricing      `json:"pricing"`
	Beta                   bool                    `json:"beta"`
}

type VeniceModelCapabilities struct {
	OptimizedForCode        bool `json:"optimizedForCode"`
	SupportsFunctionCalling bool `json:"supportsFunctionCalling"`
	SupportsReasoning       bool `json:"supportsReasoning"`
	SupportsReasoningEffort bool `json:"supportsReasoningEffort"`
	SupportsVision          bool `json:"supportsVision"`
}

type VeniceModelConstraints struct {
	Temperature *VeniceDefaultFloat `json:"temperature"`
	TopP        *VeniceDefaultFloat `json:"top_p"`
}

type VeniceDefaultFloat struct {
	Default float64 `json:"default"`
}

type VeniceModelPricing struct {
	Input  VeniceModelPricingValue `json:"input"`
	Output VeniceModelPricingValue `json:"output"`
}

type VeniceModelPricingValue struct {
	USD float64 `json:"usd"`
}

func fetchVeniceModels(apiEndpoint string) (*ModelsResponse, error) {
	client := &http.Client{Timeout: 30 * time.Second}
	url := strings.TrimRight(apiEndpoint, "/") + "/models"
	req, _ := http.NewRequestWithContext(context.Background(), "GET", url, nil)
	req.Header.Set("User-Agent", "Crush-Client/1.0")

	if apiKey := strings.TrimSpace(os.Getenv("VENICE_API_KEY")); apiKey != "" && !strings.HasPrefix(apiKey, "$") {
		req.Header.Set("Authorization", "Bearer "+apiKey)
	}

	resp, err := client.Do(req)
	if err != nil {
		return nil, err //nolint:wrapcheck
	}
	defer resp.Body.Close() //nolint:errcheck
	if resp.StatusCode != 200 {
		body, _ := io.ReadAll(resp.Body)
		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
	}

	var mr ModelsResponse
	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
		return nil, err //nolint:wrapcheck
	}
	return &mr, nil
}

func bestLargeModelID(models []catwalk.Model) string {
	var best *catwalk.Model
	for i := range models {
		m := &models[i]

		if best == nil {
			best = m
			continue
		}
		mCost := m.CostPer1MIn + m.CostPer1MOut
		bestCost := best.CostPer1MIn + best.CostPer1MOut
		if mCost > bestCost {
			best = m
			continue
		}
		if mCost == bestCost && m.ContextWindow > best.ContextWindow {
			best = m
		}
	}
	if best == nil {
		return ""
	}
	return best.ID
}

func bestSmallModelID(models []catwalk.Model) string {
	var best *catwalk.Model
	for i := range models {
		m := &models[i]
		if best == nil {
			best = m
			continue
		}
		mCost := m.CostPer1MIn + m.CostPer1MOut
		bestCost := best.CostPer1MIn + best.CostPer1MOut
		if mCost < bestCost {
			best = m
			continue
		}
		if mCost == bestCost && m.ContextWindow < best.ContextWindow {
			best = m
		}
	}
	if best == nil {
		return ""
	}
	return best.ID
}

func main() {
	veniceProvider := catwalk.Provider{
		Name:        "Venice AI",
		ID:          catwalk.InferenceProviderVenice,
		APIKey:      "$VENICE_API_KEY",
		APIEndpoint: "https://api.venice.ai/api/v1",
		Type:        catwalk.TypeOpenAICompat,
		Models:      []catwalk.Model{},
	}

	var codeOptimizedModels []catwalk.Model

	modelsResp, err := fetchVeniceModels(veniceProvider.APIEndpoint)
	if err != nil {
		log.Fatal("Error fetching Venice models:", err)
	}

	for _, model := range modelsResp.Data {
		if strings.ToLower(model.Type) != "text" {
			continue
		}
		if model.ModelSpec.Offline {
			continue
		}
		if !model.ModelSpec.Capabilities.SupportsFunctionCalling {
			continue
		}

		if model.ModelSpec.Beta {
			continue
		}

		contextWindow := model.ModelSpec.AvailableContextTokens
		if contextWindow <= 0 {
			continue
		}

		var (
			canReason            = model.ModelSpec.Capabilities.SupportsReasoning
			supportsReasonEffort = model.ModelSpec.Capabilities.SupportsReasoningEffort
		)
		var reasoningLevels []string
		var defaultReasoning string
		if canReason && supportsReasonEffort {
			reasoningLevels = []string{"low", "medium", "high"}
			defaultReasoning = "medium"
		}

		options := catwalk.ModelOptions{}
		if model.ModelSpec.Constraints.Temperature != nil {
			v := model.ModelSpec.Constraints.Temperature.Default
			if !math.IsNaN(v) {
				options.Temperature = &v
			}
		}
		if model.ModelSpec.Constraints.TopP != nil {
			v := model.ModelSpec.Constraints.TopP.Default
			if !math.IsNaN(v) {
				options.TopP = &v
			}
		}

		roundCost := func(v float64) float64 { return math.Round(v*1e5) / 1e5 }
		m := catwalk.Model{
			ID:                     model.ID,
			Name:                   model.ModelSpec.Name,
			CostPer1MIn:            roundCost(model.ModelSpec.Pricing.Input.USD),
			CostPer1MOut:           roundCost(model.ModelSpec.Pricing.Output.USD),
			CostPer1MInCached:      0,
			CostPer1MOutCached:     0,
			ContextWindow:          contextWindow,
			DefaultMaxTokens:       model.ModelSpec.MaxCompletionTokens,
			CanReason:              canReason,
			ReasoningLevels:        reasoningLevels,
			DefaultReasoningEffort: defaultReasoning,
			SupportsImages:         model.ModelSpec.Capabilities.SupportsVision,
			Options:                options,
		}

		veniceProvider.Models = append(veniceProvider.Models, m)
		if model.ModelSpec.Capabilities.OptimizedForCode {
			codeOptimizedModels = append(codeOptimizedModels, m)
		}
	}

	candidateModels := veniceProvider.Models
	if len(codeOptimizedModels) > 0 {
		candidateModels = codeOptimizedModels
	}

	veniceProvider.DefaultLargeModelID = bestLargeModelID(candidateModels)
	veniceProvider.DefaultSmallModelID = bestSmallModelID(candidateModels)

	slices.SortFunc(veniceProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
		return strings.Compare(a.Name, b.Name)
	})

	data, err := json.MarshalIndent(veniceProvider, "", "  ")
	if err != nil {
		log.Fatal("Error marshaling Venice provider:", err)
	}
	data = append(data, '\n')

	if err := os.WriteFile("internal/providers/configs/venice.json", data, 0o600); err != nil {
		log.Fatal("Error writing Venice provider config:", err)
	}

	fmt.Printf("Generated venice.json with %d models\n", len(veniceProvider.Models))
}
