feat: support hugging face inference router (#59)

Kujtim Hoxha and Crush created

* feat: support hugging face inference router

* chore: lint

* chore: add huggingface script to nightly update

* fix: remove HF prefix from types in huggingface package

Remove redundant HF prefix from type names since they are already in the huggingface package namespace.

💖 Generated with Crush
Co-Authored-By: Crush <crush@charm.land>

* chore: small changes

* chore: only add models that seem to be working

* chore: lint

---------

Co-authored-by: Crush <crush@charm.land>

Change summary

.github/workflows/update.yml                |   2 
cmd/huggingface/main.go                     | 203 ++++++++++++++
internal/providers/configs/huggingface.json | 314 +++++++++++++++++++++++
internal/providers/providers.go             |   8 
pkg/catwalk/provider.go                     |  30 +-
5 files changed, 543 insertions(+), 14 deletions(-)

Detailed changes

.github/workflows/update.yml 🔗

@@ -18,6 +18,8 @@ jobs:
         with:
           go-version-file: go.mod
       - run: go run ./cmd/openrouter/main.go
+      # we need to add this back when we know that the providers/models all work
+      # - run: go run ./cmd/huggingface/main.go
       - uses: stefanzweifel/git-auto-commit-action@778341af668090896ca464160c2def5d1d1a3eb0 # v5
         with:
           commit_message: "chore: auto-update generated files"

cmd/huggingface/main.go 🔗

@@ -0,0 +1,203 @@
+// Package main provides a command-line tool to fetch models from Hugging Face Router
+// and generate a configuration file for the provider.
+package main
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"log"
+	"net/http"
+	"os"
+	"slices"
+	"strings"
+	"time"
+
+	"github.com/charmbracelet/catwalk/pkg/catwalk"
+)
+
+// SupportedProviders defines which providers we want to support.
+// Add or remove providers from this slice to control which ones are included.
+var SupportedProviders = []string{
+	// "together", // Multiple issues
+	"fireworks-ai",
+	//"nebius",
+	// "novita", // Usage report is wrong
+	"groq",
+	"cerebras",
+	// "hyperbolic",
+	// "nscale",
+	// "sambanova",
+	// "cohere",
+	"hf-inference",
+}
+
+// Model represents a model from the Hugging Face Router API.
+type Model struct {
+	ID        string     `json:"id"`
+	Object    string     `json:"object"`
+	Created   int64      `json:"created"`
+	OwnedBy   string     `json:"owned_by"`
+	Providers []Provider `json:"providers"`
+}
+
+// Provider represents a provider configuration for a model.
+type Provider struct {
+	Provider                 string   `json:"provider"`
+	Status                   string   `json:"status"`
+	ContextLength            int64    `json:"context_length,omitempty"`
+	Pricing                  *Pricing `json:"pricing,omitempty"`
+	SupportsTools            bool     `json:"supports_tools"`
+	SupportsStructuredOutput bool     `json:"supports_structured_output"`
+}
+
+// Pricing contains the pricing information for a provider.
+type Pricing struct {
+	Input  float64 `json:"input"`
+	Output float64 `json:"output"`
+}
+
+// ModelsResponse is the response structure for the Hugging Face Router models API.
+type ModelsResponse struct {
+	Object string  `json:"object"`
+	Data   []Model `json:"data"`
+}
+
+func fetchHuggingFaceModels() (*ModelsResponse, error) {
+	client := &http.Client{Timeout: 30 * time.Second}
+	req, _ := http.NewRequestWithContext(
+		context.Background(),
+		"GET",
+		"https://router.huggingface.co/v1/models",
+		nil,
+	)
+	req.Header.Set("User-Agent", "Crush-Client/1.0")
+	resp, err := client.Do(req)
+	if err != nil {
+		return nil, err //nolint:wrapcheck
+	}
+	defer resp.Body.Close() //nolint:errcheck
+	if resp.StatusCode != 200 {
+		body, _ := io.ReadAll(resp.Body)
+		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
+	}
+	var mr ModelsResponse
+	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
+		return nil, err //nolint:wrapcheck
+	}
+	return &mr, nil
+}
+
+// findContextWindow looks for a context window from any provider for the given model.
+func findContextWindow(model Model) int64 {
+	for _, provider := range model.Providers {
+		if provider.ContextLength > 0 {
+			return provider.ContextLength
+		}
+	}
+	return 0
+}
+
+// WARN: DO NOT USE
+// for now we have a subset list of models we use.
+func main() {
+	modelsResp, err := fetchHuggingFaceModels()
+	if err != nil {
+		log.Fatal("Error fetching Hugging Face models:", err)
+	}
+
+	hfProvider := catwalk.Provider{
+		Name:                "Hugging Face",
+		ID:                  catwalk.InferenceProviderHuggingFace,
+		APIKey:              "$HF_TOKEN",
+		APIEndpoint:         "https://router.huggingface.co/v1",
+		Type:                catwalk.TypeOpenAI,
+		DefaultLargeModelID: "moonshotai/Kimi-K2-Instruct-0905:groq",
+		DefaultSmallModelID: "openai/gpt-oss-20b",
+		Models:              []catwalk.Model{},
+		DefaultHeaders: map[string]string{
+			"HTTP-Referer": "https://charm.land",
+			"X-Title":      "Crush",
+		},
+	}
+
+	for _, model := range modelsResp.Data {
+		// Find context window from any provider for this model
+		fallbackContextLength := findContextWindow(model)
+		if fallbackContextLength == 0 {
+			fmt.Printf("Skipping model %s - no context window found in any provider\n", model.ID)
+			continue
+		}
+
+		for _, provider := range model.Providers {
+			// Skip unsupported providers
+			if !slices.Contains(SupportedProviders, provider.Provider) {
+				continue
+			}
+
+			// Skip providers that don't support tools
+			if !provider.SupportsTools {
+				continue
+			}
+
+			// Skip non-live providers
+			if provider.Status != "live" {
+				continue
+			}
+
+			// Create model with provider-specific ID and name
+			modelID := fmt.Sprintf("%s:%s", model.ID, provider.Provider)
+			modelName := fmt.Sprintf("%s (%s)", model.ID, provider.Provider)
+
+			// Use provider's context length, or fallback if not available
+			contextLength := provider.ContextLength
+			if contextLength == 0 {
+				contextLength = fallbackContextLength
+			}
+
+			// Calculate pricing (convert from per-token to per-1M tokens)
+			var costPer1MIn, costPer1MOut float64
+			if provider.Pricing != nil {
+				costPer1MIn = provider.Pricing.Input
+				costPer1MOut = provider.Pricing.Output
+			}
+
+			// Set default max tokens (conservative estimate)
+			defaultMaxTokens := min(contextLength/4, 8192)
+
+			m := catwalk.Model{
+				ID:                 modelID,
+				Name:               modelName,
+				CostPer1MIn:        costPer1MIn,
+				CostPer1MOut:       costPer1MOut,
+				CostPer1MInCached:  0, // Not provided by HF Router
+				CostPer1MOutCached: 0, // Not provided by HF Router
+				ContextWindow:      contextLength,
+				DefaultMaxTokens:   defaultMaxTokens,
+				CanReason:          false, // Not provided by HF Router
+				SupportsImages:     false, // Not provided by HF Router
+			}
+
+			hfProvider.Models = append(hfProvider.Models, m)
+			fmt.Printf("Added model %s with context window %d from provider %s\n",
+				modelID, contextLength, provider.Provider)
+		}
+	}
+
+	slices.SortFunc(hfProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
+		return strings.Compare(a.Name, b.Name)
+	})
+
+	// Save the JSON in internal/providers/configs/huggingface.json
+	data, err := json.MarshalIndent(hfProvider, "", "  ")
+	if err != nil {
+		log.Fatal("Error marshaling Hugging Face provider:", err)
+	}
+
+	if err := os.WriteFile("internal/providers/configs/huggingface.json", data, 0o600); err != nil {
+		log.Fatal("Error writing Hugging Face provider config:", err)
+	}
+
+	fmt.Printf("Generated huggingface.json with %d models\n", len(hfProvider.Models))
+}

