From 3e5b8e3ae154202e49ff8792eb2a00535af4e397 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 25 Jun 2025 13:00:26 +0200 Subject: [PATCH] feat: make the models available in pkg --- cmd/openrouter/main.go | 10 +-- internal/providers/providers.go | 66 +++++++++++++++---- .../types.go => pkg/provider/provider.go | 44 ++----------- 3 files changed, 63 insertions(+), 57 deletions(-) rename internal/providers/types.go => pkg/provider/provider.go (64%) diff --git a/cmd/openrouter/main.go b/cmd/openrouter/main.go index 25529fb46adb0f0a812d28768ec51c837fa6a613..c5a1162d50528a4158f49284ac6d51b7989801a1 100644 --- a/cmd/openrouter/main.go +++ b/cmd/openrouter/main.go @@ -11,7 +11,7 @@ import ( "strconv" "time" - "github.com/charmbracelet/fur/internal/providers" + "github.com/charmbracelet/fur/pkg/provider" ) // Model represents the complete model configuration @@ -120,14 +120,14 @@ func main() { log.Fatal("Error fetching OpenRouter models:", err) } - openRouterProvider := providers.Provider{ + openRouterProvider := provider.Provider{ Name: "OpenRouter", ID: "openrouter", APIKey: "$OPENROUTER_API_KEY", APIEndpoint: "https://openrouter.ai/api/v1", - Type: providers.ProviderTypeOpenAI, + Type: provider.ProviderTypeOpenAI, DefaultModelID: "anthropic/claude-sonnet-4", - Models: []providers.Model{}, + Models: []provider.Model{}, } for _, model := range modelsResp.Data { @@ -142,7 +142,7 @@ func main() { canReason := slices.Contains(model.SupportedParams, "reasoning") supportsImages := slices.Contains(model.Architecture.InputModalities, "image") - m := providers.Model{ + m := provider.Model{ ID: model.ID, Name: model.Name, CostPer1MIn: pricing.CostPer1MIn, diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 2123df488968876994303a3346061e762f9cb628..b3f9c9b37193b1ee1f517e6de19de7e3988dc465 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -4,6 +4,8 @@ import ( _ "embed" "encoding/json" "log" + + "github.com/charmbracelet/fur/pkg/provider" ) //go:embed configs/openai.json @@ -30,43 +32,81 @@ var xAIConfig []byte //go:embed configs/bedrock.json var bedrockConfig []byte -func loadProviderFromConfig(configData []byte) Provider { - var provider Provider - if err := json.Unmarshal(configData, &provider); err != nil { +// ProviderFunc is a function that returns a Provider +type ProviderFunc func() provider.Provider + +var providerRegistry = map[provider.InferenceProvider]ProviderFunc{ + provider.InferenceProviderOpenAI: openAIProvider, + provider.InferenceProviderAnthropic: anthropicProvider, + provider.InferenceProviderGemini: geminiProvider, + provider.InferenceProviderAzure: azureProvider, + provider.InferenceProviderBedrock: bedrockProvider, + provider.InferenceProviderVertexAI: vertexAIProvider, + provider.InferenceProviderXAI: xAIProvider, + provider.InferenceProviderOpenRouter: openRouterProvider, +} + +func GetAll() []provider.Provider { + providers := make([]provider.Provider, 0, len(providerRegistry)) + for _, providerFunc := range providerRegistry { + providers = append(providers, providerFunc()) + } + return providers +} + +func GetByID(id provider.InferenceProvider) (provider.Provider, bool) { + providerFunc, exists := providerRegistry[id] + if !exists { + return provider.Provider{}, false + } + return providerFunc(), true +} + +func GetAvailableIDs() []provider.InferenceProvider { + ids := make([]provider.InferenceProvider, 0, len(providerRegistry)) + for id := range providerRegistry { + ids = append(ids, id) + } + return ids +} + +func loadProviderFromConfig(configData []byte) provider.Provider { + var p provider.Provider + if err := json.Unmarshal(configData, &p); err != nil { log.Printf("Error loading provider config: %v", err) - return Provider{} + return provider.Provider{} } - return provider + return p } -func openAIProvider() Provider { +func openAIProvider() provider.Provider { return loadProviderFromConfig(openAIConfig) } -func anthropicProvider() Provider { +func anthropicProvider() provider.Provider { return loadProviderFromConfig(anthropicConfig) } -func geminiProvider() Provider { +func geminiProvider() provider.Provider { return loadProviderFromConfig(geminiConfig) } -func azureProvider() Provider { +func azureProvider() provider.Provider { return loadProviderFromConfig(azureConfig) } -func bedrockProvider() Provider { +func bedrockProvider() provider.Provider { return loadProviderFromConfig(bedrockConfig) } -func vertexAIProvider() Provider { +func vertexAIProvider() provider.Provider { return loadProviderFromConfig(vertexAIConfig) } -func xAIProvider() Provider { +func xAIProvider() provider.Provider { return loadProviderFromConfig(xAIConfig) } -func openRouterProvider() Provider { +func openRouterProvider() provider.Provider { return loadProviderFromConfig(openRouterConfig) } diff --git a/internal/providers/types.go b/pkg/provider/provider.go similarity index 64% rename from internal/providers/types.go rename to pkg/provider/provider.go index e5e158d239d3041f5a1694820e49da564308131d..883a28d9a311b68fed17b24813eb4b99d7aafd3c 100644 --- a/internal/providers/types.go +++ b/pkg/provider/provider.go @@ -1,5 +1,6 @@ -package providers +package provider +// ProviderType represents the type of AI provider type ProviderType string const ( @@ -13,6 +14,7 @@ const ( ProviderTypeOpenRouter ProviderType = "openrouter" ) +// InferenceProvider represents the inference provider identifier type InferenceProvider string const ( @@ -26,6 +28,7 @@ const ( InferenceProviderOpenRouter InferenceProvider = "openrouter" ) +// Provider represents an AI provider configuration type Provider struct { Name string `json:"name"` ID InferenceProvider `json:"id"` @@ -36,6 +39,7 @@ type Provider struct { Models []Model `json:"models,omitempty"` } +// Model represents an AI model configuration type Model struct { ID string `json:"id"` Name string `json:"model"` @@ -48,41 +52,3 @@ type Model struct { CanReason bool `json:"can_reason"` SupportsImages bool `json:"supports_attachments"` } - -type ProviderFunc func() Provider - -var providerRegistry = map[InferenceProvider]ProviderFunc{ - InferenceProviderOpenAI: openAIProvider, - InferenceProviderAnthropic: anthropicProvider, - InferenceProviderGemini: geminiProvider, - InferenceProviderAzure: azureProvider, - InferenceProviderBedrock: bedrockProvider, - InferenceProviderVertexAI: vertexAIProvider, - InferenceProviderXAI: xAIProvider, - InferenceProviderOpenRouter: openRouterProvider, -} - -func GetAll() []Provider { - providers := make([]Provider, 0, len(providerRegistry)) - for _, providerFunc := range providerRegistry { - providers = append(providers, providerFunc()) - } - return providers -} - -func GetByID(id InferenceProvider) (Provider, bool) { - providerFunc, exists := providerRegistry[id] - if !exists { - return Provider{}, false - } - return providerFunc(), true -} - -func GetAvailableIDs() []InferenceProvider { - ids := make([]InferenceProvider, 0, len(providerRegistry)) - for id := range providerRegistry { - ids = append(ids, id) - } - return ids -} -