providers.go

  1package providers
  2
  3import (
  4	_ "embed"
  5	"encoding/json"
  6	"log"
  7
  8	"github.com/charmbracelet/fur/pkg/provider"
  9)
 10
 11//go:embed configs/openai.json
 12var openAIConfig []byte
 13
 14//go:embed configs/anthropic.json
 15var anthropicConfig []byte
 16
 17//go:embed configs/gemini.json
 18var geminiConfig []byte
 19
 20//go:embed configs/openrouter.json
 21var openRouterConfig []byte
 22
 23//go:embed configs/azure.json
 24var azureConfig []byte
 25
 26//go:embed configs/vertexai.json
 27var vertexAIConfig []byte
 28
 29//go:embed configs/xai.json
 30var xAIConfig []byte
 31
 32//go:embed configs/bedrock.json
 33var bedrockConfig []byte
 34
 35// ProviderFunc is a function that returns a Provider
 36type ProviderFunc func() provider.Provider
 37
 38var providerRegistry = map[provider.InferenceProvider]ProviderFunc{
 39	provider.InferenceProviderOpenAI:     openAIProvider,
 40	provider.InferenceProviderAnthropic:  anthropicProvider,
 41	provider.InferenceProviderGemini:     geminiProvider,
 42	provider.InferenceProviderAzure:      azureProvider,
 43	provider.InferenceProviderBedrock:    bedrockProvider,
 44	provider.InferenceProviderVertexAI:   vertexAIProvider,
 45	provider.InferenceProviderXAI:        xAIProvider,
 46	provider.InferenceProviderOpenRouter: openRouterProvider,
 47}
 48
 49func GetAll() []provider.Provider {
 50	providers := make([]provider.Provider, 0, len(providerRegistry))
 51	for _, providerFunc := range providerRegistry {
 52		providers = append(providers, providerFunc())
 53	}
 54	return providers
 55}
 56
 57func GetByID(id provider.InferenceProvider) (provider.Provider, bool) {
 58	providerFunc, exists := providerRegistry[id]
 59	if !exists {
 60		return provider.Provider{}, false
 61	}
 62	return providerFunc(), true
 63}
 64
 65func GetAvailableIDs() []provider.InferenceProvider {
 66	ids := make([]provider.InferenceProvider, 0, len(providerRegistry))
 67	for id := range providerRegistry {
 68		ids = append(ids, id)
 69	}
 70	return ids
 71}
 72
 73func loadProviderFromConfig(configData []byte) provider.Provider {
 74	var p provider.Provider
 75	if err := json.Unmarshal(configData, &p); err != nil {
 76		log.Printf("Error loading provider config: %v", err)
 77		return provider.Provider{}
 78	}
 79	return p
 80}
 81
 82func openAIProvider() provider.Provider {
 83	return loadProviderFromConfig(openAIConfig)
 84}
 85
 86func anthropicProvider() provider.Provider {
 87	return loadProviderFromConfig(anthropicConfig)
 88}
 89
 90func geminiProvider() provider.Provider {
 91	return loadProviderFromConfig(geminiConfig)
 92}
 93
 94func azureProvider() provider.Provider {
 95	return loadProviderFromConfig(azureConfig)
 96}
 97
 98func bedrockProvider() provider.Provider {
 99	return loadProviderFromConfig(bedrockConfig)
100}
101
102func vertexAIProvider() provider.Provider {
103	return loadProviderFromConfig(vertexAIConfig)
104}
105
106func xAIProvider() provider.Provider {
107	return loadProviderFromConfig(xAIConfig)
108}
109
110func openRouterProvider() provider.Provider {
111	return loadProviderFromConfig(openRouterConfig)
112}