internal/providers/configs/huggingface.json 🔗

@@ -0,0 +1,314 @@
+{
+  "name": "Hugging Face",
+  "id": "huggingface",
+  "api_key": "$HF_TOKEN",
+  "api_endpoint": "https://router.huggingface.co/v1",
+  "type": "openai",
+  "default_large_model_id": "moonshotai/Kimi-K2-Instruct-0905:groq",
+  "default_small_model_id": "openai/gpt-oss-20b",
+  "models": [
+    {
+      "id": "Qwen/Qwen3-235B-A22B:fireworks-ai",
+      "name": "Qwen/Qwen3-235B-A22B (fireworks-ai)",
+      "cost_per_1m_in": 0.22,
+      "cost_per_1m_out": 0.88,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "Qwen/Qwen3-235B-A22B-Instruct-2507:fireworks-ai",
+      "name": "Qwen/Qwen3-235B-A22B-Instruct-2507 (fireworks-ai)",
+      "cost_per_1m_in": 0.22,
+      "cost_per_1m_out": 0.88,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 262144,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "Qwen/Qwen3-235B-A22B-Thinking-2507:fireworks-ai",
+      "name": "Qwen/Qwen3-235B-A22B-Thinking-2507 (fireworks-ai)",
+      "cost_per_1m_in": 0.22,
+      "cost_per_1m_out": 0.88,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 262144,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "Qwen/Qwen3-30B-A3B:fireworks-ai",
+      "name": "Qwen/Qwen3-30B-A3B (fireworks-ai)",
+      "cost_per_1m_in": 0.15,
+      "cost_per_1m_out": 0.6,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "Qwen/Qwen3-Coder-480B-A35B-Instruct:cerebras",
+      "name": "Qwen/Qwen3-Coder-480B-A35B-Instruct (cerebras)",
+      "cost_per_1m_in": 2,
+      "cost_per_1m_out": 2,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 262144,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "Qwen/Qwen3-Coder-480B-A35B-Instruct:fireworks-ai",
+      "name": "Qwen/Qwen3-Coder-480B-A35B-Instruct (fireworks-ai)",
+      "cost_per_1m_in": 0.45,
+      "cost_per_1m_out": 1.8,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 262144,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "deepseek-ai/DeepSeek-V3-0324:fireworks-ai",
+      "name": "deepseek-ai/DeepSeek-V3-0324 (fireworks-ai)",
+      "cost_per_1m_in": 0.9,
+      "cost_per_1m_out": 0.9,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 163840,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "deepseek-ai/DeepSeek-V3.1:fireworks-ai",
+      "name": "deepseek-ai/DeepSeek-V3.1 (fireworks-ai)",
+      "cost_per_1m_in": 0,
+      "cost_per_1m_out": 0,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 163840,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "meta-llama/Llama-3.1-70B-Instruct:fireworks-ai",
+      "name": "meta-llama/Llama-3.1-70B-Instruct (fireworks-ai)",
+      "cost_per_1m_in": 0.9,
+      "cost_per_1m_out": 0.9,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "meta-llama/Llama-3.3-70B-Instruct:cerebras",
+      "name": "meta-llama/Llama-3.3-70B-Instruct (cerebras)",
+      "cost_per_1m_in": 0.85,
+      "cost_per_1m_out": 1.2,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "meta-llama/Llama-3.3-70B-Instruct:groq",
+      "name": "meta-llama/Llama-3.3-70B-Instruct (groq)",
+      "cost_per_1m_in": 0.59,
+      "cost_per_1m_out": 0.79,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "meta-llama/Llama-4-Maverick-17B-128E-Instruct:fireworks-ai",
+      "name": "meta-llama/Llama-4-Maverick-17B-128E-Instruct (fireworks-ai)",
+      "cost_per_1m_in": 0.22,
+      "cost_per_1m_out": 0.88,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 1048576,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "meta-llama/Llama-4-Maverick-17B-128E-Instruct:groq",
+      "name": "meta-llama/Llama-4-Maverick-17B-128E-Instruct (groq)",
+      "cost_per_1m_in": 0.2,
+      "cost_per_1m_out": 0.6,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "meta-llama/Llama-4-Scout-17B-16E-Instruct:groq",
+      "name": "meta-llama/Llama-4-Scout-17B-16E-Instruct (groq)",
+      "cost_per_1m_in": 0.11,
+      "cost_per_1m_out": 0.34,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "moonshotai/Kimi-K2-Instruct:fireworks-ai",
+      "name": "moonshotai/Kimi-K2-Instruct (fireworks-ai)",
+      "cost_per_1m_in": 0.6,
+      "cost_per_1m_out": 2.5,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "moonshotai/Kimi-K2-Instruct-0905:groq",
+      "name": "moonshotai/Kimi-K2-Instruct-0905 (groq)",
+      "cost_per_1m_in": 0,
+      "cost_per_1m_out": 0,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 262144,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "openai/gpt-oss-120b:cerebras",
+      "name": "openai/gpt-oss-120b (cerebras)",
+      "cost_per_1m_in": 0.25,
+      "cost_per_1m_out": 0.69,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "openai/gpt-oss-120b:fireworks-ai",
+      "name": "openai/gpt-oss-120b (fireworks-ai)",
+      "cost_per_1m_in": 0.15,
+      "cost_per_1m_out": 0.6,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "openai/gpt-oss-120b:groq",
+      "name": "openai/gpt-oss-120b (groq)",
+      "cost_per_1m_in": 0.15,
+      "cost_per_1m_out": 0.75,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "openai/gpt-oss-20b:fireworks-ai",
+      "name": "openai/gpt-oss-20b (fireworks-ai)",
+      "cost_per_1m_in": 0.05,
+      "cost_per_1m_out": 0.2,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "openai/gpt-oss-20b:groq",
+      "name": "openai/gpt-oss-20b (groq)",
+      "cost_per_1m_in": 0.1,
+      "cost_per_1m_out": 0.5,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "zai-org/GLM-4.5:fireworks-ai",
+      "name": "zai-org/GLM-4.5 (fireworks-ai)",
+      "cost_per_1m_in": 0.55,
+      "cost_per_1m_out": 2.19,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    },
+    {
+      "id": "zai-org/GLM-4.5-Air:fireworks-ai",
+      "name": "zai-org/GLM-4.5-Air (fireworks-ai)",
+      "cost_per_1m_in": 0.22,
+      "cost_per_1m_out": 0.88,
+      "cost_per_1m_in_cached": 0,
+      "cost_per_1m_out_cached": 0,
+      "context_window": 131072,
+      "default_max_tokens": 8192,
+      "can_reason": false,
+      "has_reasoning_efforts": false,
+      "supports_attachments": false
+    }
+  ],
+  "default_headers": {
+    "HTTP-Referer": "https://charm.land",
+    "X-Title": "Crush"
+  }
+}

