Detailed changes
@@ -0,0 +1,34 @@
+package csync
+
+import (
+ "iter"
+ "sync"
+)
+
+type LazySlice[K any] struct {
+ inner []K
+ mu sync.Mutex
+}
+
+func NewLazySlice[K any](load func() []K) *LazySlice[K] {
+ s := &LazySlice[K]{}
+ s.mu.Lock()
+ go func() {
+ s.inner = load()
+ s.mu.Unlock()
+ }()
+ return s
+}
+
+func (s *LazySlice[K]) Iter() iter.Seq[K] {
+ s.mu.Lock()
+ inner := s.inner
+ s.mu.Unlock()
+ return func(yield func(K) bool) {
+ for _, v := range inner {
+ if !yield(v) {
+ return
+ }
+ }
+ }
+}
@@ -0,0 +1,86 @@
+package csync
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestLazySlice_Iter(t *testing.T) {
+ t.Parallel()
+
+ data := []string{"a", "b", "c"}
+ s := NewLazySlice(func() []string {
+ // TODO: use synctest when new Go is out.
+ time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens
+ return data
+ })
+
+ var result []string
+ for v := range s.Iter() {
+ result = append(result, v)
+ }
+
+ assert.Equal(t, data, result)
+}
+
+func TestLazySlice_IterWaitsForLoading(t *testing.T) {
+ t.Parallel()
+
+ var loaded atomic.Bool
+ data := []string{"x", "y", "z"}
+
+ s := NewLazySlice(func() []string {
+ // TODO: use synctest when new Go is out.
+ time.Sleep(100 * time.Millisecond)
+ loaded.Store(true)
+ return data
+ })
+
+ assert.False(t, loaded.Load(), "should not be loaded immediately")
+
+ var result []string
+ for v := range s.Iter() {
+ result = append(result, v)
+ }
+
+ assert.True(t, loaded.Load(), "should be loaded after Iter")
+ assert.Equal(t, data, result)
+}
+
+func TestLazySlice_EmptySlice(t *testing.T) {
+ t.Parallel()
+
+ s := NewLazySlice(func() []string {
+ return []string{}
+ })
+
+ var result []string
+ for v := range s.Iter() {
+ result = append(result, v)
+ }
+
+ assert.Empty(t, result)
+}
+
+func TestLazySlice_EarlyBreak(t *testing.T) {
+ t.Parallel()
+
+ data := []string{"a", "b", "c", "d", "e"}
+ s := NewLazySlice(func() []string {
+ time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens
+ return data
+ })
+
+ var result []string
+ for v := range s.Iter() {
+ result = append(result, v)
+ if len(result) == 2 {
+ break
+ }
+ }
+
+ assert.Equal(t, []string{"a", "b"}, result)
+}
@@ -9,6 +9,7 @@ import (
"strings"
"time"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/tidwall/sjson"
@@ -236,7 +237,7 @@ type Config struct {
Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
// The providers that are configured
- Providers map[string]ProviderConfig `json:"providers,omitempty"`
+ Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"`
MCP MCPs `json:"mcp,omitempty"`
@@ -259,8 +260,8 @@ func (c *Config) WorkingDir() string {
}
func (c *Config) EnabledProviders() []ProviderConfig {
- enabled := make([]ProviderConfig, 0, len(c.Providers))
- for _, p := range c.Providers {
+ var enabled []ProviderConfig
+ for _, p := range c.Providers.Seq2() {
if !p.Disable {
enabled = append(enabled, p)
}
@@ -274,7 +275,7 @@ func (c *Config) IsConfigured() bool {
}
func (c *Config) GetModel(provider, model string) *provider.Model {
- if providerConfig, ok := c.Providers[provider]; ok {
+ if providerConfig, ok := c.Providers.Get(provider); ok {
for _, m := range providerConfig.Models {
if m.ID == model {
return &m
@@ -289,7 +290,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi
if !ok {
return nil
}
- if providerConfig, ok := c.Providers[model.Provider]; ok {
+ if providerConfig, ok := c.Providers.Get(model.Provider); ok {
return &providerConfig
}
return nil
@@ -370,14 +371,10 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
return fmt.Errorf("failed to save API key to config file: %w", err)
}
- if c.Providers == nil {
- c.Providers = make(map[string]ProviderConfig)
- }
-
- providerConfig, exists := c.Providers[providerID]
+ providerConfig, exists := c.Providers.Get(providerID)
if exists {
providerConfig.APIKey = apiKey
- c.Providers[providerID] = providerConfig
+ c.Providers.Set(providerID, providerConfig)
return nil
}
@@ -406,7 +403,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
return fmt.Errorf("provider with ID %s not found in known providers", providerID)
}
// Store the updated provider config
- c.Providers[providerID] = providerConfig
+ c.Providers.Set(providerID, providerConfig)
return nil
}
@@ -11,6 +11,7 @@ import (
"strings"
"sync"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
@@ -80,30 +81,34 @@ func Load(workingDir string, debug bool) (*Config, error) {
var testResults sync.Map
var wg sync.WaitGroup
- for _, p := range cfg.Providers {
- if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
- wg.Add(1)
- go func(provider ProviderConfig) {
- defer wg.Done()
- err := provider.TestConnection(cfg.resolver)
- testResults.Store(provider.ID, err == nil)
- if err != nil {
- slog.Error("Provider connection test failed", "provider", provider.ID, "error", err)
- }
- }(p)
- }
- }
- wg.Wait()
-
- // Remove failed providers
- testResults.Range(func(key, value any) bool {
- providerID := key.(string)
- passed := value.(bool)
- if !passed {
- delete(cfg.Providers, providerID)
+ go func() {
+ slog.Info("Testing provider connections")
+ defer slog.Info("Provider connection tests completed")
+ for _, p := range cfg.Providers.Seq2() {
+ if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
+ wg.Add(1)
+ go func(provider ProviderConfig) {
+ defer wg.Done()
+ err := provider.TestConnection(cfg.resolver)
+ testResults.Store(provider.ID, err == nil)
+ if err != nil {
+ slog.Error("Provider connection test failed", "provider", provider.ID, "error", err)
+ }
+ }(p)
+ }
}
- return true
- })
+ wg.Wait()
+
+ // Remove failed providers
+ testResults.Range(func(key, value any) bool {
+ providerID := key.(string)
+ passed := value.(bool)
+ if !passed {
+ cfg.Providers.Del(providerID)
+ }
+ return true
+ })
+ }()
if !cfg.IsConfigured() {
slog.Warn("No providers configured")
@@ -121,12 +126,12 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
knownProviderNames := make(map[string]bool)
for _, p := range knownProviders {
knownProviderNames[string(p.ID)] = true
- config, configExists := c.Providers[string(p.ID)]
+ config, configExists := c.Providers.Get(string(p.ID))
// if the user configured a known provider we need to allow it to override a couple of parameters
if configExists {
if config.Disable {
slog.Debug("Skipping provider due to disable flag", "provider", p.ID)
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
continue
}
if config.BaseURL != "" {
@@ -182,7 +187,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if !hasVertexCredentials(env) {
if configExists {
slog.Warn("Skipping Vertex AI provider due to missing credentials")
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
}
continue
}
@@ -193,7 +198,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if err != nil || endpoint == "" {
if configExists {
slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err)
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
}
continue
}
@@ -203,7 +208,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if !hasAWSCredentials(env) {
if configExists {
slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
}
continue
}
@@ -218,16 +223,16 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if v == "" || err != nil {
if configExists {
slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
- delete(c.Providers, string(p.ID))
+ c.Providers.Del(string(p.ID))
}
continue
}
}
- c.Providers[string(p.ID)] = prepared
+ c.Providers.Set(string(p.ID), prepared)
}
// validate the custom providers
- for id, providerConfig := range c.Providers {
+ for id, providerConfig := range c.Providers.Seq2() {
if knownProviderNames[id] {
continue
}
@@ -244,7 +249,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
if providerConfig.Disable {
slog.Debug("Skipping custom provider due to disable flag", "provider", id)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
if providerConfig.APIKey == "" {
@@ -252,17 +257,17 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
}
if providerConfig.BaseURL == "" {
slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
if len(providerConfig.Models) == 0 {
slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
if providerConfig.Type != provider.TypeOpenAI {
slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
@@ -273,11 +278,11 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
if baseURL == "" || err != nil {
slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
- delete(c.Providers, id)
+ c.Providers.Del(id)
continue
}
- c.Providers[id] = providerConfig
+ c.Providers.Set(id, providerConfig)
}
return nil
}
@@ -297,7 +302,7 @@ func (c *Config) setDefaults(workingDir string) {
c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
}
if c.Providers == nil {
- c.Providers = make(map[string]ProviderConfig)
+ c.Providers = csync.NewMap[string, ProviderConfig]()
}
if c.Models == nil {
c.Models = make(map[SelectedModelType]SelectedModel)
@@ -316,7 +321,7 @@ func (c *Config) setDefaults(workingDir string) {
}
func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
- if len(knownProviders) == 0 && len(c.Providers) == 0 {
+ if len(knownProviders) == 0 { // TODO:}&& len(c.Providers) == 0 {
err = fmt.Errorf("no providers configured, please configure at least one provider")
return
}
@@ -324,7 +329,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg
// Use the first provider enabled based on the known providers order
// if no provider found that is known use the first provider configured
for _, p := range knownProviders {
- providerConfig, ok := c.Providers[string(p.ID)]
+ providerConfig, ok := c.Providers.Get(string(p.ID))
if !ok || providerConfig.Disable {
continue
}
@@ -8,6 +8,7 @@ import (
"strings"
"testing"
+ "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/stretchr/testify/assert"
@@ -29,9 +30,10 @@ func TestConfig_LoadFromReaders(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, loadedConfig)
- assert.Len(t, loadedConfig.Providers, 1)
- assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey)
- assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL)
+ assert.Equal(t, 1, loadedConfig.Providers.Len())
+ pc, _ := loadedConfig.Providers.Get("openai")
+ assert.Equal(t, "key2", pc.APIKey)
+ assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
}
func TestConfig_setDefaults(t *testing.T) {
@@ -73,10 +75,11 @@ func TestConfig_configureProviders(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, 1, cfg.Providers.Len())
// We want to make sure that we keep the configured API key as a placeholder
- assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey)
+ pc, _ := cfg.Providers.Get("openai")
+ assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
}
func TestConfig_configureProvidersWithOverride(t *testing.T) {
@@ -92,22 +95,21 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
- "openai": {
- APIKey: "xyz",
- BaseURL: "https://api.openai.com/v2",
- Models: []provider.Model{
- {
- ID: "test-model",
- Model: "Updated",
- },
- {
- ID: "another-model",
- },
- },
+ Providers: csync.NewMap[string, ProviderConfig](),
+ }
+ cfg.Providers.Set("openai", ProviderConfig{
+ APIKey: "xyz",
+ BaseURL: "https://api.openai.com/v2",
+ Models: []provider.Model{
+ {
+ ID: "test-model",
+ Model: "Updated",
+ },
+ {
+ ID: "another-model",
},
},
- }
+ })
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{
@@ -116,13 +118,14 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, 1, cfg.Providers.Len())
// We want to make sure that we keep the configured API key as a placeholder
- assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey)
- assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL)
- assert.Len(t, cfg.Providers["openai"].Models, 2)
- assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Model)
+ pc, _ := cfg.Providers.Get("openai")
+ assert.Equal(t, "xyz", pc.APIKey)
+ assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
+ assert.Len(t, pc.Models, 2)
+ assert.Equal(t, "Updated", pc.Models[0].Model)
}
func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
@@ -138,7 +141,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "xyz",
BaseURL: "https://api.someendpoint.com/v2",
@@ -148,7 +151,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
},
},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{
@@ -158,16 +161,17 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Should be to because of the env variable
- assert.Len(t, cfg.Providers, 2)
+ assert.Equal(t, cfg.Providers.Len(), 2)
// We want to make sure that we keep the configured API key as a placeholder
- assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey)
+ pc, _ := cfg.Providers.Get("custom")
+ assert.Equal(t, "xyz", pc.APIKey)
// Make sure we set the ID correctly
- assert.Equal(t, "custom", cfg.Providers["custom"].ID)
- assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL)
- assert.Len(t, cfg.Providers["custom"].Models, 1)
+ assert.Equal(t, "custom", pc.ID)
+ assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
+ assert.Len(t, pc.Models, 1)
- _, ok := cfg.Providers["openai"]
+ _, ok := cfg.Providers.Get("openai")
assert.True(t, ok, "OpenAI provider should still be present")
}
@@ -192,9 +196,9 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, cfg.Providers.Len(), 1)
- bedrockProvider, ok := cfg.Providers["bedrock"]
+ bedrockProvider, ok := cfg.Providers.Get("bedrock")
assert.True(t, ok, "Bedrock provider should be present")
assert.Len(t, bedrockProvider.Models, 1)
assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
@@ -219,7 +223,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without credentials
- assert.Len(t, cfg.Providers, 0)
+ assert.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
@@ -267,9 +271,9 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, cfg.Providers.Len(), 1)
- vertexProvider, ok := cfg.Providers["vertexai"]
+ vertexProvider, ok := cfg.Providers.Get("vertexai")
assert.True(t, ok, "VertexAI provider should be present")
assert.Len(t, vertexProvider.Models, 1)
assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
@@ -300,7 +304,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without proper credentials
- assert.Len(t, cfg.Providers, 0)
+ assert.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
@@ -325,7 +329,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
// Provider should not be configured without project
- assert.Len(t, cfg.Providers, 0)
+ assert.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersSetProviderID(t *testing.T) {
@@ -348,16 +352,17 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
+ assert.Equal(t, cfg.Providers.Len(), 1)
// Provider ID should be set
- assert.Equal(t, "openai", cfg.Providers["openai"].ID)
+ pc, _ := cfg.Providers.Get("openai")
+ assert.Equal(t, "openai", pc.ID)
}
func TestConfig_EnabledProviders(t *testing.T) {
t.Run("all providers enabled", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -368,7 +373,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
APIKey: "key2",
Disable: false,
},
- },
+ }),
}
enabled := cfg.EnabledProviders()
@@ -377,7 +382,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
t.Run("some providers disabled", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -388,7 +393,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
APIKey: "key2",
Disable: true,
},
- },
+ }),
}
enabled := cfg.EnabledProviders()
@@ -398,7 +403,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
t.Run("empty providers map", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{},
+ Providers: csync.NewMap[string, ProviderConfig](),
}
enabled := cfg.EnabledProviders()
@@ -409,13 +414,13 @@ func TestConfig_EnabledProviders(t *testing.T) {
func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
ID: "openai",
APIKey: "key1",
Disable: false,
},
- },
+ }),
}
assert.True(t, cfg.IsConfigured())
@@ -423,7 +428,7 @@ func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns false when no providers are configured", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{},
+ Providers: csync.NewMap[string, ProviderConfig](),
}
assert.False(t, cfg.IsConfigured())
@@ -431,7 +436,7 @@ func TestConfig_IsConfigured(t *testing.T) {
t.Run("returns false when all providers are disabled", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
ID: "openai",
APIKey: "key1",
@@ -442,7 +447,7 @@ func TestConfig_IsConfigured(t *testing.T) {
APIKey: "key2",
Disable: true,
},
- },
+ }),
}
assert.False(t, cfg.IsConfigured())
@@ -462,11 +467,11 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
Disable: true,
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -478,15 +483,15 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
assert.NoError(t, err)
// Provider should be removed from config when disabled
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["openai"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("openai")
assert.False(t, exists)
}
func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{{
@@ -496,7 +501,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
"openai": {
APIKey: "$MISSING",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -505,21 +510,21 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 1)
+ _, exists := cfg.Providers.Get("custom")
assert.True(t, exists)
})
t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
Models: []provider.Model{{
ID: "test-model",
}},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -528,20 +533,20 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
})
t.Run("custom provider with no models is removed", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -550,14 +555,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
})
t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -566,7 +571,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
ID: "test-model",
}},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -575,14 +580,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
})
t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -591,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
ID: "test-model",
}},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -600,8 +605,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
- customProvider, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 1)
+ customProvider, exists := cfg.Providers.Get("custom")
assert.True(t, exists)
assert.Equal(t, "custom", customProvider.ID)
assert.Equal(t, "test-key", customProvider.APIKey)
@@ -610,7 +615,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
t.Run("disabled custom provider is removed", func(t *testing.T) {
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -620,7 +625,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
ID: "test-model",
}},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -629,8 +634,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, []provider.Provider{})
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["custom"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
})
}
@@ -649,11 +654,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"vertexai": {
BaseURL: "custom-url",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -664,8 +669,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["vertexai"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("vertexai")
assert.False(t, exists)
})
@@ -682,11 +687,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"bedrock": {
BaseURL: "custom-url",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -695,8 +700,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["bedrock"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("bedrock")
assert.False(t, exists)
})
@@ -713,11 +718,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
BaseURL: "custom-url",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -726,8 +731,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 0)
- _, exists := cfg.Providers["openai"]
+ assert.Equal(t, cfg.Providers.Len(), 0)
+ _, exists := cfg.Providers.Get("openai")
assert.False(t, exists)
})
@@ -744,11 +749,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"openai": {
APIKey: "test-key",
},
- },
+ }),
}
cfg.setDefaults("/tmp")
@@ -759,8 +764,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
- assert.Len(t, cfg.Providers, 1)
- _, exists := cfg.Providers["openai"]
+ assert.Equal(t, cfg.Providers.Len(), 1)
+ _, exists := cfg.Providers.Get("openai")
assert.True(t, exists)
})
}
@@ -883,7 +888,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -894,7 +899,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
},
},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
@@ -932,13 +937,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
@@ -969,7 +974,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
}
cfg := &Config{
- Providers: map[string]ProviderConfig{
+ Providers: csync.NewMapFrom(map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
@@ -980,7 +985,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
},
},
},
- },
+ }),
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
@@ -0,0 +1,84 @@
+package csync
+
+import (
+ "encoding/json"
+ "iter"
+ "maps"
+ "sync"
+)
+
+type Map[K comparable, V any] struct {
+ inner map[K]V
+ mu sync.RWMutex
+}
+
+func NewMap[K comparable, V any]() *Map[K, V] {
+ return &Map[K, V]{
+ inner: make(map[K]V),
+ }
+}
+
+func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] {
+ return &Map[K, V]{
+ inner: m,
+ }
+}
+
+func (m *Map[K, V]) Set(key K, value V) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.inner[key] = value
+}
+
+func (m *Map[K, V]) Del(key K) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ delete(m.inner, key)
+}
+
+func (m *Map[K, V]) Get(key K) (V, bool) {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ v, ok := m.inner[key]
+ return v, ok
+}
+
+func (m *Map[K, V]) Len() int {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return len(m.inner)
+}
+
+func (m *Map[K, V]) Seq2() iter.Seq2[K, V] {
+ dst := make(map[K]V)
+ m.mu.RLock()
+ maps.Copy(dst, m.inner)
+ m.mu.RUnlock()
+ return func(yield func(K, V) bool) {
+ for k, v := range dst {
+ if !yield(k, v) {
+ return
+ }
+ }
+ }
+}
+
+var (
+ _ json.Unmarshaler = &Map[string, any]{}
+ _ json.Marshaler = &Map[string, any]{}
+)
+
+// UnmarshalJSON implements json.Unmarshaler.
+func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.inner = make(map[K]V)
+ return json.Unmarshal(data, &m.inner)
+}
+
+// MarshalJSON implements json.Marshaler.
+func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return json.Marshal(m.inner)
+}
@@ -0,0 +1,450 @@
+package csync
+
+import (
+ "encoding/json"
+ "maps"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNewMap(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ assert.NotNil(t, m)
+ assert.NotNil(t, m.inner)
+ assert.Equal(t, 0, m.Len())
+}
+
+func TestNewMapFrom(t *testing.T) {
+ t.Parallel()
+
+ original := map[string]int{
+ "key1": 1,
+ "key2": 2,
+ }
+
+ m := NewMapFrom(original)
+ assert.NotNil(t, m)
+ assert.Equal(t, original, m.inner)
+ assert.Equal(t, 2, m.Len())
+
+ value, ok := m.Get("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 1, value)
+}
+
+func TestMap_Set(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+
+ m.Set("key1", 42)
+ value, ok := m.Get("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 42, value)
+ assert.Equal(t, 1, m.Len())
+
+ m.Set("key1", 100)
+ value, ok = m.Get("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 100, value)
+ assert.Equal(t, 1, m.Len())
+}
+
+func TestMap_Get(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+
+ value, ok := m.Get("nonexistent")
+ assert.False(t, ok)
+ assert.Equal(t, 0, value)
+
+ m.Set("key1", 42)
+ value, ok = m.Get("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 42, value)
+}
+
+func TestMap_Del(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 42)
+ m.Set("key2", 100)
+
+ assert.Equal(t, 2, m.Len())
+
+ m.Del("key1")
+ _, ok := m.Get("key1")
+ assert.False(t, ok)
+ assert.Equal(t, 1, m.Len())
+
+ value, ok := m.Get("key2")
+ assert.True(t, ok)
+ assert.Equal(t, 100, value)
+
+ m.Del("nonexistent")
+ assert.Equal(t, 1, m.Len())
+}
+
+func TestMap_Len(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ assert.Equal(t, 0, m.Len())
+
+ m.Set("key1", 1)
+ assert.Equal(t, 1, m.Len())
+
+ m.Set("key2", 2)
+ assert.Equal(t, 2, m.Len())
+
+ m.Del("key1")
+ assert.Equal(t, 1, m.Len())
+
+ m.Del("key2")
+ assert.Equal(t, 0, m.Len())
+}
+
+func TestMap_Seq2(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 1)
+ m.Set("key2", 2)
+ m.Set("key3", 3)
+
+ collected := maps.Collect(m.Seq2())
+
+ assert.Equal(t, 3, len(collected))
+ assert.Equal(t, 1, collected["key1"])
+ assert.Equal(t, 2, collected["key2"])
+ assert.Equal(t, 3, collected["key3"])
+}
+
+func TestMap_Seq2_EarlyReturn(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 1)
+ m.Set("key2", 2)
+ m.Set("key3", 3)
+
+ count := 0
+ for range m.Seq2() {
+ count++
+ if count == 2 {
+ break
+ }
+ }
+
+ assert.Equal(t, 2, count)
+}
+
+func TestMap_Seq2_EmptyMap(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+
+ count := 0
+ for range m.Seq2() {
+ count++
+ }
+
+ assert.Equal(t, 0, count)
+}
+
+func TestMap_MarshalJSON(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("key1", 1)
+ m.Set("key2", 2)
+
+ data, err := json.Marshal(m)
+ assert.NoError(t, err)
+
+ var result map[string]int
+ err = json.Unmarshal(data, &result)
+ assert.NoError(t, err)
+ assert.Equal(t, 2, len(result))
+ assert.Equal(t, 1, result["key1"])
+ assert.Equal(t, 2, result["key2"])
+}
+
+func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+
+ data, err := json.Marshal(m)
+ assert.NoError(t, err)
+ assert.Equal(t, "{}", string(data))
+}
+
+func TestMap_UnmarshalJSON(t *testing.T) {
+ t.Parallel()
+
+ jsonData := `{"key1": 1, "key2": 2}`
+
+ m := NewMap[string, int]()
+ err := json.Unmarshal([]byte(jsonData), m)
+ assert.NoError(t, err)
+
+ assert.Equal(t, 2, m.Len())
+ value, ok := m.Get("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 1, value)
+
+ value, ok = m.Get("key2")
+ assert.True(t, ok)
+ assert.Equal(t, 2, value)
+}
+
+func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
+ t.Parallel()
+
+ jsonData := `{}`
+
+ m := NewMap[string, int]()
+ err := json.Unmarshal([]byte(jsonData), m)
+ assert.NoError(t, err)
+ assert.Equal(t, 0, m.Len())
+}
+
+func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
+ t.Parallel()
+
+ jsonData := `{"key1": 1, "key2":}`
+
+ m := NewMap[string, int]()
+ err := json.Unmarshal([]byte(jsonData), m)
+ assert.Error(t, err)
+}
+
+func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[string, int]()
+ m.Set("existing", 999)
+
+ jsonData := `{"key1": 1, "key2": 2}`
+ err := json.Unmarshal([]byte(jsonData), m)
+ assert.NoError(t, err)
+
+ assert.Equal(t, 2, m.Len())
+ _, ok := m.Get("existing")
+ assert.False(t, ok)
+
+ value, ok := m.Get("key1")
+ assert.True(t, ok)
+ assert.Equal(t, 1, value)
+}
+
+func TestMap_JSONRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ original := NewMap[string, int]()
+ original.Set("key1", 1)
+ original.Set("key2", 2)
+ original.Set("key3", 3)
+
+ data, err := json.Marshal(original)
+ assert.NoError(t, err)
+
+ restored := NewMap[string, int]()
+ err = json.Unmarshal(data, restored)
+ assert.NoError(t, err)
+
+ assert.Equal(t, original.Len(), restored.Len())
+
+ for k, v := range original.Seq2() {
+ restoredValue, ok := restored.Get(k)
+ assert.True(t, ok)
+ assert.Equal(t, v, restoredValue)
+ }
+}
+
+func TestMap_ConcurrentAccess(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[int, int]()
+ const numGoroutines = 100
+ const numOperations = 100
+
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+
+ for i := range numGoroutines {
+ go func(id int) {
+ defer wg.Done()
+ for j := range numOperations {
+ key := id*numOperations + j
+ m.Set(key, key*2)
+ value, ok := m.Get(key)
+ assert.True(t, ok)
+ assert.Equal(t, key*2, value)
+ }
+ }(i)
+ }
+
+ wg.Wait()
+
+ assert.Equal(t, numGoroutines*numOperations, m.Len())
+}
+
+func TestMap_ConcurrentReadWrite(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[int, int]()
+ const numReaders = 50
+ const numWriters = 50
+ const numOperations = 100
+
+ for i := range 1000 {
+ m.Set(i, i)
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(numReaders + numWriters)
+
+ for range numReaders {
+ go func() {
+ defer wg.Done()
+ for j := range numOperations {
+ key := j % 1000
+ value, ok := m.Get(key)
+ if ok {
+ assert.Equal(t, key, value)
+ }
+ _ = m.Len()
+ }
+ }()
+ }
+
+ for i := range numWriters {
+ go func(id int) {
+ defer wg.Done()
+ for j := range numOperations {
+ key := 1000 + id*numOperations + j
+ m.Set(key, key)
+ if j%10 == 0 {
+ m.Del(key)
+ }
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+func TestMap_ConcurrentSeq2(t *testing.T) {
+ t.Parallel()
+
+ m := NewMap[int, int]()
+ for i := range 100 {
+ m.Set(i, i*2)
+ }
+
+ var wg sync.WaitGroup
+ const numIterators = 10
+
+ wg.Add(numIterators)
+ for range numIterators {
+ go func() {
+ defer wg.Done()
+ count := 0
+ for k, v := range m.Seq2() {
+ assert.Equal(t, k*2, v)
+ count++
+ }
+ assert.Equal(t, 100, count)
+ }()
+ }
+
+ wg.Wait()
+}
+
+func TestMap_TypeSafety(t *testing.T) {
+ t.Parallel()
+
+ stringIntMap := NewMap[string, int]()
+ stringIntMap.Set("key", 42)
+ value, ok := stringIntMap.Get("key")
+ assert.True(t, ok)
+ assert.Equal(t, 42, value)
+
+ intStringMap := NewMap[int, string]()
+ intStringMap.Set(42, "value")
+ strValue, ok := intStringMap.Get(42)
+ assert.True(t, ok)
+ assert.Equal(t, "value", strValue)
+
+ structMap := NewMap[string, struct{ Name string }]()
+ structMap.Set("key", struct{ Name string }{Name: "test"})
+ structValue, ok := structMap.Get("key")
+ assert.True(t, ok)
+ assert.Equal(t, "test", structValue.Name)
+}
+
+func TestMap_InterfaceCompliance(t *testing.T) {
+ t.Parallel()
+
+ var _ json.Marshaler = &Map[string, any]{}
+ var _ json.Unmarshaler = &Map[string, any]{}
+}
+
+func BenchmarkMap_Set(b *testing.B) {
+ m := NewMap[int, int]()
+
+ for i := 0; b.Loop(); i++ {
+ m.Set(i, i*2)
+ }
+}
+
+func BenchmarkMap_Get(b *testing.B) {
+ m := NewMap[int, int]()
+ for i := range 1000 {
+ m.Set(i, i*2)
+ }
+
+ for i := 0; b.Loop(); i++ {
+ m.Get(i % 1000)
+ }
+}
+
+func BenchmarkMap_Seq2(b *testing.B) {
+ m := NewMap[int, int]()
+ for i := range 1000 {
+ m.Set(i, i*2)
+ }
+
+ for b.Loop() {
+ for range m.Seq2() {
+ }
+ }
+}
+
+func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
+ m := NewMap[int, int]()
+ for i := range 1000 {
+ m.Set(i, i*2)
+ }
+
+ b.ResetTimer()
+ b.RunParallel(func(pb *testing.PB) {
+ i := 0
+ for pb.Next() {
+ if i%2 == 0 {
+ m.Get(i % 1000)
+ } else {
+ m.Set(i+1000, i*2)
+ }
+ i++
+ }
+ })
+}
@@ -8,9 +8,9 @@ import (
"slices"
"strings"
"sync"
- "sync/atomic"
"time"
+ "github.com/charmbracelet/crush/csync"
"github.com/charmbracelet/crush/internal/config"
fur "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/history"
@@ -68,8 +68,7 @@ type agent struct {
sessions session.Service
messages message.Service
- toolsDone atomic.Bool
- tools []tools.BaseTool
+ tools *csync.LazySlice[tools.BaseTool]
provider provider.Provider
providerID string
@@ -168,24 +167,10 @@ func NewAgent(
return nil, err
}
- agent := &agent{
- Broker: pubsub.NewBroker[AgentEvent](),
- agentCfg: agentCfg,
- provider: agentProvider,
- providerID: string(providerCfg.ID),
- messages: messages,
- sessions: sessions,
- titleProvider: titleProvider,
- summarizeProvider: summarizeProvider,
- summarizeProviderID: string(smallModelProviderCfg.ID),
- activeRequests: sync.Map{},
- }
-
- go func() {
+ toolFn := func() []tools.BaseTool {
slog.Info("Initializing agent tools", "agent", agentCfg.ID)
defer func() {
slog.Info("Initialized agent tools", "agent", agentCfg.ID)
- agent.toolsDone.Store(true)
}()
cwd := cfg.WorkingDir()
@@ -214,8 +199,7 @@ func NewAgent(
}
if agentCfg.AllowedTools == nil {
- agent.tools = allTools
- return
+ return allTools
}
var filteredTools []tools.BaseTool
@@ -224,10 +208,22 @@ func NewAgent(
filteredTools = append(filteredTools, tool)
}
}
- agent.tools = filteredTools
- }()
+ return filteredTools
+ }
- return agent, nil
+ return &agent{
+ Broker: pubsub.NewBroker[AgentEvent](),
+ agentCfg: agentCfg,
+ provider: agentProvider,
+ providerID: string(providerCfg.ID),
+ messages: messages,
+ sessions: sessions,
+ titleProvider: titleProvider,
+ summarizeProvider: summarizeProvider,
+ summarizeProviderID: string(smallModelProviderCfg.ID),
+ activeRequests: sync.Map{},
+ tools: csync.NewLazySlice(toolFn),
+ }, nil
}
func (a *agent) Model() fur.Model {
@@ -449,10 +445,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
- if !a.toolsDone.Load() {
- return message.Message{}, nil, fmt.Errorf("agent is still initializing, please wait a moment and try again")
- }
- eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
+ eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Iter()))
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
@@ -501,7 +494,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
default:
// Continue processing
var tool tools.BaseTool
- for _, availableTool := range a.tools {
+ for availableTool := range a.tools.Iter() {
if availableTool.Info().Name == toolCall.Name {
tool = availableTool
break
@@ -911,7 +904,7 @@ func (a *agent) UpdateModel() error {
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
var smallModelProviderCfg config.ProviderConfig
- for _, p := range cfg.Providers {
+ for _, p := range cfg.Providers.Seq2() {
if p.ID == smallModelCfg.Provider {
smallModelProviderCfg = p
break
@@ -422,7 +422,7 @@ func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provide
func (s *splashCmp) isProviderConfigured(providerID string) bool {
cfg := config.Get()
- if _, ok := cfg.Providers[providerID]; ok {
+ if _, ok := cfg.Providers.Get(providerID); ok {
return true
}
return false
@@ -103,7 +103,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
if err != nil {
return util.ReportError(err)
}
- for providerID, providerConfig := range cfg.Providers {
+ for providerID, providerConfig := range cfg.Providers.Seq2() {
if providerConfig.Disable {
continue
}
@@ -164,7 +164,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
}
// Check if this provider is configured and not disabled
- if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
+ if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
continue
}
@@ -174,7 +174,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
}
section := commands.NewItemSection(name)
- if _, ok := cfg.Providers[string(provider.ID)]; ok {
+ if _, ok := cfg.Providers.Get(string(provider.ID)); ok {
section.SetInfo(configured)
}
modelItems = append(modelItems, section)
@@ -357,7 +357,7 @@ func (m *modelDialogCmp) modelTypeRadio() string {
func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
cfg := config.Get()
- if _, ok := cfg.Providers[providerID]; ok {
+ if _, ok := cfg.Providers.Get(providerID); ok {
return true
}
return false