1// Package providers provides a registry of inference providers
2package providers
3
4import (
5 _ "embed"
6 "encoding/json"
7 "log"
8
9 "github.com/charmbracelet/fur/pkg/provider"
10)
11
12//go:embed configs/openai.json
13var openAIConfig []byte
14
15//go:embed configs/anthropic.json
16var anthropicConfig []byte
17
18//go:embed configs/gemini.json
19var geminiConfig []byte
20
21//go:embed configs/openrouter.json
22var openRouterConfig []byte
23
24//go:embed configs/azure.json
25var azureConfig []byte
26
27//go:embed configs/vertexai.json
28var vertexAIConfig []byte
29
30//go:embed configs/xai.json
31var xAIConfig []byte
32
33//go:embed configs/bedrock.json
34var bedrockConfig []byte
35
36// ProviderFunc is a function that returns a Provider.
37type ProviderFunc func() provider.Provider
38
39var providerRegistry = map[provider.InferenceProvider]ProviderFunc{
40 provider.InferenceProviderOpenAI: openAIProvider,
41 provider.InferenceProviderAnthropic: anthropicProvider,
42 provider.InferenceProviderGemini: geminiProvider,
43 provider.InferenceProviderAzure: azureProvider,
44 provider.InferenceProviderBedrock: bedrockProvider,
45 provider.InferenceProviderVertexAI: vertexAIProvider,
46 provider.InferenceProviderXAI: xAIProvider,
47 provider.InferenceProviderOpenRouter: openRouterProvider,
48}
49
50// GetAll returns all registered providers.
51func GetAll() []provider.Provider {
52 providers := make([]provider.Provider, 0, len(providerRegistry))
53 for _, providerFunc := range providerRegistry {
54 providers = append(providers, providerFunc())
55 }
56 return providers
57}
58
59// GetByID returns a provider by its ID.
60func GetByID(id provider.InferenceProvider) (provider.Provider, bool) {
61 providerFunc, exists := providerRegistry[id]
62 if !exists {
63 return provider.Provider{}, false
64 }
65 return providerFunc(), true
66}
67
68// GetAvailableIDs returns a slice of all available provider IDs.
69func GetAvailableIDs() []provider.InferenceProvider {
70 ids := make([]provider.InferenceProvider, 0, len(providerRegistry))
71 for id := range providerRegistry {
72 ids = append(ids, id)
73 }
74 return ids
75}
76
77func loadProviderFromConfig(configData []byte) provider.Provider {
78 var p provider.Provider
79 if err := json.Unmarshal(configData, &p); err != nil {
80 log.Printf("Error loading provider config: %v", err)
81 return provider.Provider{}
82 }
83 return p
84}
85
86func openAIProvider() provider.Provider {
87 return loadProviderFromConfig(openAIConfig)
88}
89
90func anthropicProvider() provider.Provider {
91 return loadProviderFromConfig(anthropicConfig)
92}
93
94func geminiProvider() provider.Provider {
95 return loadProviderFromConfig(geminiConfig)
96}
97
98func azureProvider() provider.Provider {
99 return loadProviderFromConfig(azureConfig)
100}
101
102func bedrockProvider() provider.Provider {
103 return loadProviderFromConfig(bedrockConfig)
104}
105
106func vertexAIProvider() provider.Provider {
107 return loadProviderFromConfig(vertexAIConfig)
108}
109
110func xAIProvider() provider.Provider {
111 return loadProviderFromConfig(xAIConfig)
112}
113
114func openRouterProvider() provider.Provider {
115 return loadProviderFromConfig(openRouterConfig)
116}