From 00aa3969e0131ca12afaf87cdcbaf64ed20a2933 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 6 Feb 2026 12:47:31 +0100 Subject: [PATCH] refactor: move provider configuration logic from Config to Service MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move configureProviders, configureSelectedModels, and defaultModelSelection from Config receiver to Service receiver. Load() now calls svc.configureProviders() and svc.configureSelectedModels(). Tests use a serviceFor() helper. 🐾 Generated with Crush Assisted-by: Claude Opus 4.6 via Crush --- internal/config/load.go | 17 ++++--- internal/config/load_test.go | 94 +++++++++++++++++++----------------- 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/internal/config/load.go b/internal/config/load.go index 84530dbff179689102ee5e4e1863714674485e63..96b4e06ce97a3ae0ee0d2bcac92d3152792a51da 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -89,7 +89,7 @@ func Load(workingDir, dataDir string, debug bool) (*Service, error) { valueResolver := NewShellVariableResolver(env) svc.resolver = valueResolver cfg.resolver = valueResolver - if err := cfg.configureProviders(env, valueResolver, svc.knownProviders); err != nil { + if err := svc.configureProviders(env, valueResolver, svc.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } @@ -98,7 +98,7 @@ func Load(workingDir, dataDir string, debug bool) (*Service, error) { return svc, nil } - if err := cfg.configureSelectedModels(svc.knownProviders); err != nil { + if err := svc.configureSelectedModels(svc.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure selected models: %w", err) } cfg.SetupAgents() @@ -133,7 +133,8 @@ func PushPopCrushEnv() func() { return restore } -func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { +func (s *Service) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { + c := s.cfg knownProviderNames := make(map[string]bool) restore := PushPopCrushEnv() defer restore() @@ -220,7 +221,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know switch { case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: // Claude Code subscription is not supported anymore. Remove to show onboarding. - c.removeConfigField("providers.anthropic") + s.RemoveConfigField("providers.anthropic") c.Providers.Del(string(p.ID)) continue case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil: @@ -471,7 +472,8 @@ func (c *Config) applyLSPDefaults() { } } -func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { +func (s *Service) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { + c := s.cfg if len(knownProviders) == 0 && c.Providers.Len() == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return largeModel, smallModel, err @@ -540,8 +542,9 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large return largeModel, smallModel, err } -func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error { - defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders) +func (s *Service) configureSelectedModels(knownProviders []catwalk.Provider) error { + c := s.cfg + defaultLarge, defaultSmall, err := s.defaultModelSelection(knownProviders) if err != nil { return fmt.Errorf("failed to select default models: %w", err) } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 60a0b7379501a7d766b33c4828c644cdb390bada..71b94ca5af1a09d98583099c503b9555dab5e462 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -14,6 +14,10 @@ import ( "github.com/stretchr/testify/require" ) +func serviceFor(cfg *Config) *Service { + return &Service{cfg: cfg} +} + func TestMain(m *testing.M) { slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil))) @@ -74,7 +78,7 @@ func TestConfig_configureProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, 1, cfg.Providers.Len()) @@ -117,7 +121,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, 1, cfg.Providers.Len()) @@ -159,7 +163,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Should be to because of the env variable require.Equal(t, cfg.Providers.Len(), 2) @@ -195,7 +199,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -221,7 +225,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without credentials require.Equal(t, cfg.Providers.Len(), 0) @@ -246,7 +250,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.Error(t, err) } @@ -269,7 +273,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { "VERTEXAI_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -301,7 +305,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without proper credentials require.Equal(t, cfg.Providers.Len(), 0) @@ -326,7 +330,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without project require.Equal(t, cfg.Providers.Len(), 0) @@ -350,7 +354,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -541,7 +545,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -569,7 +573,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -592,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -614,7 +618,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -639,7 +643,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -664,7 +668,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -692,7 +696,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -722,7 +726,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -757,7 +761,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { "GOOGLE_GENAI_USE_VERTEXAI": "false", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -788,7 +792,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -819,7 +823,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) @@ -852,7 +856,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) @@ -886,10 +890,10 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - large, small, err := cfg.defaultModelSelection(knownProviders) + large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders) require.NoError(t, err) require.Equal(t, "large-model", large.Model) require.Equal(t, "openai", large.Provider) @@ -922,10 +926,10 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - _, _, err = cfg.defaultModelSelection(knownProviders) + _, _, err = serviceFor(cfg).defaultModelSelection(knownProviders) require.Error(t, err) }) t.Run("should error if model is missing", func(t *testing.T) { @@ -952,9 +956,9 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - _, _, err = cfg.defaultModelSelection(knownProviders) + _, _, err = serviceFor(cfg).defaultModelSelection(knownProviders) require.Error(t, err) }) @@ -995,9 +999,9 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - large, small, err := cfg.defaultModelSelection(knownProviders) + large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders) require.NoError(t, err) require.Equal(t, "model", large.Model) require.Equal(t, "custom", large.Provider) @@ -1039,9 +1043,9 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - _, _, err = cfg.defaultModelSelection(knownProviders) + _, _, err = serviceFor(cfg).defaultModelSelection(knownProviders) require.Error(t, err) }) t.Run("should use the default provider first", func(t *testing.T) { @@ -1081,9 +1085,9 @@ func TestConfig_defaultModelSelection(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - large, small, err := cfg.defaultModelSelection(knownProviders) + large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders) require.NoError(t, err) require.Equal(t, "large-model", large.Model) require.Equal(t, "openai", large.Provider) @@ -1126,7 +1130,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // openai should NOT be present because it lacks base_url and models. @@ -1169,7 +1173,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Only fully specified provider should be present. @@ -1223,7 +1227,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { "ANTHROPIC_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Both providers should be present. @@ -1251,7 +1255,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) // Provider should be rejected for missing models. @@ -1275,7 +1279,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) // Provider should be rejected for missing base_url. @@ -1340,10 +1344,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = serviceFor(cfg).configureSelectedModels(knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] small := cfg.Models[SelectedModelTypeSmall] @@ -1402,10 +1406,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = serviceFor(cfg).configureSelectedModels(knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] small := cfg.Models[SelectedModelTypeSmall] @@ -1447,10 +1451,10 @@ func TestConfig_configureSelectedModels(t *testing.T) { cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) - err := cfg.configureProviders(env, resolver, knownProviders) + err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - err = cfg.configureSelectedModels(knownProviders) + err = serviceFor(cfg).configureSelectedModels(knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] require.Equal(t, "large-model", large.Model)