@@ -1 +1 @@
-{"language":"en","version":"0.2","flagWords":[],"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema","preconfigured","jsons","qjebbs","LOCALAPPDATA","USERPROFILE","stretchr"]}
+{"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema","preconfigured","jsons","qjebbs","LOCALAPPDATA","USERPROFILE","stretchr","cursorrules","VERTEXAI","VERTEXAI"],"flagWords":[],"language":"en","version":"0.2"}
@@ -14,6 +14,7 @@ import (
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/pkg/env"
"github.com/charmbracelet/crush/pkg/log"
+ "golang.org/x/exp/slog"
)
// LoadReader config via io.Reader.
@@ -55,15 +56,12 @@ func Load(workingDir string, debug bool) (*Config, error) {
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
- // TODO: maybe add a validation step here right after loading
- // e.x validate the models
- // e.x validate provider config
cfg.setDefaults(workingDir)
// Load known providers, this loads the config from fur
providers, err := LoadProviders(client.New())
- if err != nil {
+ if err != nil || len(providers) == 0 {
return nil, fmt.Errorf("failed to load providers: %w", err)
}
@@ -74,15 +72,35 @@ func Load(workingDir string, debug bool) (*Config, error) {
return nil, fmt.Errorf("failed to configure providers: %w", err)
}
+ if !cfg.IsConfigured() {
+ slog.Warn("No providers configured")
+ return cfg, nil
+ }
+
+ // largeModel, ok := cfg.Models[SelectedModelTypeLarge]
+ // if !ok {
+ // // set default
+ // }
+ // smallModel, ok := cfg.Models[SelectedModelTypeSmall]
+ // if !ok {
+ // // set default
+ // }
+ //
return cfg, nil
}
func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
+ knownProviderNames := make(map[string]bool)
for _, p := range knownProviders {
-
- config, ok := cfg.Providers[string(p.ID)]
+ knownProviderNames[string(p.ID)] = true
+ config, configExists := cfg.Providers[string(p.ID)]
// if the user configured a known provider we need to allow it to override a couple of parameters
- if ok {
+ if configExists {
+ if config.Disable {
+ slog.Debug("Skipping provider due to disable flag", "provider", p.ID)
+ delete(cfg.Providers, string(p.ID))
+ continue
+ }
if config.BaseURL != "" {
p.APIEndpoint = config.BaseURL
}
@@ -112,6 +130,7 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn
}
}
prepared := ProviderConfig{
+ ID: string(p.ID),
BaseURL: p.APIEndpoint,
APIKey: p.APIKey,
Type: p.Type,
@@ -125,12 +144,20 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn
// Handle specific providers that require additional configuration
case provider.InferenceProviderVertexAI:
if !hasVertexCredentials(env) {
+ if configExists {
+ slog.Warn("Skipping Vertex AI provider due to missing credentials")
+ delete(cfg.Providers, string(p.ID))
+ }
continue
}
prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
case provider.InferenceProviderBedrock:
if !hasAWSCredentials(env) {
+ if configExists {
+ slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
+ delete(cfg.Providers, string(p.ID))
+ }
continue
}
for _, model := range p.Models {
@@ -142,15 +169,70 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn
// if the provider api or endpoint are missing we skip them
v, err := resolver.ResolveValue(p.APIKey)
if v == "" || err != nil {
- continue
- }
- v, err = resolver.ResolveValue(p.APIEndpoint)
- if v == "" || err != nil {
+ if configExists {
+ slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
+ delete(cfg.Providers, string(p.ID))
+ }
continue
}
}
cfg.Providers[string(p.ID)] = prepared
}
+
+ // validate the custom providers
+ for id, providerConfig := range cfg.Providers {
+ if knownProviderNames[id] {
+ continue
+ }
+
+ // Make sure the provider ID is set
+ providerConfig.ID = id
+ // default to OpenAI if not set
+ if providerConfig.Type == "" {
+ providerConfig.Type = provider.TypeOpenAI
+ }
+
+ if providerConfig.Disable {
+ slog.Debug("Skipping custom provider due to disable flag", "provider", id)
+ delete(cfg.Providers, id)
+ continue
+ }
+ if providerConfig.APIKey == "" {
+ slog.Warn("Skipping custom provider due to missing API key", "provider", id)
+ delete(cfg.Providers, id)
+ continue
+ }
+ if providerConfig.BaseURL == "" {
+ slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
+ delete(cfg.Providers, id)
+ continue
+ }
+ if len(providerConfig.Models) == 0 {
+ slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
+ delete(cfg.Providers, id)
+ continue
+ }
+ if providerConfig.Type != provider.TypeOpenAI {
+ slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
+ delete(cfg.Providers, id)
+ continue
+ }
+
+ apiKey, err := resolver.ResolveValue(providerConfig.APIKey)
+ if apiKey == "" || err != nil {
+ slog.Warn("Skipping custom provider due to missing API key", "provider", id, "error", err)
+ delete(cfg.Providers, id)
+ continue
+ }
+ baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
+ if baseURL == "" || err != nil {
+ slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
+ delete(cfg.Providers, id)
+ continue
+ }
+
+ cfg.Providers[id] = providerConfig
+ }
return nil
}
@@ -200,7 +282,7 @@ func (cfg *Config) setDefaults(workingDir string) {
cfg.Providers = make(map[string]ProviderConfig)
}
if cfg.Models == nil {
- cfg.Models = make(map[string]SelectedModel)
+ cfg.Models = make(map[SelectedModelType]SelectedModel)
}
if cfg.MCP == nil {
cfg.MCP = make(map[string]MCPConfig)
@@ -2,6 +2,8 @@ package config
import (
"io"
+ "log/slog"
+ "os"
"strings"
"testing"
@@ -10,6 +12,13 @@ import (
"github.com/stretchr/testify/assert"
)
+func TestMain(m *testing.M) {
+ slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
+
+ exitVal := m.Run()
+ os.Exit(exitVal)
+}
+
func TestConfig_LoadFromReaders(t *testing.T) {
data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
@@ -152,6 +161,8 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
// We want to make sure that we keep the configured API key as a placeholder
assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey)
+ // Make sure we set the ID correctly
+ assert.Equal(t, "custom", cfg.Providers["custom"].ID)
assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL)
assert.Len(t, cfg.Providers["custom"].Models, 1)
@@ -315,3 +326,437 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
// Provider should not be configured without project
assert.Len(t, cfg.Providers, 0)
}
+
+func TestConfig_configureProvidersSetProviderID(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ }
+
+ cfg := &Config{}
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{
+ "OPENAI_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+ assert.Len(t, cfg.Providers, 1)
+
+ // Provider ID should be set
+ assert.Equal(t, "openai", cfg.Providers["openai"].ID)
+}
+
+func TestConfig_EnabledProviders(t *testing.T) {
+ t.Run("all providers enabled", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "openai": {
+ ID: "openai",
+ APIKey: "key1",
+ Disable: false,
+ },
+ "anthropic": {
+ ID: "anthropic",
+ APIKey: "key2",
+ Disable: false,
+ },
+ },
+ }
+
+ enabled := cfg.EnabledProviders()
+ assert.Len(t, enabled, 2)
+ })
+
+ t.Run("some providers disabled", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "openai": {
+ ID: "openai",
+ APIKey: "key1",
+ Disable: false,
+ },
+ "anthropic": {
+ ID: "anthropic",
+ APIKey: "key2",
+ Disable: true,
+ },
+ },
+ }
+
+ enabled := cfg.EnabledProviders()
+ assert.Len(t, enabled, 1)
+ assert.Equal(t, "openai", enabled[0].ID)
+ })
+
+ t.Run("empty providers map", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{},
+ }
+
+ enabled := cfg.EnabledProviders()
+ assert.Len(t, enabled, 0)
+ })
+}
+
+func TestConfig_IsConfigured(t *testing.T) {
+ t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "openai": {
+ ID: "openai",
+ APIKey: "key1",
+ Disable: false,
+ },
+ },
+ }
+
+ assert.True(t, cfg.IsConfigured())
+ })
+
+ t.Run("returns false when no providers are configured", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{},
+ }
+
+ assert.False(t, cfg.IsConfigured())
+ })
+
+ t.Run("returns false when all providers are disabled", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "openai": {
+ ID: "openai",
+ APIKey: "key1",
+ Disable: true,
+ },
+ "anthropic": {
+ ID: "anthropic",
+ APIKey: "key2",
+ Disable: true,
+ },
+ },
+ }
+
+ assert.False(t, cfg.IsConfigured())
+ })
+}
+
+func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "openai": {
+ Disable: true,
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{
+ "OPENAI_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ // Provider should be removed from config when disabled
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["openai"]
+ assert.False(t, exists)
+}
+
+func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
+ t.Run("custom provider with missing API key is removed", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ BaseURL: "https://api.custom.com/v1",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []provider.Provider{})
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["custom"]
+ assert.False(t, exists)
+ })
+
+ t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "test-key",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []provider.Provider{})
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["custom"]
+ assert.False(t, exists)
+ })
+
+ t.Run("custom provider with no models is removed", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "test-key",
+ BaseURL: "https://api.custom.com/v1",
+ Models: []provider.Model{},
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []provider.Provider{})
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["custom"]
+ assert.False(t, exists)
+ })
+
+ t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "test-key",
+ BaseURL: "https://api.custom.com/v1",
+ Type: "unsupported",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []provider.Provider{})
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["custom"]
+ assert.False(t, exists)
+ })
+
+ t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "test-key",
+ BaseURL: "https://api.custom.com/v1",
+ Type: provider.TypeOpenAI,
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []provider.Provider{})
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 1)
+ customProvider, exists := cfg.Providers["custom"]
+ assert.True(t, exists)
+ assert.Equal(t, "custom", customProvider.ID)
+ assert.Equal(t, "test-key", customProvider.APIKey)
+ assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
+ })
+
+ t.Run("disabled custom provider is removed", func(t *testing.T) {
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "test-key",
+ BaseURL: "https://api.custom.com/v1",
+ Type: provider.TypeOpenAI,
+ Disable: true,
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, []provider.Provider{})
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["custom"]
+ assert.False(t, exists)
+ })
+}
+
+func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
+ t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: provider.InferenceProviderVertexAI,
+ APIKey: "",
+ APIEndpoint: "",
+ Models: []provider.Model{{
+ ID: "gemini-pro",
+ }},
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "vertexai": {
+ BaseURL: "custom-url",
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{
+ "GOOGLE_GENAI_USE_VERTEXAI": "false",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["vertexai"]
+ assert.False(t, exists)
+ })
+
+ t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: provider.InferenceProviderBedrock,
+ APIKey: "",
+ APIEndpoint: "",
+ Models: []provider.Model{{
+ ID: "anthropic.claude-sonnet-4-20250514-v1:0",
+ }},
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "bedrock": {
+ BaseURL: "custom-url",
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["bedrock"]
+ assert.False(t, exists)
+ })
+
+ t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$MISSING_API_KEY",
+ APIEndpoint: "https://api.openai.com/v1",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "openai": {
+ BaseURL: "custom-url",
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 0)
+ _, exists := cfg.Providers["openai"]
+ assert.False(t, exists)
+ })
+
+ t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$OPENAI_API_KEY",
+ APIEndpoint: "$MISSING_ENDPOINT",
+ Models: []provider.Model{{
+ ID: "test-model",
+ }},
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "openai": {
+ APIKey: "test-key",
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+
+ env := env.NewFromMap(map[string]string{
+ "OPENAI_API_KEY": "test-key",
+ })
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ assert.Len(t, cfg.Providers, 1)
+ _, exists := cfg.Providers["openai"]
+ assert.True(t, exists)
+ })
+}