main.go

  1// Package main provides a command-line tool to fetch models from Hugging Face Router
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"math"
 12	"net/http"
 13	"os"
 14	"slices"
 15	"strings"
 16	"time"
 17
 18	"charm.land/catwalk/pkg/catwalk"
 19)
 20
 21// SupportedProviders defines which providers we want to support.
 22// Add or remove providers from this slice to control which ones are included.
 23var SupportedProviders = []string{
 24	// "together", // Multiple issues
 25	"fireworks-ai",
 26	//"nebius",
 27	// "novita", // Usage report is wrong
 28	"groq",
 29	"cerebras",
 30	// "hyperbolic",
 31	// "nscale",
 32	// "sambanova",
 33	// "cohere",
 34	"hf-inference",
 35}
 36
 37// Model represents a model from the Hugging Face Router API.
 38type Model struct {
 39	ID        string     `json:"id"`
 40	Object    string     `json:"object"`
 41	Created   int64      `json:"created"`
 42	OwnedBy   string     `json:"owned_by"`
 43	Providers []Provider `json:"providers"`
 44}
 45
 46// Provider represents a provider configuration for a model.
 47type Provider struct {
 48	Provider                 string   `json:"provider"`
 49	Status                   string   `json:"status"`
 50	ContextLength            int64    `json:"context_length,omitempty"`
 51	Pricing                  *Pricing `json:"pricing,omitempty"`
 52	SupportsTools            bool     `json:"supports_tools"`
 53	SupportsStructuredOutput bool     `json:"supports_structured_output"`
 54}
 55
 56// Pricing contains the pricing information for a provider.
 57type Pricing struct {
 58	Input  float64 `json:"input"`
 59	Output float64 `json:"output"`
 60}
 61
 62// ModelsResponse is the response structure for the Hugging Face Router models API.
 63type ModelsResponse struct {
 64	Object string  `json:"object"`
 65	Data   []Model `json:"data"`
 66}
 67
 68func fetchHuggingFaceModels() (*ModelsResponse, error) {
 69	client := &http.Client{Timeout: 30 * time.Second}
 70	req, _ := http.NewRequestWithContext(
 71		context.Background(),
 72		"GET",
 73		"https://router.huggingface.co/v1/models",
 74		nil,
 75	)
 76	req.Header.Set("User-Agent", "Crush-Client/1.0")
 77	resp, err := client.Do(req)
 78	if err != nil {
 79		return nil, err //nolint:wrapcheck
 80	}
 81	defer resp.Body.Close() //nolint:errcheck
 82	if resp.StatusCode != 200 {
 83		body, _ := io.ReadAll(resp.Body)
 84		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
 85	}
 86	var mr ModelsResponse
 87	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
 88		return nil, err //nolint:wrapcheck
 89	}
 90	return &mr, nil
 91}
 92
 93// findContextWindow looks for a context window from any provider for the given model.
 94func findContextWindow(model Model) int64 {
 95	for _, provider := range model.Providers {
 96		if provider.ContextLength > 0 {
 97			return provider.ContextLength
 98		}
 99	}
100	return 0
101}
102
103// WARN: DO NOT USE
104// for now we have a subset list of models we use.
105func main() {
106	modelsResp, err := fetchHuggingFaceModels()
107	if err != nil {
108		log.Fatal("Error fetching Hugging Face models:", err)
109	}
110
111	hfProvider := catwalk.Provider{
112		Name:                "Hugging Face",
113		ID:                  catwalk.InferenceProviderHuggingFace,
114		APIKey:              "$HF_TOKEN",
115		APIEndpoint:         "https://router.huggingface.co/v1",
116		Type:                catwalk.TypeOpenAICompat,
117		DefaultLargeModelID: "moonshotai/Kimi-K2-Instruct-0905:groq",
118		DefaultSmallModelID: "openai/gpt-oss-20b:groq",
119		Models:              []catwalk.Model{},
120		DefaultHeaders: map[string]string{
121			"HTTP-Referer": "https://charm.land",
122			"X-Title":      "Crush",
123		},
124	}
125
126	for _, model := range modelsResp.Data {
127		// Find context window from any provider for this model
128		fallbackContextLength := findContextWindow(model)
129		if fallbackContextLength == 0 {
130			fmt.Printf("Skipping model %s - no context window found in any provider\n", model.ID)
131			continue
132		}
133
134		for _, provider := range model.Providers {
135			// Skip unsupported providers
136			if !slices.Contains(SupportedProviders, provider.Provider) {
137				continue
138			}
139
140			// Skip providers that don't support tools
141			if !provider.SupportsTools {
142				continue
143			}
144
145			// Skip non-live providers
146			if provider.Status != "live" {
147				continue
148			}
149
150			// Create model with provider-specific ID and name
151			modelID := fmt.Sprintf("%s:%s", model.ID, provider.Provider)
152			modelName := fmt.Sprintf("%s (%s)", model.ID, provider.Provider)
153
154			// Use provider's context length, or fallback if not available
155			contextLength := provider.ContextLength
156			if contextLength == 0 {
157				contextLength = fallbackContextLength
158			}
159
160			// Calculate pricing (convert from per-token to per-1M tokens)
161			var costPer1MIn, costPer1MOut float64
162			if provider.Pricing != nil {
163				costPer1MIn = math.Round(provider.Pricing.Input*1e5) / 1e5
164				costPer1MOut = math.Round(provider.Pricing.Output*1e5) / 1e5
165			}
166
167			// Set default max tokens (conservative estimate)
168			defaultMaxTokens := min(contextLength/4, 8192)
169
170			m := catwalk.Model{
171				ID:                 modelID,
172				Name:               modelName,
173				CostPer1MIn:        costPer1MIn,
174				CostPer1MOut:       costPer1MOut,
175				CostPer1MInCached:  0, // Not provided by HF Router
176				CostPer1MOutCached: 0, // Not provided by HF Router
177				ContextWindow:      contextLength,
178				DefaultMaxTokens:   defaultMaxTokens,
179				CanReason:          false, // Not provided by HF Router
180				SupportsImages:     false, // Not provided by HF Router
181			}
182
183			hfProvider.Models = append(hfProvider.Models, m)
184			fmt.Printf("Added model %s with context window %d from provider %s\n",
185				modelID, contextLength, provider.Provider)
186		}
187	}
188
189	slices.SortFunc(hfProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
190		return strings.Compare(a.Name, b.Name)
191	})
192
193	// Save the JSON in internal/providers/configs/huggingface.json
194	data, err := json.MarshalIndent(hfProvider, "", "  ")
195	if err != nil {
196		log.Fatal("Error marshaling Hugging Face provider:", err)
197	}
198
199	if err := os.WriteFile("internal/providers/configs/huggingface.json", data, 0o600); err != nil {
200		log.Fatal("Error writing Hugging Face provider config:", err)
201	}
202
203	fmt.Printf("Generated huggingface.json with %d models\n", len(hfProvider.Models))
204}