From c152f2d87f673ee578fb83355c535f48896fe507 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 23 Sep 2025 15:52:16 +0200 Subject: [PATCH] feat: support hugging face inference router (#59) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support hugging face inference router * chore: lint * chore: add huggingface script to nightly update * fix: remove HF prefix from types in huggingface package Remove redundant HF prefix from type names since they are already in the huggingface package namespace. 💖 Generated with Crush Co-Authored-By: Crush * chore: small changes * chore: only add models that seem to be working * chore: lint --------- Co-authored-by: Crush --- .github/workflows/update.yml | 2 + cmd/huggingface/main.go | 203 +++++++++++++ internal/providers/configs/huggingface.json | 314 ++++++++++++++++++++ internal/providers/providers.go | 8 + pkg/catwalk/provider.go | 30 +- 5 files changed, 543 insertions(+), 14 deletions(-) create mode 100644 cmd/huggingface/main.go create mode 100644 internal/providers/configs/huggingface.json diff --git a/.github/workflows/update.yml b/.github/workflows/update.yml index 6dba46976b386179392731a6e2fde1ee96028402..89b1b20a399f1cf023271dad68bc2cfb143f7035 100644 --- a/.github/workflows/update.yml +++ b/.github/workflows/update.yml @@ -18,6 +18,8 @@ jobs: with: go-version-file: go.mod - run: go run ./cmd/openrouter/main.go + # we need to add this back when we know that the providers/models all work + # - run: go run ./cmd/huggingface/main.go - uses: stefanzweifel/git-auto-commit-action@778341af668090896ca464160c2def5d1d1a3eb0 # v5 with: commit_message: "chore: auto-update generated files" diff --git a/cmd/huggingface/main.go b/cmd/huggingface/main.go new file mode 100644 index 0000000000000000000000000000000000000000..583081afd0348fdceb2298f97ab132218e021dae --- /dev/null +++ b/cmd/huggingface/main.go @@ -0,0 +1,203 @@ +// Package main provides a command-line tool to fetch models from Hugging Face Router +// and generate a configuration file for the provider. +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "slices" + "strings" + "time" + + "github.com/charmbracelet/catwalk/pkg/catwalk" +) + +// SupportedProviders defines which providers we want to support. +// Add or remove providers from this slice to control which ones are included. +var SupportedProviders = []string{ + // "together", // Multiple issues + "fireworks-ai", + //"nebius", + // "novita", // Usage report is wrong + "groq", + "cerebras", + // "hyperbolic", + // "nscale", + // "sambanova", + // "cohere", + "hf-inference", +} + +// Model represents a model from the Hugging Face Router API. +type Model struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Providers []Provider `json:"providers"` +} + +// Provider represents a provider configuration for a model. +type Provider struct { + Provider string `json:"provider"` + Status string `json:"status"` + ContextLength int64 `json:"context_length,omitempty"` + Pricing *Pricing `json:"pricing,omitempty"` + SupportsTools bool `json:"supports_tools"` + SupportsStructuredOutput bool `json:"supports_structured_output"` +} + +// Pricing contains the pricing information for a provider. +type Pricing struct { + Input float64 `json:"input"` + Output float64 `json:"output"` +} + +// ModelsResponse is the response structure for the Hugging Face Router models API. +type ModelsResponse struct { + Object string `json:"object"` + Data []Model `json:"data"` +} + +func fetchHuggingFaceModels() (*ModelsResponse, error) { + client := &http.Client{Timeout: 30 * time.Second} + req, _ := http.NewRequestWithContext( + context.Background(), + "GET", + "https://router.huggingface.co/v1/models", + nil, + ) + req.Header.Set("User-Agent", "Crush-Client/1.0") + resp, err := client.Do(req) + if err != nil { + return nil, err //nolint:wrapcheck + } + defer resp.Body.Close() //nolint:errcheck + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body) + } + var mr ModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil { + return nil, err //nolint:wrapcheck + } + return &mr, nil +} + +// findContextWindow looks for a context window from any provider for the given model. +func findContextWindow(model Model) int64 { + for _, provider := range model.Providers { + if provider.ContextLength > 0 { + return provider.ContextLength + } + } + return 0 +} + +// WARN: DO NOT USE +// for now we have a subset list of models we use. +func main() { + modelsResp, err := fetchHuggingFaceModels() + if err != nil { + log.Fatal("Error fetching Hugging Face models:", err) + } + + hfProvider := catwalk.Provider{ + Name: "Hugging Face", + ID: catwalk.InferenceProviderHuggingFace, + APIKey: "$HF_TOKEN", + APIEndpoint: "https://router.huggingface.co/v1", + Type: catwalk.TypeOpenAI, + DefaultLargeModelID: "moonshotai/Kimi-K2-Instruct-0905:groq", + DefaultSmallModelID: "openai/gpt-oss-20b", + Models: []catwalk.Model{}, + DefaultHeaders: map[string]string{ + "HTTP-Referer": "https://charm.land", + "X-Title": "Crush", + }, + } + + for _, model := range modelsResp.Data { + // Find context window from any provider for this model + fallbackContextLength := findContextWindow(model) + if fallbackContextLength == 0 { + fmt.Printf("Skipping model %s - no context window found in any provider\n", model.ID) + continue + } + + for _, provider := range model.Providers { + // Skip unsupported providers + if !slices.Contains(SupportedProviders, provider.Provider) { + continue + } + + // Skip providers that don't support tools + if !provider.SupportsTools { + continue + } + + // Skip non-live providers + if provider.Status != "live" { + continue + } + + // Create model with provider-specific ID and name + modelID := fmt.Sprintf("%s:%s", model.ID, provider.Provider) + modelName := fmt.Sprintf("%s (%s)", model.ID, provider.Provider) + + // Use provider's context length, or fallback if not available + contextLength := provider.ContextLength + if contextLength == 0 { + contextLength = fallbackContextLength + } + + // Calculate pricing (convert from per-token to per-1M tokens) + var costPer1MIn, costPer1MOut float64 + if provider.Pricing != nil { + costPer1MIn = provider.Pricing.Input + costPer1MOut = provider.Pricing.Output + } + + // Set default max tokens (conservative estimate) + defaultMaxTokens := min(contextLength/4, 8192) + + m := catwalk.Model{ + ID: modelID, + Name: modelName, + CostPer1MIn: costPer1MIn, + CostPer1MOut: costPer1MOut, + CostPer1MInCached: 0, // Not provided by HF Router + CostPer1MOutCached: 0, // Not provided by HF Router + ContextWindow: contextLength, + DefaultMaxTokens: defaultMaxTokens, + CanReason: false, // Not provided by HF Router + SupportsImages: false, // Not provided by HF Router + } + + hfProvider.Models = append(hfProvider.Models, m) + fmt.Printf("Added model %s with context window %d from provider %s\n", + modelID, contextLength, provider.Provider) + } + } + + slices.SortFunc(hfProvider.Models, func(a catwalk.Model, b catwalk.Model) int { + return strings.Compare(a.Name, b.Name) + }) + + // Save the JSON in internal/providers/configs/huggingface.json + data, err := json.MarshalIndent(hfProvider, "", " ") + if err != nil { + log.Fatal("Error marshaling Hugging Face provider:", err) + } + + if err := os.WriteFile("internal/providers/configs/huggingface.json", data, 0o600); err != nil { + log.Fatal("Error writing Hugging Face provider config:", err) + } + + fmt.Printf("Generated huggingface.json with %d models\n", len(hfProvider.Models)) +} diff --git a/internal/providers/configs/huggingface.json b/internal/providers/configs/huggingface.json new file mode 100644 index 0000000000000000000000000000000000000000..b65c5bce228c6cf89d5923ca2787de4c52d97cde --- /dev/null +++ b/internal/providers/configs/huggingface.json @@ -0,0 +1,314 @@ +{ + "name": "Hugging Face", + "id": "huggingface", + "api_key": "$HF_TOKEN", + "api_endpoint": "https://router.huggingface.co/v1", + "type": "openai", + "default_large_model_id": "moonshotai/Kimi-K2-Instruct-0905:groq", + "default_small_model_id": "openai/gpt-oss-20b", + "models": [ + { + "id": "Qwen/Qwen3-235B-A22B:fireworks-ai", + "name": "Qwen/Qwen3-235B-A22B (fireworks-ai)", + "cost_per_1m_in": 0.22, + "cost_per_1m_out": 0.88, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "Qwen/Qwen3-235B-A22B-Instruct-2507:fireworks-ai", + "name": "Qwen/Qwen3-235B-A22B-Instruct-2507 (fireworks-ai)", + "cost_per_1m_in": 0.22, + "cost_per_1m_out": 0.88, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 262144, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "Qwen/Qwen3-235B-A22B-Thinking-2507:fireworks-ai", + "name": "Qwen/Qwen3-235B-A22B-Thinking-2507 (fireworks-ai)", + "cost_per_1m_in": 0.22, + "cost_per_1m_out": 0.88, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 262144, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "Qwen/Qwen3-30B-A3B:fireworks-ai", + "name": "Qwen/Qwen3-30B-A3B (fireworks-ai)", + "cost_per_1m_in": 0.15, + "cost_per_1m_out": 0.6, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "Qwen/Qwen3-Coder-480B-A35B-Instruct:cerebras", + "name": "Qwen/Qwen3-Coder-480B-A35B-Instruct (cerebras)", + "cost_per_1m_in": 2, + "cost_per_1m_out": 2, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 262144, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "Qwen/Qwen3-Coder-480B-A35B-Instruct:fireworks-ai", + "name": "Qwen/Qwen3-Coder-480B-A35B-Instruct (fireworks-ai)", + "cost_per_1m_in": 0.45, + "cost_per_1m_out": 1.8, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 262144, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "deepseek-ai/DeepSeek-V3-0324:fireworks-ai", + "name": "deepseek-ai/DeepSeek-V3-0324 (fireworks-ai)", + "cost_per_1m_in": 0.9, + "cost_per_1m_out": 0.9, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 163840, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "deepseek-ai/DeepSeek-V3.1:fireworks-ai", + "name": "deepseek-ai/DeepSeek-V3.1 (fireworks-ai)", + "cost_per_1m_in": 0, + "cost_per_1m_out": 0, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 163840, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "meta-llama/Llama-3.1-70B-Instruct:fireworks-ai", + "name": "meta-llama/Llama-3.1-70B-Instruct (fireworks-ai)", + "cost_per_1m_in": 0.9, + "cost_per_1m_out": 0.9, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "meta-llama/Llama-3.3-70B-Instruct:cerebras", + "name": "meta-llama/Llama-3.3-70B-Instruct (cerebras)", + "cost_per_1m_in": 0.85, + "cost_per_1m_out": 1.2, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "meta-llama/Llama-3.3-70B-Instruct:groq", + "name": "meta-llama/Llama-3.3-70B-Instruct (groq)", + "cost_per_1m_in": 0.59, + "cost_per_1m_out": 0.79, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "meta-llama/Llama-4-Maverick-17B-128E-Instruct:fireworks-ai", + "name": "meta-llama/Llama-4-Maverick-17B-128E-Instruct (fireworks-ai)", + "cost_per_1m_in": 0.22, + "cost_per_1m_out": 0.88, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 1048576, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "meta-llama/Llama-4-Maverick-17B-128E-Instruct:groq", + "name": "meta-llama/Llama-4-Maverick-17B-128E-Instruct (groq)", + "cost_per_1m_in": 0.2, + "cost_per_1m_out": 0.6, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "meta-llama/Llama-4-Scout-17B-16E-Instruct:groq", + "name": "meta-llama/Llama-4-Scout-17B-16E-Instruct (groq)", + "cost_per_1m_in": 0.11, + "cost_per_1m_out": 0.34, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "moonshotai/Kimi-K2-Instruct:fireworks-ai", + "name": "moonshotai/Kimi-K2-Instruct (fireworks-ai)", + "cost_per_1m_in": 0.6, + "cost_per_1m_out": 2.5, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "moonshotai/Kimi-K2-Instruct-0905:groq", + "name": "moonshotai/Kimi-K2-Instruct-0905 (groq)", + "cost_per_1m_in": 0, + "cost_per_1m_out": 0, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 262144, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "openai/gpt-oss-120b:cerebras", + "name": "openai/gpt-oss-120b (cerebras)", + "cost_per_1m_in": 0.25, + "cost_per_1m_out": 0.69, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "openai/gpt-oss-120b:fireworks-ai", + "name": "openai/gpt-oss-120b (fireworks-ai)", + "cost_per_1m_in": 0.15, + "cost_per_1m_out": 0.6, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "openai/gpt-oss-120b:groq", + "name": "openai/gpt-oss-120b (groq)", + "cost_per_1m_in": 0.15, + "cost_per_1m_out": 0.75, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "openai/gpt-oss-20b:fireworks-ai", + "name": "openai/gpt-oss-20b (fireworks-ai)", + "cost_per_1m_in": 0.05, + "cost_per_1m_out": 0.2, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "openai/gpt-oss-20b:groq", + "name": "openai/gpt-oss-20b (groq)", + "cost_per_1m_in": 0.1, + "cost_per_1m_out": 0.5, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "zai-org/GLM-4.5:fireworks-ai", + "name": "zai-org/GLM-4.5 (fireworks-ai)", + "cost_per_1m_in": 0.55, + "cost_per_1m_out": 2.19, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + }, + { + "id": "zai-org/GLM-4.5-Air:fireworks-ai", + "name": "zai-org/GLM-4.5-Air (fireworks-ai)", + "cost_per_1m_in": 0.22, + "cost_per_1m_out": 0.88, + "cost_per_1m_in_cached": 0, + "cost_per_1m_out_cached": 0, + "context_window": 131072, + "default_max_tokens": 8192, + "can_reason": false, + "has_reasoning_efforts": false, + "supports_attachments": false + } + ], + "default_headers": { + "HTTP-Referer": "https://charm.land", + "X-Title": "Crush" + } +} diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 76a1200f64e9d544629e5ef74d2f2eb1f73ef07d..ee81d37e29b073d981e83a2d18a04bd27b0679e5 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -54,6 +54,9 @@ var chutesConfig []byte //go:embed configs/deepseek.json var deepSeekConfig []byte +//go:embed configs/huggingface.json +var huggingFaceConfig []byte + // ProviderFunc is a function that returns a Provider. type ProviderFunc func() catwalk.Provider @@ -73,6 +76,7 @@ var providerRegistry = []ProviderFunc{ veniceProvider, chutesProvider, deepSeekProvider, + huggingFaceProvider, } // GetAll returns all registered providers. @@ -152,3 +156,7 @@ func chutesProvider() catwalk.Provider { func deepSeekProvider() catwalk.Provider { return loadProviderFromConfig(deepSeekConfig) } + +func huggingFaceProvider() catwalk.Provider { + return loadProviderFromConfig(huggingFaceConfig) +} diff --git a/pkg/catwalk/provider.go b/pkg/catwalk/provider.go index 589d0ff6704ca1f0c0f2fde57bffb832e47198d9..a4f1aaa8da3ade25614c04d615be6cf631496a88 100644 --- a/pkg/catwalk/provider.go +++ b/pkg/catwalk/provider.go @@ -18,20 +18,21 @@ type InferenceProvider string // All the inference providers supported by the system. const ( - InferenceProviderOpenAI InferenceProvider = "openai" - InferenceProviderAnthropic InferenceProvider = "anthropic" - InferenceProviderGemini InferenceProvider = "gemini" - InferenceProviderAzure InferenceProvider = "azure" - InferenceProviderBedrock InferenceProvider = "bedrock" - InferenceProviderVertexAI InferenceProvider = "vertexai" - InferenceProviderXAI InferenceProvider = "xai" - InferenceProviderZAI InferenceProvider = "zai" - InferenceProviderGROQ InferenceProvider = "groq" - InferenceProviderOpenRouter InferenceProvider = "openrouter" - InferenceProviderLambda InferenceProvider = "lambda" - InferenceProviderCerebras InferenceProvider = "cerebras" - InferenceProviderVenice InferenceProvider = "venice" - InferenceProviderChutes InferenceProvider = "chutes" + InferenceProviderOpenAI InferenceProvider = "openai" + InferenceProviderAnthropic InferenceProvider = "anthropic" + InferenceProviderGemini InferenceProvider = "gemini" + InferenceProviderAzure InferenceProvider = "azure" + InferenceProviderBedrock InferenceProvider = "bedrock" + InferenceProviderVertexAI InferenceProvider = "vertexai" + InferenceProviderXAI InferenceProvider = "xai" + InferenceProviderZAI InferenceProvider = "zai" + InferenceProviderGROQ InferenceProvider = "groq" + InferenceProviderOpenRouter InferenceProvider = "openrouter" + InferenceProviderLambda InferenceProvider = "lambda" + InferenceProviderCerebras InferenceProvider = "cerebras" + InferenceProviderVenice InferenceProvider = "venice" + InferenceProviderChutes InferenceProvider = "chutes" + InferenceProviderHuggingFace InferenceProvider = "huggingface" ) // Provider represents an AI provider configuration. @@ -80,5 +81,6 @@ func KnownProviders() []InferenceProvider { InferenceProviderCerebras, InferenceProviderVenice, InferenceProviderChutes, + InferenceProviderHuggingFace, } }