1// Package main provides a command-line tool to fetch models from Hugging Face Router
2// and generate a configuration file for the provider.
3package main
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log"
11 "net/http"
12 "os"
13 "slices"
14 "strings"
15 "time"
16
17 "github.com/charmbracelet/catwalk/pkg/catwalk"
18)
19
20// SupportedProviders defines which providers we want to support.
21// Add or remove providers from this slice to control which ones are included.
22var SupportedProviders = []string{
23 // "together", // Multiple issues
24 "fireworks-ai",
25 //"nebius",
26 // "novita", // Usage report is wrong
27 "groq",
28 "cerebras",
29 // "hyperbolic",
30 // "nscale",
31 // "sambanova",
32 // "cohere",
33 "hf-inference",
34}
35
36// Model represents a model from the Hugging Face Router API.
37type Model struct {
38 ID string `json:"id"`
39 Object string `json:"object"`
40 Created int64 `json:"created"`
41 OwnedBy string `json:"owned_by"`
42 Providers []Provider `json:"providers"`
43}
44
45// Provider represents a provider configuration for a model.
46type Provider struct {
47 Provider string `json:"provider"`
48 Status string `json:"status"`
49 ContextLength int64 `json:"context_length,omitempty"`
50 Pricing *Pricing `json:"pricing,omitempty"`
51 SupportsTools bool `json:"supports_tools"`
52 SupportsStructuredOutput bool `json:"supports_structured_output"`
53}
54
55// Pricing contains the pricing information for a provider.
56type Pricing struct {
57 Input float64 `json:"input"`
58 Output float64 `json:"output"`
59}
60
61// ModelsResponse is the response structure for the Hugging Face Router models API.
62type ModelsResponse struct {
63 Object string `json:"object"`
64 Data []Model `json:"data"`
65}
66
67func fetchHuggingFaceModels() (*ModelsResponse, error) {
68 client := &http.Client{Timeout: 30 * time.Second}
69 req, _ := http.NewRequestWithContext(
70 context.Background(),
71 "GET",
72 "https://router.huggingface.co/v1/models",
73 nil,
74 )
75 req.Header.Set("User-Agent", "Crush-Client/1.0")
76 resp, err := client.Do(req)
77 if err != nil {
78 return nil, err //nolint:wrapcheck
79 }
80 defer resp.Body.Close() //nolint:errcheck
81 if resp.StatusCode != 200 {
82 body, _ := io.ReadAll(resp.Body)
83 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
84 }
85 var mr ModelsResponse
86 if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
87 return nil, err //nolint:wrapcheck
88 }
89 return &mr, nil
90}
91
92// findContextWindow looks for a context window from any provider for the given model.
93func findContextWindow(model Model) int64 {
94 for _, provider := range model.Providers {
95 if provider.ContextLength > 0 {
96 return provider.ContextLength
97 }
98 }
99 return 0
100}
101
102// WARN: DO NOT USE
103// for now we have a subset list of models we use.
104func main() {
105 modelsResp, err := fetchHuggingFaceModels()
106 if err != nil {
107 log.Fatal("Error fetching Hugging Face models:", err)
108 }
109
110 hfProvider := catwalk.Provider{
111 Name: "Hugging Face",
112 ID: catwalk.InferenceProviderHuggingFace,
113 APIKey: "$HF_TOKEN",
114 APIEndpoint: "https://router.huggingface.co/v1",
115 Type: catwalk.TypeOpenAI,
116 DefaultLargeModelID: "moonshotai/Kimi-K2-Instruct-0905:groq",
117 DefaultSmallModelID: "openai/gpt-oss-20b",
118 Models: []catwalk.Model{},
119 DefaultHeaders: map[string]string{
120 "HTTP-Referer": "https://charm.land",
121 "X-Title": "Crush",
122 },
123 }
124
125 for _, model := range modelsResp.Data {
126 // Find context window from any provider for this model
127 fallbackContextLength := findContextWindow(model)
128 if fallbackContextLength == 0 {
129 fmt.Printf("Skipping model %s - no context window found in any provider\n", model.ID)
130 continue
131 }
132
133 for _, provider := range model.Providers {
134 // Skip unsupported providers
135 if !slices.Contains(SupportedProviders, provider.Provider) {
136 continue
137 }
138
139 // Skip providers that don't support tools
140 if !provider.SupportsTools {
141 continue
142 }
143
144 // Skip non-live providers
145 if provider.Status != "live" {
146 continue
147 }
148
149 // Create model with provider-specific ID and name
150 modelID := fmt.Sprintf("%s:%s", model.ID, provider.Provider)
151 modelName := fmt.Sprintf("%s (%s)", model.ID, provider.Provider)
152
153 // Use provider's context length, or fallback if not available
154 contextLength := provider.ContextLength
155 if contextLength == 0 {
156 contextLength = fallbackContextLength
157 }
158
159 // Calculate pricing (convert from per-token to per-1M tokens)
160 var costPer1MIn, costPer1MOut float64
161 if provider.Pricing != nil {
162 costPer1MIn = provider.Pricing.Input
163 costPer1MOut = provider.Pricing.Output
164 }
165
166 // Set default max tokens (conservative estimate)
167 defaultMaxTokens := min(contextLength/4, 8192)
168
169 m := catwalk.Model{
170 ID: modelID,
171 Name: modelName,
172 CostPer1MIn: costPer1MIn,
173 CostPer1MOut: costPer1MOut,
174 CostPer1MInCached: 0, // Not provided by HF Router
175 CostPer1MOutCached: 0, // Not provided by HF Router
176 ContextWindow: contextLength,
177 DefaultMaxTokens: defaultMaxTokens,
178 CanReason: false, // Not provided by HF Router
179 SupportsImages: false, // Not provided by HF Router
180 }
181
182 hfProvider.Models = append(hfProvider.Models, m)
183 fmt.Printf("Added model %s with context window %d from provider %s\n",
184 modelID, contextLength, provider.Provider)
185 }
186 }
187
188 slices.SortFunc(hfProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
189 return strings.Compare(a.Name, b.Name)
190 })
191
192 // Save the JSON in internal/providers/configs/huggingface.json
193 data, err := json.MarshalIndent(hfProvider, "", " ")
194 if err != nil {
195 log.Fatal("Error marshaling Hugging Face provider:", err)
196 }
197
198 if err := os.WriteFile("internal/providers/configs/huggingface.json", data, 0o600); err != nil {
199 log.Fatal("Error writing Hugging Face provider config:", err)
200 }
201
202 fmt.Printf("Generated huggingface.json with %d models\n", len(hfProvider.Models))
203}