feat: make the models available in pkg

Kujtim Hoxha created

Change summary

cmd/openrouter/main.go          | 10 ++--
internal/providers/providers.go | 66 ++++++++++++++++++++++++++++------
pkg/provider/provider.go        | 44 ++--------------------
3 files changed, 63 insertions(+), 57 deletions(-)

Detailed changes

cmd/openrouter/main.go 🔗

@@ -11,7 +11,7 @@ import (
 	"strconv"
 	"time"
 
-	"github.com/charmbracelet/fur/internal/providers"
+	"github.com/charmbracelet/fur/pkg/provider"
 )
 
 // Model represents the complete model configuration
@@ -120,14 +120,14 @@ func main() {
 		log.Fatal("Error fetching OpenRouter models:", err)
 	}
 
-	openRouterProvider := providers.Provider{
+	openRouterProvider := provider.Provider{
 		Name:           "OpenRouter",
 		ID:             "openrouter",
 		APIKey:         "$OPENROUTER_API_KEY",
 		APIEndpoint:    "https://openrouter.ai/api/v1",
-		Type:           providers.ProviderTypeOpenAI,
+		Type:           provider.ProviderTypeOpenAI,
 		DefaultModelID: "anthropic/claude-sonnet-4",
-		Models:         []providers.Model{},
+		Models:         []provider.Model{},
 	}
 
 	for _, model := range modelsResp.Data {
@@ -142,7 +142,7 @@ func main() {
 		canReason := slices.Contains(model.SupportedParams, "reasoning")
 		supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
 
-		m := providers.Model{
+		m := provider.Model{
 			ID:                 model.ID,
 			Name:               model.Name,
 			CostPer1MIn:        pricing.CostPer1MIn,

internal/providers/providers.go 🔗

@@ -4,6 +4,8 @@ import (
 	_ "embed"
 	"encoding/json"
 	"log"
+
+	"github.com/charmbracelet/fur/pkg/provider"
 )
 
 //go:embed configs/openai.json
@@ -30,43 +32,81 @@ var xAIConfig []byte
 //go:embed configs/bedrock.json
 var bedrockConfig []byte
 
-func loadProviderFromConfig(configData []byte) Provider {
-	var provider Provider
-	if err := json.Unmarshal(configData, &provider); err != nil {
+// ProviderFunc is a function that returns a Provider
+type ProviderFunc func() provider.Provider
+
+var providerRegistry = map[provider.InferenceProvider]ProviderFunc{
+	provider.InferenceProviderOpenAI:     openAIProvider,
+	provider.InferenceProviderAnthropic:  anthropicProvider,
+	provider.InferenceProviderGemini:     geminiProvider,
+	provider.InferenceProviderAzure:      azureProvider,
+	provider.InferenceProviderBedrock:    bedrockProvider,
+	provider.InferenceProviderVertexAI:   vertexAIProvider,
+	provider.InferenceProviderXAI:        xAIProvider,
+	provider.InferenceProviderOpenRouter: openRouterProvider,
+}
+
+func GetAll() []provider.Provider {
+	providers := make([]provider.Provider, 0, len(providerRegistry))
+	for _, providerFunc := range providerRegistry {
+		providers = append(providers, providerFunc())
+	}
+	return providers
+}
+
+func GetByID(id provider.InferenceProvider) (provider.Provider, bool) {
+	providerFunc, exists := providerRegistry[id]
+	if !exists {
+		return provider.Provider{}, false
+	}
+	return providerFunc(), true
+}
+
+func GetAvailableIDs() []provider.InferenceProvider {
+	ids := make([]provider.InferenceProvider, 0, len(providerRegistry))
+	for id := range providerRegistry {
+		ids = append(ids, id)
+	}
+	return ids
+}
+
+func loadProviderFromConfig(configData []byte) provider.Provider {
+	var p provider.Provider
+	if err := json.Unmarshal(configData, &p); err != nil {
 		log.Printf("Error loading provider config: %v", err)
-		return Provider{}
+		return provider.Provider{}
 	}
-	return provider
+	return p
 }
 
-func openAIProvider() Provider {
+func openAIProvider() provider.Provider {
 	return loadProviderFromConfig(openAIConfig)
 }
 
-func anthropicProvider() Provider {
+func anthropicProvider() provider.Provider {
 	return loadProviderFromConfig(anthropicConfig)
 }
 
-func geminiProvider() Provider {
+func geminiProvider() provider.Provider {
 	return loadProviderFromConfig(geminiConfig)
 }
 
-func azureProvider() Provider {
+func azureProvider() provider.Provider {
 	return loadProviderFromConfig(azureConfig)
 }
 
-func bedrockProvider() Provider {
+func bedrockProvider() provider.Provider {
 	return loadProviderFromConfig(bedrockConfig)
 }
 
-func vertexAIProvider() Provider {
+func vertexAIProvider() provider.Provider {
 	return loadProviderFromConfig(vertexAIConfig)
 }
 
-func xAIProvider() Provider {
+func xAIProvider() provider.Provider {
 	return loadProviderFromConfig(xAIConfig)
 }
 
-func openRouterProvider() Provider {
+func openRouterProvider() provider.Provider {
 	return loadProviderFromConfig(openRouterConfig)
 }

internal/providers/types.go → pkg/provider/provider.go 🔗

@@ -1,5 +1,6 @@
-package providers
+package provider
 
+// ProviderType represents the type of AI provider
 type ProviderType string
 
 const (
@@ -13,6 +14,7 @@ const (
 	ProviderTypeOpenRouter ProviderType = "openrouter"
 )
 
+// InferenceProvider represents the inference provider identifier
 type InferenceProvider string
 
 const (
@@ -26,6 +28,7 @@ const (
 	InferenceProviderOpenRouter InferenceProvider = "openrouter"
 )
 
+// Provider represents an AI provider configuration
 type Provider struct {
 	Name           string            `json:"name"`
 	ID             InferenceProvider `json:"id"`
@@ -36,6 +39,7 @@ type Provider struct {
 	Models         []Model           `json:"models,omitempty"`
 }
 
+// Model represents an AI model configuration
 type Model struct {
 	ID                 string  `json:"id"`
 	Name               string  `json:"model"`
@@ -48,41 +52,3 @@ type Model struct {
 	CanReason          bool    `json:"can_reason"`
 	SupportsImages     bool    `json:"supports_attachments"`
 }
-
-type ProviderFunc func() Provider
-
-var providerRegistry = map[InferenceProvider]ProviderFunc{
-	InferenceProviderOpenAI:     openAIProvider,
-	InferenceProviderAnthropic:  anthropicProvider,
-	InferenceProviderGemini:     geminiProvider,
-	InferenceProviderAzure:      azureProvider,
-	InferenceProviderBedrock:    bedrockProvider,
-	InferenceProviderVertexAI:   vertexAIProvider,
-	InferenceProviderXAI:        xAIProvider,
-	InferenceProviderOpenRouter: openRouterProvider,
-}
-
-func GetAll() []Provider {
-	providers := make([]Provider, 0, len(providerRegistry))
-	for _, providerFunc := range providerRegistry {
-		providers = append(providers, providerFunc())
-	}
-	return providers
-}
-
-func GetByID(id InferenceProvider) (Provider, bool) {
-	providerFunc, exists := providerRegistry[id]
-	if !exists {
-		return Provider{}, false
-	}
-	return providerFunc(), true
-}
-
-func GetAvailableIDs() []InferenceProvider {
-	ids := make([]InferenceProvider, 0, len(providerRegistry))
-	for id := range providerRegistry {
-		ids = append(ids, id)
-	}
-	return ids
-}
-