@@ -11,6 +11,7 @@ import (
"strings"
"sync"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
@@ -80,30 +81,34 @@ func Load(workingDir string, debug bool) (*Config, error) {
var testResults sync.Map
var wg sync.WaitGroup
- for _, p := range cfg.Providers {
- if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
- wg.Add(1)
- go func(provider ProviderConfig) {
- defer wg.Done()
- err := provider.TestConnection(cfg.resolver)
- testResults.Store(provider.ID, err == nil)
- if err != nil {
- slog.Error("Provider connection test failed", "provider", provider.ID, "error", err)
- }
- }(p)
- }
- }
- wg.Wait()
-
- // Remove failed providers
- testResults.Range(func(key, value any) bool {
- providerID := key.(string)
- passed := value.(bool)
- if !passed {
- delete(cfg.Providers, providerID)
+ go func() {
+ slog.Info("Testing provider connections")
+ defer slog.Info("Provider connection tests completed")
+ for _, p := range cfg.Providers.Seq2() {
+ if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
+ wg.Add(1)
+ go func(provider ProviderConfig) {
+ defer wg.Done()
+ err := provider.TestConnection(cfg.resolver)
+ testResults.Store(provider.ID, err == nil)
+ if err != nil {
+ slog.Error("Provider connection test failed", "provider", provider.ID, "error", err)
+ }
+ }(p)
+ }
}
- return true
- })
+ wg.Wait()
+
+ // Remove failed providers
+ testResults.Range(func(key, value any) bool {
+ providerID := key.(string)
+ passed := value.(bool)
+ if !passed {
+ cfg.Providers.Del(providerID)
+ }
+ return true
+ })
+ }()
if !cfg.IsConfigured() {
slog.Warn("No providers configured")
@@ -121,12 +126,12 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
knownProviderNames := make(map[string]bool)
for _, p := range knownProviders {
knownProviderNames[string(p.ID)] = true
- config, configExists := c.Providers[string(p.ID)]
+ config, configExists := c.Providers.Get(string(p.ID))
// if the user configured a known provider we need to allow it to override a couple of parameters
if configExists {
if config.Disable {
slog.Debug("Skipping provider due to disable flag", "provider", p.ID)
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
continue
}
if config.BaseURL != "" {
@@ -182,7 +187,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if !hasVertexCredentials(env) {
if configExists {
slog.Warn("Skipping Vertex AI provider due to missing credentials")
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
}
continue
}
@@ -193,7 +198,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if err != nil || endpoint == "" {
if configExists {
slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err)
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
}
continue
}
@@ -203,7 +208,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if !hasAWSCredentials(env) {
if configExists {
slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
}
continue
}
@@ -218,16 +223,16 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if v == "" || err != nil {
if configExists {
slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
}
continue
}
}
- c.Providers[string(p.ID)] = prepared
+ c.Providers.Set(string(p.ID), prepared)
}
// validate the custom providers
- for id, providerConfig := range c.Providers {
+ for id, providerConfig := range c.Providers.Seq2() {
if knownProviderNames[id] {
continue
}
@@ -244,7 +249,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if providerConfig.Disable {
slog.Debug("Skipping custom provider due to disable flag", "provider", id)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
if providerConfig.APIKey == "" {
@@ -252,17 +257,17 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
}
if providerConfig.BaseURL == "" {
slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
if len(providerConfig.Models) == 0 {
slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
if providerConfig.Type != provider.TypeOpenAI {
slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
@@ -273,11 +278,11 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
if baseURL == "" || err != nil {
slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
- c.Providers[id] = providerConfig
+ c.Providers.Set(id, providerConfig)
}
return nil
}
@@ -297,7 +302,7 @@ func (c *Config) setDefaults(workingDir string) {
c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
}
if c.Providers == nil {
- c.Providers = make(map[string]ProviderConfig)
+ c.Providers = csync.NewMap[string, ProviderConfig]()
}
if c.Models == nil {
c.Models = make(map[SelectedModelType]SelectedModel)
@@ -316,7 +321,7 @@ func (c *Config) setDefaults(workingDir string) {
}
func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
- if len(knownProviders) == 0 && len(c.Providers) == 0 {
+ if len(knownProviders) == 0 { // TODO:}&& len(c.Providers) == 0 {
err = fmt.Errorf("no providers configured, please configure at least one provider")
return
}
@@ -324,7 +329,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg
// Use the first provider enabled based on the known providers order
// if no provider found that is known use the first provider configured
for _, p := range knownProviders {
- providerConfig, ok := c.Providers[string(p.ID)]
+ providerConfig, ok := c.Providers.Get(string(p.ID))
if !ok || providerConfig.Disable {
continue
}
@@ -8,6 +8,7 @@ import (
"strings"
"testing"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/stretchr/testify/assert"
@@ -29,9 +30,10 @@ func TestConfig_LoadFromReaders(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, loadedConfig)
- assert.Len(t, loadedConfig.Providers, 1)
- assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey)
- assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL)
+ assert.Equal(t, 1, loadedConfig.Providers.Len())
+ pc, _ := loadedConfig.Providers.Get("openai")
+ assert.Equal(t, "key2", pc.APIKey)
+ assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
}
func TestConfig_setDefaults(t *testing.T) {
@@ -73,10 +75,11 @@ func TestConfig_configureProviders(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, 1, cfg.Providers.Len())
// We want to make sure that we keep the configured API key as a placeholder
- assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey)
+ pc, _ := cfg.Providers.Get("openai")
+ assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
}
func TestConfig_configureProvidersWithOverride(t *testing.T) {
@@ -92,22 +95,21 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
- "openai": {
- APIKey: "xyz",
- BaseURL: "https://api.openai.com/v2",
- Models: []provider.Model{
- {
- ID: "test-model",
- Model: "Updated",
- },
- {
- ID: "another-model",
- },
- },
+ Providers: csync.NewMap[string, ProviderConfig](),
+ }
+ cfg.Providers.Set("openai", ProviderConfig{
+ APIKey: "xyz",
+ BaseURL: "https://api.openai.com/v2",
+ Models: []provider.Model{
+ {
+ ID: "test-model",
+ Model: "Updated",
+ },
+ {
+ ID: "another-model",
},
},
- }
+ })
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{
@@ -116,13 +118,14 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, 1, cfg.Providers.Len())
// We want to make sure that we keep the configured API key as a placeholder
- assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey)
- assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL)
- assert.Len(t, cfg.Providers["openai"].Models, 2)
- assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Model)
+ pc, _ := cfg.Providers.Get("openai")
+ assert.Equal(t, "xyz", pc.APIKey)
+ assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
+ assert.Len(t, pc.Models, 2)
+ assert.Equal(t, "Updated", pc.Models[0].Model)
}
func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
@@ -138,7 +141,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "xyz",
BaseURL: "https://api.someendpoint.com/v2",
@@ -148,7 +151,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
},
},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{
@@ -158,16 +161,17 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Should be to because of the env variable
- assert.Len(t, cfg.Providers, 2)
+ assert.Equal(t, cfg.Providers.Len(), 2)
// We want to make sure that we keep the configured API key as a placeholder
- assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey)
+ pc, _ := cfg.Providers.Get("custom")
+ assert.Equal(t, "xyz", pc.APIKey)
// Make sure we set the ID correctly
- assert.Equal(t, "custom", cfg.Providers["custom"].ID)
- assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL)
- assert.Len(t, cfg.Providers["custom"].Models, 1)
+ assert.Equal(t, "custom", pc.ID)
+ assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
+ assert.Len(t, pc.Models, 1)
- _, ok := cfg.Providers["openai"]
+ _, ok := cfg.Providers.Get("openai")
assert.True(t, ok, "OpenAI provider should still be present")
}
@@ -192,9 +196,9 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, cfg.Providers.Len(), 1)
- bedrockProvider, ok := cfg.Providers["bedrock"]
+ bedrockProvider, ok := cfg.Providers.Get("bedrock")
assert.True(t, ok, "Bedrock provider should be present")
assert.Len(t, bedrockProvider.Models, 1)
assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
@@ -219,7 +223,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without credentials
- assert.Len(t, cfg.Providers, 0)
+ assert.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
@@ -267,9 +271,9 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, cfg.Providers.Len(), 1)
- vertexProvider, ok := cfg.Providers["vertexai"]
+ vertexProvider, ok := cfg.Providers.Get("vertexai")
assert.True(t, ok, "VertexAI provider should be present")
assert.Len(t, vertexProvider.Models, 1)
assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
@@ -300,7 +304,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without proper credentials
- assert.Len(t, cfg.Providers, 0)
+ assert.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
@@ -325,7 +329,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without project
- assert.Len(t, cfg.Providers, 0)
+ assert.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersSetProviderID(t *testing.T) {
@@ -348,16 +352,17 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, cfg.Providers.Len(), 1)
// Provider ID should be set
- assert.Equal(t, "openai", cfg.Providers["openai"].ID)
+ pc, _ := cfg.Providers.Get("openai")
+ assert.Equal(t, "openai", pc.ID)
}
func TestConfig_EnabledProviders(t *testing.T) {
t.Run("all providers enabled", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -368,7 +373,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
APIKey: "key2",
Disable: false,
},
- },
+ }),
}
enabled := cfg.EnabledProviders()
@@ -377,7 +382,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
t.Run("some providers disabled", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -388,7 +393,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
APIKey: "key2",
Disable: true,
},
- },
+ }),
}
enabled := cfg.EnabledProviders()
@@ -398,7 +403,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
t.Run("empty providers map", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{},
+ Providers: csync.NewMap[string, ProviderConfig](),
}
enabled := cfg.EnabledProviders()
@@ -409,13 +414,13 @@ func TestConfig_EnabledProviders(t *testing.T) {
func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
ID: "openai",
APIKey: "key1",
Disable: false,
},
- },
+ }),
}
assert.True(t, cfg.IsConfigured())
@@ -423,7 +428,7 @@ func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns false when no providers are configured", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{},
+ Providers: csync.NewMap[string, ProviderConfig](),
}
assert.False(t, cfg.IsConfigured())
@@ -431,7 +436,7 @@ func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns false when all providers are disabled", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -442,7 +447,7 @@ func TestConfig_IsConfigured(t *testing.T) {
APIKey: "key2",
Disable: true,
},
- },
+ }),
}
assert.False(t, cfg.IsConfigured())
@@ -462,11 +467,11 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
Disable: true,
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -478,15 +483,15 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
assert.NoError(t, err)
// Provider should be removed from config when disabled
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["openai"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("openai")
assert.False(t, exists)
}
func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{{
@@ -496,7 +501,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
"openai": {
APIKey: "$MISSING",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -505,21 +510,21 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 1)
+ _, exists := cfg.Providers.Get("custom")
assert.True(t, exists)
})
t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
Models: []provider.Model{{
ID: "test-model",
}},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -528,20 +533,20 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
})
t.Run("custom provider with no models is removed", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -550,14 +555,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
})
t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -566,7 +571,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
ID: "test-model",
}},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -575,14 +580,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
})
t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -591,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
ID: "test-model",
}},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -600,8 +605,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
- customProvider, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 1)
+ customProvider, exists := cfg.Providers.Get("custom")
assert.True(t, exists)
assert.Equal(t, "custom", customProvider.ID)
assert.Equal(t, "test-key", customProvider.APIKey)
@@ -610,7 +615,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("disabled custom provider is removed", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -620,7 +625,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
ID: "test-model",
}},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -629,8 +634,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
})
}
@@ -649,11 +654,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"vertexai": {
BaseURL: "custom-url",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -664,8 +669,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["vertexai"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("vertexai")
assert.False(t, exists)
})
@@ -682,11 +687,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"bedrock": {
BaseURL: "custom-url",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -695,8 +700,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["bedrock"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("bedrock")
assert.False(t, exists)
})
@@ -713,11 +718,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
BaseURL: "custom-url",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -726,8 +731,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["openai"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("openai")
assert.False(t, exists)
})
@@ -744,11 +749,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
APIKey: "test-key",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -759,8 +764,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
- _, exists := cfg.Providers["openai"]
+ assert.Equal(t, cfg.Providers.Len(), 1)
+ _, exists := cfg.Providers.Get("openai")
assert.True(t, exists)
})
}
@@ -883,7 +888,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -894,7 +899,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
},
},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
@@ -932,13 +937,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
@@ -969,7 +974,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -980,7 +985,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
},
},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})