@@ -14,6 +14,10 @@ import (
"github.com/stretchr/testify/require"
)
+func serviceFor(cfg *Config) *Service {
+ return &Service{cfg: cfg}
+}
+
func TestMain(m *testing.M) {
slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
@@ -74,7 +78,7 @@ func TestConfig_configureProviders(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, 1, cfg.Providers.Len())
@@ -117,7 +121,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, 1, cfg.Providers.Len())
@@ -159,7 +163,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
// Should be to because of the env variable
require.Equal(t, cfg.Providers.Len(), 2)
@@ -195,7 +199,7 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -221,7 +225,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
// Provider should not be configured without credentials
require.Equal(t, cfg.Providers.Len(), 0)
@@ -246,7 +250,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.Error(t, err)
}
@@ -269,7 +273,7 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
"VERTEXAI_LOCATION": "us-central1",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -301,7 +305,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
"GOOGLE_CLOUD_LOCATION": "us-central1",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
// Provider should not be configured without proper credentials
require.Equal(t, cfg.Providers.Len(), 0)
@@ -326,7 +330,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
"GOOGLE_CLOUD_LOCATION": "us-central1",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
// Provider should not be configured without project
require.Equal(t, cfg.Providers.Len(), 0)
@@ -350,7 +354,7 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -541,7 +545,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -569,7 +573,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -592,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -614,7 +618,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -639,7 +643,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -664,7 +668,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -692,7 +696,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -722,7 +726,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -757,7 +761,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
"GOOGLE_GENAI_USE_VERTEXAI": "false",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -788,7 +792,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -819,7 +823,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 0)
@@ -852,7 +856,7 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
@@ -886,10 +890,10 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- large, small, err := cfg.defaultModelSelection(knownProviders)
+ large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders)
require.NoError(t, err)
require.Equal(t, "large-model", large.Model)
require.Equal(t, "openai", large.Provider)
@@ -922,10 +926,10 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- _, _, err = cfg.defaultModelSelection(knownProviders)
+ _, _, err = serviceFor(cfg).defaultModelSelection(knownProviders)
require.Error(t, err)
})
t.Run("should error if model is missing", func(t *testing.T) {
@@ -952,9 +956,9 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- _, _, err = cfg.defaultModelSelection(knownProviders)
+ _, _, err = serviceFor(cfg).defaultModelSelection(knownProviders)
require.Error(t, err)
})
@@ -995,9 +999,9 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- large, small, err := cfg.defaultModelSelection(knownProviders)
+ large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders)
require.NoError(t, err)
require.Equal(t, "model", large.Model)
require.Equal(t, "custom", large.Provider)
@@ -1039,9 +1043,9 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- _, _, err = cfg.defaultModelSelection(knownProviders)
+ _, _, err = serviceFor(cfg).defaultModelSelection(knownProviders)
require.Error(t, err)
})
t.Run("should use the default provider first", func(t *testing.T) {
@@ -1081,9 +1085,9 @@ func TestConfig_defaultModelSelection(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- large, small, err := cfg.defaultModelSelection(knownProviders)
+ large, small, err := serviceFor(cfg).defaultModelSelection(knownProviders)
require.NoError(t, err)
require.Equal(t, "large-model", large.Model)
require.Equal(t, "openai", large.Provider)
@@ -1126,7 +1130,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
// openai should NOT be present because it lacks base_url and models.
@@ -1169,7 +1173,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
"OPENAI_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
// Only fully specified provider should be present.
@@ -1223,7 +1227,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
"ANTHROPIC_API_KEY": "test-key",
})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
// Both providers should be present.
@@ -1251,7 +1255,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
// Provider should be rejected for missing models.
@@ -1275,7 +1279,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
+ err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{})
require.NoError(t, err)
// Provider should be rejected for missing base_url.
@@ -1340,10 +1344,10 @@ func TestConfig_configureSelectedModels(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- err = cfg.configureSelectedModels(knownProviders)
+ err = serviceFor(cfg).configureSelectedModels(knownProviders)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
small := cfg.Models[SelectedModelTypeSmall]
@@ -1402,10 +1406,10 @@ func TestConfig_configureSelectedModels(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- err = cfg.configureSelectedModels(knownProviders)
+ err = serviceFor(cfg).configureSelectedModels(knownProviders)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
small := cfg.Models[SelectedModelTypeSmall]
@@ -1447,10 +1451,10 @@ func TestConfig_configureSelectedModels(t *testing.T) {
cfg.setDefaults("/tmp", "")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
- err := cfg.configureProviders(env, resolver, knownProviders)
+ err := serviceFor(cfg).configureProviders(env, resolver, knownProviders)
require.NoError(t, err)
- err = cfg.configureSelectedModels(knownProviders)
+ err = serviceFor(cfg).configureSelectedModels(knownProviders)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
require.Equal(t, "large-model", large.Model)