providers.go

  1// Package providers provides a registry of inference providers
  2package providers
  3
  4import (
  5	_ "embed"
  6	"encoding/json"
  7	"log"
  8
  9	"github.com/charmbracelet/fur/pkg/provider"
 10)
 11
 12//go:embed configs/openai.json
 13var openAIConfig []byte
 14
 15//go:embed configs/anthropic.json
 16var anthropicConfig []byte
 17
 18//go:embed configs/gemini.json
 19var geminiConfig []byte
 20
 21//go:embed configs/openrouter.json
 22var openRouterConfig []byte
 23
 24//go:embed configs/azure.json
 25var azureConfig []byte
 26
 27//go:embed configs/vertexai.json
 28var vertexAIConfig []byte
 29
 30//go:embed configs/xai.json
 31var xAIConfig []byte
 32
 33//go:embed configs/bedrock.json
 34var bedrockConfig []byte
 35
 36// ProviderFunc is a function that returns a Provider.
 37type ProviderFunc func() provider.Provider
 38
 39var providerRegistry = map[provider.InferenceProvider]ProviderFunc{
 40	provider.InferenceProviderOpenAI:     openAIProvider,
 41	provider.InferenceProviderAnthropic:  anthropicProvider,
 42	provider.InferenceProviderGemini:     geminiProvider,
 43	provider.InferenceProviderAzure:      azureProvider,
 44	provider.InferenceProviderBedrock:    bedrockProvider,
 45	provider.InferenceProviderVertexAI:   vertexAIProvider,
 46	provider.InferenceProviderXAI:        xAIProvider,
 47	provider.InferenceProviderOpenRouter: openRouterProvider,
 48}
 49
 50// GetAll returns all registered providers.
 51func GetAll() []provider.Provider {
 52	providers := make([]provider.Provider, 0, len(providerRegistry))
 53	for _, providerFunc := range providerRegistry {
 54		providers = append(providers, providerFunc())
 55	}
 56	return providers
 57}
 58
 59// GetByID returns a provider by its ID.
 60func GetByID(id provider.InferenceProvider) (provider.Provider, bool) {
 61	providerFunc, exists := providerRegistry[id]
 62	if !exists {
 63		return provider.Provider{}, false
 64	}
 65	return providerFunc(), true
 66}
 67
 68// GetAvailableIDs returns a slice of all available provider IDs.
 69func GetAvailableIDs() []provider.InferenceProvider {
 70	ids := make([]provider.InferenceProvider, 0, len(providerRegistry))
 71	for id := range providerRegistry {
 72		ids = append(ids, id)
 73	}
 74	return ids
 75}
 76
 77func loadProviderFromConfig(configData []byte) provider.Provider {
 78	var p provider.Provider
 79	if err := json.Unmarshal(configData, &p); err != nil {
 80		log.Printf("Error loading provider config: %v", err)
 81		return provider.Provider{}
 82	}
 83	return p
 84}
 85
 86func openAIProvider() provider.Provider {
 87	return loadProviderFromConfig(openAIConfig)
 88}
 89
 90func anthropicProvider() provider.Provider {
 91	return loadProviderFromConfig(anthropicConfig)
 92}
 93
 94func geminiProvider() provider.Provider {
 95	return loadProviderFromConfig(geminiConfig)
 96}
 97
 98func azureProvider() provider.Provider {
 99	return loadProviderFromConfig(azureConfig)
100}
101
102func bedrockProvider() provider.Provider {
103	return loadProviderFromConfig(bedrockConfig)
104}
105
106func vertexAIProvider() provider.Provider {
107	return loadProviderFromConfig(vertexAIConfig)
108}
109
110func xAIProvider() provider.Provider {
111	return loadProviderFromConfig(xAIConfig)
112}
113
114func openRouterProvider() provider.Provider {
115	return loadProviderFromConfig(openRouterConfig)
116}