1// Package main provides a command-line tool to fetch models from Hugging Face Router
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"net/http"
 12	"os"
 13	"slices"
 14	"strings"
 15	"time"
 16
 17	"github.com/charmbracelet/catwalk/pkg/catwalk"
 18)
 19
 20// SupportedProviders defines which providers we want to support.
 21// Add or remove providers from this slice to control which ones are included.
 22var SupportedProviders = []string{
 23	// "together", // Multiple issues
 24	"fireworks-ai",
 25	//"nebius",
 26	// "novita", // Usage report is wrong
 27	"groq",
 28	"cerebras",
 29	// "hyperbolic",
 30	// "nscale",
 31	// "sambanova",
 32	// "cohere",
 33	"hf-inference",
 34}
 35
 36// Model represents a model from the Hugging Face Router API.
 37type Model struct {
 38	ID        string     `json:"id"`
 39	Object    string     `json:"object"`
 40	Created   int64      `json:"created"`
 41	OwnedBy   string     `json:"owned_by"`
 42	Providers []Provider `json:"providers"`
 43}
 44
 45// Provider represents a provider configuration for a model.
 46type Provider struct {
 47	Provider                 string   `json:"provider"`
 48	Status                   string   `json:"status"`
 49	ContextLength            int64    `json:"context_length,omitempty"`
 50	Pricing                  *Pricing `json:"pricing,omitempty"`
 51	SupportsTools            bool     `json:"supports_tools"`
 52	SupportsStructuredOutput bool     `json:"supports_structured_output"`
 53}
 54
 55// Pricing contains the pricing information for a provider.
 56type Pricing struct {
 57	Input  float64 `json:"input"`
 58	Output float64 `json:"output"`
 59}
 60
 61// ModelsResponse is the response structure for the Hugging Face Router models API.
 62type ModelsResponse struct {
 63	Object string  `json:"object"`
 64	Data   []Model `json:"data"`
 65}
 66
 67func fetchHuggingFaceModels() (*ModelsResponse, error) {
 68	client := &http.Client{Timeout: 30 * time.Second}
 69	req, _ := http.NewRequestWithContext(
 70		context.Background(),
 71		"GET",
 72		"https://router.huggingface.co/v1/models",
 73		nil,
 74	)
 75	req.Header.Set("User-Agent", "Crush-Client/1.0")
 76	resp, err := client.Do(req)
 77	if err != nil {
 78		return nil, err //nolint:wrapcheck
 79	}
 80	defer resp.Body.Close() //nolint:errcheck
 81	if resp.StatusCode != 200 {
 82		body, _ := io.ReadAll(resp.Body)
 83		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
 84	}
 85	var mr ModelsResponse
 86	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
 87		return nil, err //nolint:wrapcheck
 88	}
 89	return &mr, nil
 90}
 91
 92// findContextWindow looks for a context window from any provider for the given model.
 93func findContextWindow(model Model) int64 {
 94	for _, provider := range model.Providers {
 95		if provider.ContextLength > 0 {
 96			return provider.ContextLength
 97		}
 98	}
 99	return 0
100}
101
102// WARN: DO NOT USE
103// for now we have a subset list of models we use.
104func main() {
105	modelsResp, err := fetchHuggingFaceModels()
106	if err != nil {
107		log.Fatal("Error fetching Hugging Face models:", err)
108	}
109
110	hfProvider := catwalk.Provider{
111		Name:                "Hugging Face",
112		ID:                  catwalk.InferenceProviderHuggingFace,
113		APIKey:              "$HF_TOKEN",
114		APIEndpoint:         "https://router.huggingface.co/v1",
115		Type:                catwalk.TypeOpenAI,
116		DefaultLargeModelID: "moonshotai/Kimi-K2-Instruct-0905:groq",
117		DefaultSmallModelID: "openai/gpt-oss-20b",
118		Models:              []catwalk.Model{},
119		DefaultHeaders: map[string]string{
120			"HTTP-Referer": "https://charm.land",
121			"X-Title":      "Crush",
122		},
123	}
124
125	for _, model := range modelsResp.Data {
126		// Find context window from any provider for this model
127		fallbackContextLength := findContextWindow(model)
128		if fallbackContextLength == 0 {
129			fmt.Printf("Skipping model %s - no context window found in any provider\n", model.ID)
130			continue
131		}
132
133		for _, provider := range model.Providers {
134			// Skip unsupported providers
135			if !slices.Contains(SupportedProviders, provider.Provider) {
136				continue
137			}
138
139			// Skip providers that don't support tools
140			if !provider.SupportsTools {
141				continue
142			}
143
144			// Skip non-live providers
145			if provider.Status != "live" {
146				continue
147			}
148
149			// Create model with provider-specific ID and name
150			modelID := fmt.Sprintf("%s:%s", model.ID, provider.Provider)
151			modelName := fmt.Sprintf("%s (%s)", model.ID, provider.Provider)
152
153			// Use provider's context length, or fallback if not available
154			contextLength := provider.ContextLength
155			if contextLength == 0 {
156				contextLength = fallbackContextLength
157			}
158
159			// Calculate pricing (convert from per-token to per-1M tokens)
160			var costPer1MIn, costPer1MOut float64
161			if provider.Pricing != nil {
162				costPer1MIn = provider.Pricing.Input
163				costPer1MOut = provider.Pricing.Output
164			}
165
166			// Set default max tokens (conservative estimate)
167			defaultMaxTokens := min(contextLength/4, 8192)
168
169			m := catwalk.Model{
170				ID:                 modelID,
171				Name:               modelName,
172				CostPer1MIn:        costPer1MIn,
173				CostPer1MOut:       costPer1MOut,
174				CostPer1MInCached:  0, // Not provided by HF Router
175				CostPer1MOutCached: 0, // Not provided by HF Router
176				ContextWindow:      contextLength,
177				DefaultMaxTokens:   defaultMaxTokens,
178				CanReason:          false, // Not provided by HF Router
179				SupportsImages:     false, // Not provided by HF Router
180			}
181
182			hfProvider.Models = append(hfProvider.Models, m)
183			fmt.Printf("Added model %s with context window %d from provider %s\n",
184				modelID, contextLength, provider.Provider)
185		}
186	}
187
188	slices.SortFunc(hfProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
189		return strings.Compare(a.Name, b.Name)
190	})
191
192	// Save the JSON in internal/providers/configs/huggingface.json
193	data, err := json.MarshalIndent(hfProvider, "", "  ")
194	if err != nil {
195		log.Fatal("Error marshaling Hugging Face provider:", err)
196	}
197
198	if err := os.WriteFile("internal/providers/configs/huggingface.json", data, 0o600); err != nil {
199		log.Fatal("Error writing Hugging Face provider config:", err)
200	}
201
202	fmt.Printf("Generated huggingface.json with %d models\n", len(hfProvider.Models))
203}