1// Package main provides a command-line tool to fetch models from Hugging Face Router
2// and generate a configuration file for the provider.
3package main
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log"
11 "math"
12 "net/http"
13 "os"
14 "slices"
15 "strings"
16 "time"
17
18 "charm.land/catwalk/pkg/catwalk"
19)
20
21// SupportedProviders defines which providers we want to support.
22// Add or remove providers from this slice to control which ones are included.
23var SupportedProviders = []string{
24 // "together", // Multiple issues
25 "fireworks-ai",
26 //"nebius",
27 // "novita", // Usage report is wrong
28 "groq",
29 "cerebras",
30 // "hyperbolic",
31 // "nscale",
32 // "sambanova",
33 // "cohere",
34 "hf-inference",
35}
36
37// Model represents a model from the Hugging Face Router API.
38type Model struct {
39 ID string `json:"id"`
40 Object string `json:"object"`
41 Created int64 `json:"created"`
42 OwnedBy string `json:"owned_by"`
43 Providers []Provider `json:"providers"`
44}
45
46// Provider represents a provider configuration for a model.
47type Provider struct {
48 Provider string `json:"provider"`
49 Status string `json:"status"`
50 ContextLength int64 `json:"context_length,omitempty"`
51 Pricing *Pricing `json:"pricing,omitempty"`
52 SupportsTools bool `json:"supports_tools"`
53 SupportsStructuredOutput bool `json:"supports_structured_output"`
54}
55
56// Pricing contains the pricing information for a provider.
57type Pricing struct {
58 Input float64 `json:"input"`
59 Output float64 `json:"output"`
60}
61
62// ModelsResponse is the response structure for the Hugging Face Router models API.
63type ModelsResponse struct {
64 Object string `json:"object"`
65 Data []Model `json:"data"`
66}
67
68func fetchHuggingFaceModels() (*ModelsResponse, error) {
69 client := &http.Client{Timeout: 30 * time.Second}
70 req, _ := http.NewRequestWithContext(
71 context.Background(),
72 "GET",
73 "https://router.huggingface.co/v1/models",
74 nil,
75 )
76 req.Header.Set("User-Agent", "Crush-Client/1.0")
77 resp, err := client.Do(req)
78 if err != nil {
79 return nil, err //nolint:wrapcheck
80 }
81 defer resp.Body.Close() //nolint:errcheck
82 if resp.StatusCode != 200 {
83 body, _ := io.ReadAll(resp.Body)
84 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
85 }
86 var mr ModelsResponse
87 if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
88 return nil, err //nolint:wrapcheck
89 }
90 return &mr, nil
91}
92
93// findContextWindow looks for a context window from any provider for the given model.
94func findContextWindow(model Model) int64 {
95 for _, provider := range model.Providers {
96 if provider.ContextLength > 0 {
97 return provider.ContextLength
98 }
99 }
100 return 0
101}
102
103// WARN: DO NOT USE
104// for now we have a subset list of models we use.
105func main() {
106 modelsResp, err := fetchHuggingFaceModels()
107 if err != nil {
108 log.Fatal("Error fetching Hugging Face models:", err)
109 }
110
111 hfProvider := catwalk.Provider{
112 Name: "Hugging Face",
113 ID: catwalk.InferenceProviderHuggingFace,
114 APIKey: "$HF_TOKEN",
115 APIEndpoint: "https://router.huggingface.co/v1",
116 Type: catwalk.TypeOpenAICompat,
117 DefaultLargeModelID: "moonshotai/Kimi-K2.5:fireworks-ai",
118 DefaultSmallModelID: "openai/gpt-oss-20b:groq",
119 Models: []catwalk.Model{},
120 DefaultHeaders: map[string]string{
121 "HTTP-Referer": "https://charm.land",
122 "X-Title": "Crush",
123 },
124 }
125
126 for _, model := range modelsResp.Data {
127 // Find context window from any provider for this model
128 fallbackContextLength := findContextWindow(model)
129 if fallbackContextLength == 0 {
130 fmt.Printf("Skipping model %s - no context window found in any provider\n", model.ID)
131 continue
132 }
133
134 for _, provider := range model.Providers {
135 // Skip unsupported providers
136 if !slices.Contains(SupportedProviders, provider.Provider) {
137 continue
138 }
139
140 // Skip providers that don't support tools
141 if !provider.SupportsTools {
142 continue
143 }
144
145 // Skip non-live providers
146 if provider.Status != "live" {
147 continue
148 }
149
150 // Create model with provider-specific ID and name
151 modelID := fmt.Sprintf("%s:%s", model.ID, provider.Provider)
152 modelName := fmt.Sprintf("%s (%s)", model.ID, provider.Provider)
153
154 // Use provider's context length, or fallback if not available
155 contextLength := provider.ContextLength
156 if contextLength == 0 {
157 contextLength = fallbackContextLength
158 }
159
160 // Calculate pricing (convert from per-token to per-1M tokens)
161 var costPer1MIn, costPer1MOut float64
162 if provider.Pricing != nil {
163 costPer1MIn = math.Round(provider.Pricing.Input*1e5) / 1e5
164 costPer1MOut = math.Round(provider.Pricing.Output*1e5) / 1e5
165 }
166
167 // Set default max tokens (conservative estimate)
168 defaultMaxTokens := min(contextLength/4, 8192)
169
170 m := catwalk.Model{
171 ID: modelID,
172 Name: modelName,
173 CostPer1MIn: costPer1MIn,
174 CostPer1MOut: costPer1MOut,
175 CostPer1MInCached: 0, // Not provided by HF Router
176 CostPer1MOutCached: 0, // Not provided by HF Router
177 ContextWindow: contextLength,
178 DefaultMaxTokens: defaultMaxTokens,
179 CanReason: false, // Not provided by HF Router
180 SupportsImages: false, // Not provided by HF Router
181 }
182
183 hfProvider.Models = append(hfProvider.Models, m)
184 fmt.Printf("Added model %s with context window %d from provider %s\n",
185 modelID, contextLength, provider.Provider)
186 }
187 }
188
189 slices.SortFunc(hfProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
190 return strings.Compare(a.Name, b.Name)
191 })
192
193 // Save the JSON in internal/providers/configs/huggingface.json
194 data, err := json.MarshalIndent(hfProvider, "", " ")
195 if err != nil {
196 log.Fatal("Error marshaling Hugging Face provider:", err)
197 }
198
199 if err := os.WriteFile("internal/providers/configs/huggingface.json", data, 0o600); err != nil {
200 log.Fatal("Error writing Hugging Face provider config:", err)
201 }
202
203 fmt.Printf("Generated huggingface.json with %d models\n", len(hfProvider.Models))
204}