internal/providers/providers.go 🔗

@@ -54,6 +54,9 @@ var chutesConfig []byte
 //go:embed configs/deepseek.json
 var deepSeekConfig []byte
 
+//go:embed configs/huggingface.json
+var huggingFaceConfig []byte
+
 // ProviderFunc is a function that returns a Provider.
 type ProviderFunc func() catwalk.Provider
 
@@ -73,6 +76,7 @@ var providerRegistry = []ProviderFunc{
 	veniceProvider,
 	chutesProvider,
 	deepSeekProvider,
+	huggingFaceProvider,
 }
 
 // GetAll returns all registered providers.
@@ -152,3 +156,7 @@ func chutesProvider() catwalk.Provider {
 func deepSeekProvider() catwalk.Provider {
 	return loadProviderFromConfig(deepSeekConfig)
 }
+
+func huggingFaceProvider() catwalk.Provider {
+	return loadProviderFromConfig(huggingFaceConfig)
+}

pkg/catwalk/provider.go 🔗

@@ -18,20 +18,21 @@ type InferenceProvider string
 
 // All the inference providers supported by the system.
 const (
-	InferenceProviderOpenAI     InferenceProvider = "openai"
-	InferenceProviderAnthropic  InferenceProvider = "anthropic"
-	InferenceProviderGemini     InferenceProvider = "gemini"
-	InferenceProviderAzure      InferenceProvider = "azure"
-	InferenceProviderBedrock    InferenceProvider = "bedrock"
-	InferenceProviderVertexAI   InferenceProvider = "vertexai"
-	InferenceProviderXAI        InferenceProvider = "xai"
-	InferenceProviderZAI        InferenceProvider = "zai"
-	InferenceProviderGROQ       InferenceProvider = "groq"
-	InferenceProviderOpenRouter InferenceProvider = "openrouter"
-	InferenceProviderLambda     InferenceProvider = "lambda"
-	InferenceProviderCerebras   InferenceProvider = "cerebras"
-	InferenceProviderVenice     InferenceProvider = "venice"
-	InferenceProviderChutes     InferenceProvider = "chutes"
+	InferenceProviderOpenAI      InferenceProvider = "openai"
+	InferenceProviderAnthropic   InferenceProvider = "anthropic"
+	InferenceProviderGemini      InferenceProvider = "gemini"
+	InferenceProviderAzure       InferenceProvider = "azure"
+	InferenceProviderBedrock     InferenceProvider = "bedrock"
+	InferenceProviderVertexAI    InferenceProvider = "vertexai"
+	InferenceProviderXAI         InferenceProvider = "xai"
+	InferenceProviderZAI         InferenceProvider = "zai"
+	InferenceProviderGROQ        InferenceProvider = "groq"
+	InferenceProviderOpenRouter  InferenceProvider = "openrouter"
+	InferenceProviderLambda      InferenceProvider = "lambda"
+	InferenceProviderCerebras    InferenceProvider = "cerebras"
+	InferenceProviderVenice      InferenceProvider = "venice"
+	InferenceProviderChutes      InferenceProvider = "chutes"
+	InferenceProviderHuggingFace InferenceProvider = "huggingface"
 )
 
 // Provider represents an AI provider configuration.
@@ -80,5 +81,6 @@ func KnownProviders() []InferenceProvider {
 		InferenceProviderCerebras,
 		InferenceProviderVenice,
 		InferenceProviderChutes,
+		InferenceProviderHuggingFace,
 	}
 }