diff --git a/cspell.json b/cspell.json index 5b0877dc174821537da9b67ea67d638965f9a9f7..34247df510135ec19d9129f3cd4f388437454299 100644 --- a/cspell.json +++ b/cspell.json @@ -1 +1 @@ -{"language":"en","version":"0.2","flagWords":[],"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema","preconfigured","jsons","qjebbs","LOCALAPPDATA","USERPROFILE","stretchr"]} \ No newline at end of file +{"words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai","jsonschema","preconfigured","jsons","qjebbs","LOCALAPPDATA","USERPROFILE","stretchr","cursorrules","VERTEXAI","VERTEXAI"],"flagWords":[],"language":"en","version":"0.2"} \ No newline at end of file diff --git a/pkg/config/config.go b/pkg/config/config.go index 476c356e95868ce1ac7679c309abd939050b8b6a..de1d2abe08c91e23e5ac6c10c8b45e0a72efa2ba 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -29,6 +29,13 @@ var defaultContextPaths = []string{ "CRUSH.local.md", } +type SelectedModelType string + +const ( + SelectedModelTypeLarge SelectedModelType = "large" + SelectedModelTypeSmall SelectedModelType = "small" +) + type SelectedModel struct { // The model id as used by the provider API. // Required. @@ -48,6 +55,8 @@ type SelectedModel struct { } type ProviderConfig struct { + // The provider's id. + ID string `json:"id,omitempty"` // The provider's API endpoint. BaseURL string `json:"base_url,omitempty"` // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai. @@ -153,7 +162,7 @@ func (l LSPs) Sorted() []LSP { // Config holds the configuration for crush. type Config struct { // We currently only support large/small as values here. - Models map[string]SelectedModel `json:"models,omitempty"` + Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` // The providers that are configured Providers map[string]ProviderConfig `json:"providers,omitempty"` @@ -171,3 +180,18 @@ type Config struct { func (c *Config) WorkingDir() string { return c.workingDir } + +func (c *Config) EnabledProviders() []ProviderConfig { + enabled := make([]ProviderConfig, 0, len(c.Providers)) + for _, p := range c.Providers { + if !p.Disable { + enabled = append(enabled, p) + } + } + return enabled +} + +// IsConfigured return true if at least one provider is configured +func (c *Config) IsConfigured() bool { + return len(c.EnabledProviders()) > 0 +} diff --git a/pkg/config/load.go b/pkg/config/load.go index 7b996aa9c99867ade245a92fe4b783f34148fe66..8294edef917b8e9ab9f190676181eaa90f62b8b9 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -14,6 +14,7 @@ import ( "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/pkg/env" "github.com/charmbracelet/crush/pkg/log" + "golang.org/x/exp/slog" ) // LoadReader config via io.Reader. @@ -55,15 +56,12 @@ func Load(workingDir string, debug bool) (*Config, error) { if err != nil { return nil, fmt.Errorf("failed to load config: %w", err) } - // TODO: maybe add a validation step here right after loading - // e.x validate the models - // e.x validate provider config cfg.setDefaults(workingDir) // Load known providers, this loads the config from fur providers, err := LoadProviders(client.New()) - if err != nil { + if err != nil || len(providers) == 0 { return nil, fmt.Errorf("failed to load providers: %w", err) } @@ -74,15 +72,35 @@ func Load(workingDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to configure providers: %w", err) } + if !cfg.IsConfigured() { + slog.Warn("No providers configured") + return cfg, nil + } + + // largeModel, ok := cfg.Models[SelectedModelTypeLarge] + // if !ok { + // // set default + // } + // smallModel, ok := cfg.Models[SelectedModelTypeSmall] + // if !ok { + // // set default + // } + // return cfg, nil } func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error { + knownProviderNames := make(map[string]bool) for _, p := range knownProviders { - - config, ok := cfg.Providers[string(p.ID)] + knownProviderNames[string(p.ID)] = true + config, configExists := cfg.Providers[string(p.ID)] // if the user configured a known provider we need to allow it to override a couple of parameters - if ok { + if configExists { + if config.Disable { + slog.Debug("Skipping provider due to disable flag", "provider", p.ID) + delete(cfg.Providers, string(p.ID)) + continue + } if config.BaseURL != "" { p.APIEndpoint = config.BaseURL } @@ -112,6 +130,7 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn } } prepared := ProviderConfig{ + ID: string(p.ID), BaseURL: p.APIEndpoint, APIKey: p.APIKey, Type: p.Type, @@ -125,12 +144,20 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn // Handle specific providers that require additional configuration case provider.InferenceProviderVertexAI: if !hasVertexCredentials(env) { + if configExists { + slog.Warn("Skipping Vertex AI provider due to missing credentials") + delete(cfg.Providers, string(p.ID)) + } continue } prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT") prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION") case provider.InferenceProviderBedrock: if !hasAWSCredentials(env) { + if configExists { + slog.Warn("Skipping Bedrock provider due to missing AWS credentials") + delete(cfg.Providers, string(p.ID)) + } continue } for _, model := range p.Models { @@ -142,15 +169,70 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn // if the provider api or endpoint are missing we skip them v, err := resolver.ResolveValue(p.APIKey) if v == "" || err != nil { - continue - } - v, err = resolver.ResolveValue(p.APIEndpoint) - if v == "" || err != nil { + if configExists { + slog.Warn("Skipping provider due to missing API key", "provider", p.ID) + delete(cfg.Providers, string(p.ID)) + } continue } } cfg.Providers[string(p.ID)] = prepared } + + // validate the custom providers + for id, providerConfig := range cfg.Providers { + if knownProviderNames[id] { + continue + } + + // Make sure the provider ID is set + providerConfig.ID = id + // default to OpenAI if not set + if providerConfig.Type == "" { + providerConfig.Type = provider.TypeOpenAI + } + + if providerConfig.Disable { + slog.Debug("Skipping custom provider due to disable flag", "provider", id) + delete(cfg.Providers, id) + continue + } + if providerConfig.APIKey == "" { + slog.Warn("Skipping custom provider due to missing API key", "provider", id) + delete(cfg.Providers, id) + continue + } + if providerConfig.BaseURL == "" { + slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id) + delete(cfg.Providers, id) + continue + } + if len(providerConfig.Models) == 0 { + slog.Warn("Skipping custom provider because the provider has no models", "provider", id) + delete(cfg.Providers, id) + continue + } + if providerConfig.Type != provider.TypeOpenAI { + slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type) + delete(cfg.Providers, id) + continue + } + + apiKey, err := resolver.ResolveValue(providerConfig.APIKey) + if apiKey == "" || err != nil { + slog.Warn("Skipping custom provider due to missing API key", "provider", id, "error", err) + delete(cfg.Providers, id) + continue + } + baseURL, err := resolver.ResolveValue(providerConfig.BaseURL) + if baseURL == "" || err != nil { + slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err) + delete(cfg.Providers, id) + continue + } + + cfg.Providers[id] = providerConfig + } return nil } @@ -200,7 +282,7 @@ func (cfg *Config) setDefaults(workingDir string) { cfg.Providers = make(map[string]ProviderConfig) } if cfg.Models == nil { - cfg.Models = make(map[string]SelectedModel) + cfg.Models = make(map[SelectedModelType]SelectedModel) } if cfg.MCP == nil { cfg.MCP = make(map[string]MCPConfig) diff --git a/pkg/config/load_test.go b/pkg/config/load_test.go index e2dd943d58b24f55b433bbd7050d95c9bda18277..911b203de887b255fdc80c25d4433fd5de52a419 100644 --- a/pkg/config/load_test.go +++ b/pkg/config/load_test.go @@ -2,6 +2,8 @@ package config import ( "io" + "log/slog" + "os" "strings" "testing" @@ -10,6 +12,13 @@ import ( "github.com/stretchr/testify/assert" ) +func TestMain(m *testing.M) { + slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil))) + + exitVal := m.Run() + os.Exit(exitVal) +} + func TestConfig_LoadFromReaders(t *testing.T) { data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`) data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`) @@ -152,6 +161,8 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { // We want to make sure that we keep the configured API key as a placeholder assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey) + // Make sure we set the ID correctly + assert.Equal(t, "custom", cfg.Providers["custom"].ID) assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL) assert.Len(t, cfg.Providers["custom"].Models, 1) @@ -315,3 +326,437 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { // Provider should not be configured without project assert.Len(t, cfg.Providers, 0) } + +func TestConfig_configureProvidersSetProviderID(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "https://api.openai.com/v1", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + } + + cfg := &Config{} + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + assert.Len(t, cfg.Providers, 1) + + // Provider ID should be set + assert.Equal(t, "openai", cfg.Providers["openai"].ID) +} + +func TestConfig_EnabledProviders(t *testing.T) { + t.Run("all providers enabled", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "openai": { + ID: "openai", + APIKey: "key1", + Disable: false, + }, + "anthropic": { + ID: "anthropic", + APIKey: "key2", + Disable: false, + }, + }, + } + + enabled := cfg.EnabledProviders() + assert.Len(t, enabled, 2) + }) + + t.Run("some providers disabled", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "openai": { + ID: "openai", + APIKey: "key1", + Disable: false, + }, + "anthropic": { + ID: "anthropic", + APIKey: "key2", + Disable: true, + }, + }, + } + + enabled := cfg.EnabledProviders() + assert.Len(t, enabled, 1) + assert.Equal(t, "openai", enabled[0].ID) + }) + + t.Run("empty providers map", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{}, + } + + enabled := cfg.EnabledProviders() + assert.Len(t, enabled, 0) + }) +} + +func TestConfig_IsConfigured(t *testing.T) { + t.Run("returns true when at least one provider is enabled", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "openai": { + ID: "openai", + APIKey: "key1", + Disable: false, + }, + }, + } + + assert.True(t, cfg.IsConfigured()) + }) + + t.Run("returns false when no providers are configured", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{}, + } + + assert.False(t, cfg.IsConfigured()) + }) + + t.Run("returns false when all providers are disabled", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "openai": { + ID: "openai", + APIKey: "key1", + Disable: true, + }, + "anthropic": { + ID: "anthropic", + APIKey: "key2", + Disable: true, + }, + }, + } + + assert.False(t, cfg.IsConfigured()) + }) +} + +func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "https://api.openai.com/v1", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "openai": { + Disable: true, + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + // Provider should be removed from config when disabled + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["openai"] + assert.False(t, exists) +} + +func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { + t.Run("custom provider with missing API key is removed", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + BaseURL: "https://api.custom.com/v1", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, []provider.Provider{}) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["custom"] + assert.False(t, exists) + }) + + t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "test-key", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, []provider.Provider{}) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["custom"] + assert.False(t, exists) + }) + + t.Run("custom provider with no models is removed", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "test-key", + BaseURL: "https://api.custom.com/v1", + Models: []provider.Model{}, + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, []provider.Provider{}) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["custom"] + assert.False(t, exists) + }) + + t.Run("custom provider with unsupported type is removed", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "test-key", + BaseURL: "https://api.custom.com/v1", + Type: "unsupported", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, []provider.Provider{}) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["custom"] + assert.False(t, exists) + }) + + t.Run("valid custom provider is kept and ID is set", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "test-key", + BaseURL: "https://api.custom.com/v1", + Type: provider.TypeOpenAI, + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, []provider.Provider{}) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 1) + customProvider, exists := cfg.Providers["custom"] + assert.True(t, exists) + assert.Equal(t, "custom", customProvider.ID) + assert.Equal(t, "test-key", customProvider.APIKey) + assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL) + }) + + t.Run("disabled custom provider is removed", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "test-key", + BaseURL: "https://api.custom.com/v1", + Type: provider.TypeOpenAI, + Disable: true, + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, []provider.Provider{}) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["custom"] + assert.False(t, exists) + }) +} + +func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { + t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: provider.InferenceProviderVertexAI, + APIKey: "", + APIEndpoint: "", + Models: []provider.Model{{ + ID: "gemini-pro", + }}, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "vertexai": { + BaseURL: "custom-url", + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{ + "GOOGLE_GENAI_USE_VERTEXAI": "false", + }) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["vertexai"] + assert.False(t, exists) + }) + + t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: provider.InferenceProviderBedrock, + APIKey: "", + APIEndpoint: "", + Models: []provider.Model{{ + ID: "anthropic.claude-sonnet-4-20250514-v1:0", + }}, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "bedrock": { + BaseURL: "custom-url", + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["bedrock"] + assert.False(t, exists) + }) + + t.Run("provider removed when API key missing with existing config", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$MISSING_API_KEY", + APIEndpoint: "https://api.openai.com/v1", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "openai": { + BaseURL: "custom-url", + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 0) + _, exists := cfg.Providers["openai"] + assert.False(t, exists) + }) + + t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$OPENAI_API_KEY", + APIEndpoint: "$MISSING_ENDPOINT", + Models: []provider.Model{{ + ID: "test-model", + }}, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "openai": { + APIKey: "test-key", + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 1) + _, exists := cfg.Providers["openai"] + assert.True(t, exists) + }) +}