1package providers
2
3import (
4 _ "embed"
5 "encoding/json"
6 "log"
7
8 "github.com/charmbracelet/fur/pkg/provider"
9)
10
11//go:embed configs/openai.json
12var openAIConfig []byte
13
14//go:embed configs/anthropic.json
15var anthropicConfig []byte
16
17//go:embed configs/gemini.json
18var geminiConfig []byte
19
20//go:embed configs/openrouter.json
21var openRouterConfig []byte
22
23//go:embed configs/azure.json
24var azureConfig []byte
25
26//go:embed configs/vertexai.json
27var vertexAIConfig []byte
28
29//go:embed configs/xai.json
30var xAIConfig []byte
31
32//go:embed configs/bedrock.json
33var bedrockConfig []byte
34
35// ProviderFunc is a function that returns a Provider
36type ProviderFunc func() provider.Provider
37
38var providerRegistry = map[provider.InferenceProvider]ProviderFunc{
39 provider.InferenceProviderOpenAI: openAIProvider,
40 provider.InferenceProviderAnthropic: anthropicProvider,
41 provider.InferenceProviderGemini: geminiProvider,
42 provider.InferenceProviderAzure: azureProvider,
43 provider.InferenceProviderBedrock: bedrockProvider,
44 provider.InferenceProviderVertexAI: vertexAIProvider,
45 provider.InferenceProviderXAI: xAIProvider,
46 provider.InferenceProviderOpenRouter: openRouterProvider,
47}
48
49func GetAll() []provider.Provider {
50 providers := make([]provider.Provider, 0, len(providerRegistry))
51 for _, providerFunc := range providerRegistry {
52 providers = append(providers, providerFunc())
53 }
54 return providers
55}
56
57func GetByID(id provider.InferenceProvider) (provider.Provider, bool) {
58 providerFunc, exists := providerRegistry[id]
59 if !exists {
60 return provider.Provider{}, false
61 }
62 return providerFunc(), true
63}
64
65func GetAvailableIDs() []provider.InferenceProvider {
66 ids := make([]provider.InferenceProvider, 0, len(providerRegistry))
67 for id := range providerRegistry {
68 ids = append(ids, id)
69 }
70 return ids
71}
72
73func loadProviderFromConfig(configData []byte) provider.Provider {
74 var p provider.Provider
75 if err := json.Unmarshal(configData, &p); err != nil {
76 log.Printf("Error loading provider config: %v", err)
77 return provider.Provider{}
78 }
79 return p
80}
81
82func openAIProvider() provider.Provider {
83 return loadProviderFromConfig(openAIConfig)
84}
85
86func anthropicProvider() provider.Provider {
87 return loadProviderFromConfig(anthropicConfig)
88}
89
90func geminiProvider() provider.Provider {
91 return loadProviderFromConfig(geminiConfig)
92}
93
94func azureProvider() provider.Provider {
95 return loadProviderFromConfig(azureConfig)
96}
97
98func bedrockProvider() provider.Provider {
99 return loadProviderFromConfig(bedrockConfig)
100}
101
102func vertexAIProvider() provider.Provider {
103 return loadProviderFromConfig(vertexAIConfig)
104}
105
106func xAIProvider() provider.Provider {
107 return loadProviderFromConfig(xAIConfig)
108}
109
110func openRouterProvider() provider.Provider {
111 return loadProviderFromConfig(openRouterConfig)
112}