providers.go

  1// Package providers provides a registry of inference providers
  2package providers
  3
  4import (
  5	_ "embed"
  6	"encoding/json"
  7	"log"
  8
  9	"github.com/charmbracelet/catwalk/pkg/provider"
 10)
 11
 12//go:embed configs/openai.json
 13var openAIConfig []byte
 14
 15//go:embed configs/anthropic.json
 16var anthropicConfig []byte
 17
 18//go:embed configs/gemini.json
 19var geminiConfig []byte
 20
 21//go:embed configs/openrouter.json
 22var openRouterConfig []byte
 23
 24//go:embed configs/azure.json
 25var azureConfig []byte
 26
 27//go:embed configs/vertexai.json
 28var vertexAIConfig []byte
 29
 30//go:embed configs/xai.json
 31var xAIConfig []byte
 32
 33//go:embed configs/bedrock.json
 34var bedrockConfig []byte
 35
 36//go:embed configs/groq.json
 37var groqConfig []byte
 38
 39// ProviderFunc is a function that returns a Provider.
 40type ProviderFunc func() provider.Provider
 41
 42var providerRegistry = []ProviderFunc{
 43	anthropicProvider,
 44	openAIProvider,
 45	geminiProvider,
 46	azureProvider,
 47	bedrockProvider,
 48	vertexAIProvider,
 49	xAIProvider,
 50	groqProvider,
 51	openRouterProvider,
 52}
 53
 54// GetAll returns all registered providers.
 55func GetAll() []provider.Provider {
 56	providers := make([]provider.Provider, 0, len(providerRegistry))
 57	for _, providerFunc := range providerRegistry {
 58		providers = append(providers, providerFunc())
 59	}
 60	return providers
 61}
 62
 63func loadProviderFromConfig(configData []byte) provider.Provider {
 64	var p provider.Provider
 65	if err := json.Unmarshal(configData, &p); err != nil {
 66		log.Printf("Error loading provider config: %v", err)
 67		return provider.Provider{}
 68	}
 69	return p
 70}
 71
 72func openAIProvider() provider.Provider {
 73	return loadProviderFromConfig(openAIConfig)
 74}
 75
 76func anthropicProvider() provider.Provider {
 77	return loadProviderFromConfig(anthropicConfig)
 78}
 79
 80func geminiProvider() provider.Provider {
 81	return loadProviderFromConfig(geminiConfig)
 82}
 83
 84func azureProvider() provider.Provider {
 85	return loadProviderFromConfig(azureConfig)
 86}
 87
 88func bedrockProvider() provider.Provider {
 89	return loadProviderFromConfig(bedrockConfig)
 90}
 91
 92func vertexAIProvider() provider.Provider {
 93	return loadProviderFromConfig(vertexAIConfig)
 94}
 95
 96func xAIProvider() provider.Provider {
 97	return loadProviderFromConfig(xAIConfig)
 98}
 99
100func openRouterProvider() provider.Provider {
101	return loadProviderFromConfig(openRouterConfig)
102}
103
104func groqProvider() provider.Provider {
105	return loadProviderFromConfig(groqConfig)
106}