1// Package providers provides a registry of inference providers
2package providers
3
4import (
5 _ "embed"
6 "encoding/json"
7 "log"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10)
11
12//go:embed configs/openai.json
13var openAIConfig []byte
14
15//go:embed configs/anthropic.json
16var anthropicConfig []byte
17
18//go:embed configs/gemini.json
19var geminiConfig []byte
20
21//go:embed configs/openrouter.json
22var openRouterConfig []byte
23
24//go:embed configs/azure.json
25var azureConfig []byte
26
27//go:embed configs/vertexai.json
28var vertexAIConfig []byte
29
30//go:embed configs/xai.json
31var xAIConfig []byte
32
33//go:embed configs/bedrock.json
34var bedrockConfig []byte
35
36//go:embed configs/groq.json
37var groqConfig []byte
38
39//go:embed configs/lambda.json
40var lambdaConfig []byte
41
42//go:embed configs/cerebras.json
43var cerebrasConfig []byte
44
45// ProviderFunc is a function that returns a Provider.
46type ProviderFunc func() catwalk.Provider
47
48var providerRegistry = []ProviderFunc{
49 anthropicProvider,
50 openAIProvider,
51 geminiProvider,
52 azureProvider,
53 bedrockProvider,
54 vertexAIProvider,
55 xAIProvider,
56 groqProvider,
57 openRouterProvider,
58 lambdaProvider,
59 cerebrasProvider,
60}
61
62// GetAll returns all registered providers.
63func GetAll() []catwalk.Provider {
64 providers := make([]catwalk.Provider, 0, len(providerRegistry))
65 for _, providerFunc := range providerRegistry {
66 providers = append(providers, providerFunc())
67 }
68 return providers
69}
70
71func loadProviderFromConfig(configData []byte) catwalk.Provider {
72 var p catwalk.Provider
73 if err := json.Unmarshal(configData, &p); err != nil {
74 log.Printf("Error loading provider config: %v", err)
75 return catwalk.Provider{}
76 }
77 return p
78}
79
80func openAIProvider() catwalk.Provider {
81 return loadProviderFromConfig(openAIConfig)
82}
83
84func anthropicProvider() catwalk.Provider {
85 return loadProviderFromConfig(anthropicConfig)
86}
87
88func geminiProvider() catwalk.Provider {
89 return loadProviderFromConfig(geminiConfig)
90}
91
92func azureProvider() catwalk.Provider {
93 return loadProviderFromConfig(azureConfig)
94}
95
96func bedrockProvider() catwalk.Provider {
97 return loadProviderFromConfig(bedrockConfig)
98}
99
100func vertexAIProvider() catwalk.Provider {
101 return loadProviderFromConfig(vertexAIConfig)
102}
103
104func xAIProvider() catwalk.Provider {
105 return loadProviderFromConfig(xAIConfig)
106}
107
108func openRouterProvider() catwalk.Provider {
109 return loadProviderFromConfig(openRouterConfig)
110}
111
112func groqProvider() catwalk.Provider {
113 return loadProviderFromConfig(groqConfig)
114}
115
116func lambdaProvider() catwalk.Provider {
117 return loadProviderFromConfig(lambdaConfig)
118}
119
120func cerebrasProvider() catwalk.Provider {
121 return loadProviderFromConfig(cerebrasConfig)
122}