diff --git a/csync/slices.go b/csync/slices.go new file mode 100644 index 0000000000000000000000000000000000000000..388ad074d53a9bd7188418b231afbf39adca0565 --- /dev/null +++ b/csync/slices.go @@ -0,0 +1,34 @@ +package csync + +import ( + "iter" + "sync" +) + +type LazySlice[K any] struct { + inner []K + mu sync.Mutex +} + +func NewLazySlice[K any](load func() []K) *LazySlice[K] { + s := &LazySlice[K]{} + s.mu.Lock() + go func() { + s.inner = load() + s.mu.Unlock() + }() + return s +} + +func (s *LazySlice[K]) Iter() iter.Seq[K] { + s.mu.Lock() + inner := s.inner + s.mu.Unlock() + return func(yield func(K) bool) { + for _, v := range inner { + if !yield(v) { + return + } + } + } +} diff --git a/csync/slices_test.go b/csync/slices_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d1c7af8cf30f3d58a84046f899f8dd89f80beb51 --- /dev/null +++ b/csync/slices_test.go @@ -0,0 +1,86 @@ +package csync + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLazySlice_Iter(t *testing.T) { + t.Parallel() + + data := []string{"a", "b", "c"} + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens + return data + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.Equal(t, data, result) +} + +func TestLazySlice_IterWaitsForLoading(t *testing.T) { + t.Parallel() + + var loaded atomic.Bool + data := []string{"x", "y", "z"} + + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(100 * time.Millisecond) + loaded.Store(true) + return data + }) + + assert.False(t, loaded.Load(), "should not be loaded immediately") + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.True(t, loaded.Load(), "should be loaded after Iter") + assert.Equal(t, data, result) +} + +func TestLazySlice_EmptySlice(t *testing.T) { + t.Parallel() + + s := NewLazySlice(func() []string { + return []string{} + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + } + + assert.Empty(t, result) +} + +func TestLazySlice_EarlyBreak(t *testing.T) { + t.Parallel() + + data := []string{"a", "b", "c", "d", "e"} + s := NewLazySlice(func() []string { + time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens + return data + }) + + var result []string + for v := range s.Iter() { + result = append(result, v) + if len(result) == 2 { + break + } + } + + assert.Equal(t, []string{"a", "b"}, result) +} diff --git a/internal/config/config.go b/internal/config/config.go index 1c20188a12a3955fde6b6eeed9f12ea39288e328..18eca04912189415606599c5849e8a7beb592cb4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/tidwall/sjson" @@ -236,7 +237,7 @@ type Config struct { Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` // The providers that are configured - Providers map[string]ProviderConfig `json:"providers,omitempty"` + Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"` MCP MCPs `json:"mcp,omitempty"` @@ -259,8 +260,8 @@ func (c *Config) WorkingDir() string { } func (c *Config) EnabledProviders() []ProviderConfig { - enabled := make([]ProviderConfig, 0, len(c.Providers)) - for _, p := range c.Providers { + var enabled []ProviderConfig + for _, p := range c.Providers.Seq2() { if !p.Disable { enabled = append(enabled, p) } @@ -274,7 +275,7 @@ func (c *Config) IsConfigured() bool { } func (c *Config) GetModel(provider, model string) *provider.Model { - if providerConfig, ok := c.Providers[provider]; ok { + if providerConfig, ok := c.Providers.Get(provider); ok { for _, m := range providerConfig.Models { if m.ID == model { return &m @@ -289,7 +290,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi if !ok { return nil } - if providerConfig, ok := c.Providers[model.Provider]; ok { + if providerConfig, ok := c.Providers.Get(model.Provider); ok { return &providerConfig } return nil @@ -370,14 +371,10 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return fmt.Errorf("failed to save API key to config file: %w", err) } - if c.Providers == nil { - c.Providers = make(map[string]ProviderConfig) - } - - providerConfig, exists := c.Providers[providerID] + providerConfig, exists := c.Providers.Get(providerID) if exists { providerConfig.APIKey = apiKey - c.Providers[providerID] = providerConfig + c.Providers.Set(providerID, providerConfig) return nil } @@ -406,7 +403,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return fmt.Errorf("provider with ID %s not found in known providers", providerID) } // Store the updated provider config - c.Providers[providerID] = providerConfig + c.Providers.Set(providerID, providerConfig) return nil } diff --git a/internal/config/load.go b/internal/config/load.go index cd4ccd08c46e48155091407962137da2cb913869..09d65e5391b94a1f80b15e7e576ba5d3e38ef19d 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -11,6 +11,7 @@ import ( "strings" "sync" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" @@ -80,30 +81,34 @@ func Load(workingDir string, debug bool) (*Config, error) { var testResults sync.Map var wg sync.WaitGroup - for _, p := range cfg.Providers { - if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { - wg.Add(1) - go func(provider ProviderConfig) { - defer wg.Done() - err := provider.TestConnection(cfg.resolver) - testResults.Store(provider.ID, err == nil) - if err != nil { - slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) - } - }(p) - } - } - wg.Wait() - - // Remove failed providers - testResults.Range(func(key, value any) bool { - providerID := key.(string) - passed := value.(bool) - if !passed { - delete(cfg.Providers, providerID) + go func() { + slog.Info("Testing provider connections") + defer slog.Info("Provider connection tests completed") + for _, p := range cfg.Providers.Seq2() { + if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { + wg.Add(1) + go func(provider ProviderConfig) { + defer wg.Done() + err := provider.TestConnection(cfg.resolver) + testResults.Store(provider.ID, err == nil) + if err != nil { + slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) + } + }(p) + } } - return true - }) + wg.Wait() + + // Remove failed providers + testResults.Range(func(key, value any) bool { + providerID := key.(string) + passed := value.(bool) + if !passed { + cfg.Providers.Del(providerID) + } + return true + }) + }() if !cfg.IsConfigured() { slog.Warn("No providers configured") @@ -121,12 +126,12 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know knownProviderNames := make(map[string]bool) for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true - config, configExists := c.Providers[string(p.ID)] + config, configExists := c.Providers.Get(string(p.ID)) // if the user configured a known provider we need to allow it to override a couple of parameters if configExists { if config.Disable { slog.Debug("Skipping provider due to disable flag", "provider", p.ID) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) continue } if config.BaseURL != "" { @@ -182,7 +187,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if !hasVertexCredentials(env) { if configExists { slog.Warn("Skipping Vertex AI provider due to missing credentials") - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -193,7 +198,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if err != nil || endpoint == "" { if configExists { slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -203,7 +208,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if !hasAWSCredentials(env) { if configExists { slog.Warn("Skipping Bedrock provider due to missing AWS credentials") - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -218,16 +223,16 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if v == "" || err != nil { if configExists { slog.Warn("Skipping provider due to missing API key", "provider", p.ID) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } } - c.Providers[string(p.ID)] = prepared + c.Providers.Set(string(p.ID), prepared) } // validate the custom providers - for id, providerConfig := range c.Providers { + for id, providerConfig := range c.Providers.Seq2() { if knownProviderNames[id] { continue } @@ -244,7 +249,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if providerConfig.Disable { slog.Debug("Skipping custom provider due to disable flag", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if providerConfig.APIKey == "" { @@ -252,17 +257,17 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } if providerConfig.BaseURL == "" { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if len(providerConfig.Models) == 0 { slog.Warn("Skipping custom provider because the provider has no models", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if providerConfig.Type != provider.TypeOpenAI { slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type) - delete(c.Providers, id) + c.Providers.Del(id) continue } @@ -273,11 +278,11 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know baseURL, err := resolver.ResolveValue(providerConfig.BaseURL) if baseURL == "" || err != nil { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err) - delete(c.Providers, id) + c.Providers.Del(id) continue } - c.Providers[id] = providerConfig + c.Providers.Set(id, providerConfig) } return nil } @@ -297,7 +302,7 @@ func (c *Config) setDefaults(workingDir string) { c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory) } if c.Providers == nil { - c.Providers = make(map[string]ProviderConfig) + c.Providers = csync.NewMap[string, ProviderConfig]() } if c.Models == nil { c.Models = make(map[SelectedModelType]SelectedModel) @@ -316,7 +321,7 @@ func (c *Config) setDefaults(workingDir string) { } func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { - if len(knownProviders) == 0 && len(c.Providers) == 0 { + if len(knownProviders) == 0 { // TODO:}&& len(c.Providers) == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return } @@ -324,7 +329,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg // Use the first provider enabled based on the known providers order // if no provider found that is known use the first provider configured for _, p := range knownProviders { - providerConfig, ok := c.Providers[string(p.ID)] + providerConfig, ok := c.Providers.Get(string(p.ID)) if !ok || providerConfig.Disable { continue } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index b96ca5e81cd265cbcd1bdf9d456603ad3f22c558..86a2356da2021dc22de88de05a80717e95aa492a 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/stretchr/testify/assert" @@ -29,9 +30,10 @@ func TestConfig_LoadFromReaders(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, loadedConfig) - assert.Len(t, loadedConfig.Providers, 1) - assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey) - assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL) + assert.Equal(t, 1, loadedConfig.Providers.Len()) + pc, _ := loadedConfig.Providers.Get("openai") + assert.Equal(t, "key2", pc.APIKey) + assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) } func TestConfig_setDefaults(t *testing.T) { @@ -73,10 +75,11 @@ func TestConfig_configureProviders(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, 1, cfg.Providers.Len()) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey) } func TestConfig_configureProvidersWithOverride(t *testing.T) { @@ -92,22 +95,21 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ - "openai": { - APIKey: "xyz", - BaseURL: "https://api.openai.com/v2", - Models: []provider.Model{ - { - ID: "test-model", - Model: "Updated", - }, - { - ID: "another-model", - }, - }, + Providers: csync.NewMap[string, ProviderConfig](), + } + cfg.Providers.Set("openai", ProviderConfig{ + APIKey: "xyz", + BaseURL: "https://api.openai.com/v2", + Models: []provider.Model{ + { + ID: "test-model", + Model: "Updated", + }, + { + ID: "another-model", }, }, - } + }) cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ @@ -116,13 +118,14 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, 1, cfg.Providers.Len()) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey) - assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL) - assert.Len(t, cfg.Providers["openai"].Models, 2) - assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Model) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "xyz", pc.APIKey) + assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) + assert.Len(t, pc.Models, 2) + assert.Equal(t, "Updated", pc.Models[0].Model) } func TestConfig_configureProvidersWithNewProvider(t *testing.T) { @@ -138,7 +141,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "xyz", BaseURL: "https://api.someendpoint.com/v2", @@ -148,7 +151,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ @@ -158,16 +161,17 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Should be to because of the env variable - assert.Len(t, cfg.Providers, 2) + assert.Equal(t, cfg.Providers.Len(), 2) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey) + pc, _ := cfg.Providers.Get("custom") + assert.Equal(t, "xyz", pc.APIKey) // Make sure we set the ID correctly - assert.Equal(t, "custom", cfg.Providers["custom"].ID) - assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL) - assert.Len(t, cfg.Providers["custom"].Models, 1) + assert.Equal(t, "custom", pc.ID) + assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL) + assert.Len(t, pc.Models, 1) - _, ok := cfg.Providers["openai"] + _, ok := cfg.Providers.Get("openai") assert.True(t, ok, "OpenAI provider should still be present") } @@ -192,9 +196,9 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) - bedrockProvider, ok := cfg.Providers["bedrock"] + bedrockProvider, ok := cfg.Providers.Get("bedrock") assert.True(t, ok, "Bedrock provider should be present") assert.Len(t, bedrockProvider.Models, 1) assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID) @@ -219,7 +223,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without credentials - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { @@ -267,9 +271,9 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) - vertexProvider, ok := cfg.Providers["vertexai"] + vertexProvider, ok := cfg.Providers.Get("vertexai") assert.True(t, ok, "VertexAI provider should be present") assert.Len(t, vertexProvider.Models, 1) assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID) @@ -300,7 +304,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without proper credentials - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { @@ -325,7 +329,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without project - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersSetProviderID(t *testing.T) { @@ -348,16 +352,17 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) // Provider ID should be set - assert.Equal(t, "openai", cfg.Providers["openai"].ID) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "openai", pc.ID) } func TestConfig_EnabledProviders(t *testing.T) { t.Run("all providers enabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -368,7 +373,7 @@ func TestConfig_EnabledProviders(t *testing.T) { APIKey: "key2", Disable: false, }, - }, + }), } enabled := cfg.EnabledProviders() @@ -377,7 +382,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("some providers disabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -388,7 +393,7 @@ func TestConfig_EnabledProviders(t *testing.T) { APIKey: "key2", Disable: true, }, - }, + }), } enabled := cfg.EnabledProviders() @@ -398,7 +403,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("empty providers map", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{}, + Providers: csync.NewMap[string, ProviderConfig](), } enabled := cfg.EnabledProviders() @@ -409,13 +414,13 @@ func TestConfig_EnabledProviders(t *testing.T) { func TestConfig_IsConfigured(t *testing.T) { t.Run("returns true when at least one provider is enabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", Disable: false, }, - }, + }), } assert.True(t, cfg.IsConfigured()) @@ -423,7 +428,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when no providers are configured", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{}, + Providers: csync.NewMap[string, ProviderConfig](), } assert.False(t, cfg.IsConfigured()) @@ -431,7 +436,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when all providers are disabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -442,7 +447,7 @@ func TestConfig_IsConfigured(t *testing.T) { APIKey: "key2", Disable: true, }, - }, + }), } assert.False(t, cfg.IsConfigured()) @@ -462,11 +467,11 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { Disable: true, }, - }, + }), } cfg.setDefaults("/tmp") @@ -478,15 +483,15 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { assert.NoError(t, err) // Provider should be removed from config when disabled - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("openai") assert.False(t, exists) } func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { BaseURL: "https://api.custom.com/v1", Models: []provider.Model{{ @@ -496,7 +501,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "openai": { APIKey: "$MISSING", }, - }, + }), } cfg.setDefaults("/tmp") @@ -505,21 +510,21 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 1) + _, exists := cfg.Providers.Get("custom") assert.True(t, exists) }) t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", Models: []provider.Model{{ ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -528,20 +533,20 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("custom provider with no models is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []provider.Model{}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -550,14 +555,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("custom provider with unsupported type is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -566,7 +571,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -575,14 +580,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("valid custom provider is kept and ID is set", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -591,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -600,8 +605,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - customProvider, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 1) + customProvider, exists := cfg.Providers.Get("custom") assert.True(t, exists) assert.Equal(t, "custom", customProvider.ID) assert.Equal(t, "test-key", customProvider.APIKey) @@ -610,7 +615,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("disabled custom provider is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -620,7 +625,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -629,8 +634,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []provider.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) } @@ -649,11 +654,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "vertexai": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -664,8 +669,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["vertexai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("vertexai") assert.False(t, exists) }) @@ -682,11 +687,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "bedrock": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -695,8 +700,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["bedrock"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("bedrock") assert.False(t, exists) }) @@ -713,11 +718,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -726,8 +731,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("openai") assert.False(t, exists) }) @@ -744,11 +749,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { APIKey: "test-key", }, - }, + }), } cfg.setDefaults("/tmp") @@ -759,8 +764,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 1) + _, exists := cfg.Providers.Get("openai") assert.True(t, exists) }) } @@ -883,7 +888,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -894,7 +899,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) @@ -932,13 +937,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []provider.Model{}, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) @@ -969,7 +974,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -980,7 +985,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) diff --git a/internal/csync/maps.go b/internal/csync/maps.go new file mode 100644 index 0000000000000000000000000000000000000000..69b56050d45b13abd189d3eb2da75120fe13589f --- /dev/null +++ b/internal/csync/maps.go @@ -0,0 +1,84 @@ +package csync + +import ( + "encoding/json" + "iter" + "maps" + "sync" +) + +type Map[K comparable, V any] struct { + inner map[K]V + mu sync.RWMutex +} + +func NewMap[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ + inner: make(map[K]V), + } +} + +func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] { + return &Map[K, V]{ + inner: m, + } +} + +func (m *Map[K, V]) Set(key K, value V) { + m.mu.Lock() + defer m.mu.Unlock() + m.inner[key] = value +} + +func (m *Map[K, V]) Del(key K) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.inner, key) +} + +func (m *Map[K, V]) Get(key K) (V, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + v, ok := m.inner[key] + return v, ok +} + +func (m *Map[K, V]) Len() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.inner) +} + +func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { + dst := make(map[K]V) + m.mu.RLock() + maps.Copy(dst, m.inner) + m.mu.RUnlock() + return func(yield func(K, V) bool) { + for k, v := range dst { + if !yield(k, v) { + return + } + } + } +} + +var ( + _ json.Unmarshaler = &Map[string, any]{} + _ json.Marshaler = &Map[string, any]{} +) + +// UnmarshalJSON implements json.Unmarshaler. +func (m *Map[K, V]) UnmarshalJSON(data []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + m.inner = make(map[K]V) + return json.Unmarshal(data, &m.inner) +} + +// MarshalJSON implements json.Marshaler. +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return json.Marshal(m.inner) +} diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5eddd92ce201f12d5f59620817a9e04c4e2f3008 --- /dev/null +++ b/internal/csync/maps_test.go @@ -0,0 +1,450 @@ +package csync + +import ( + "encoding/json" + "maps" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + assert.NotNil(t, m) + assert.NotNil(t, m.inner) + assert.Equal(t, 0, m.Len()) +} + +func TestNewMapFrom(t *testing.T) { + t.Parallel() + + original := map[string]int{ + "key1": 1, + "key2": 2, + } + + m := NewMapFrom(original) + assert.NotNil(t, m) + assert.Equal(t, original, m.inner) + assert.Equal(t, 2, m.Len()) + + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) +} + +func TestMap_Set(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + m.Set("key1", 42) + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 1, m.Len()) + + m.Set("key1", 100) + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 100, value) + assert.Equal(t, 1, m.Len()) +} + +func TestMap_Get(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + value, ok := m.Get("nonexistent") + assert.False(t, ok) + assert.Equal(t, 0, value) + + m.Set("key1", 42) + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) +} + +func TestMap_Del(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + m.Set("key2", 100) + + assert.Equal(t, 2, m.Len()) + + m.Del("key1") + _, ok := m.Get("key1") + assert.False(t, ok) + assert.Equal(t, 1, m.Len()) + + value, ok := m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 100, value) + + m.Del("nonexistent") + assert.Equal(t, 1, m.Len()) +} + +func TestMap_Len(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + assert.Equal(t, 0, m.Len()) + + m.Set("key1", 1) + assert.Equal(t, 1, m.Len()) + + m.Set("key2", 2) + assert.Equal(t, 2, m.Len()) + + m.Del("key1") + assert.Equal(t, 1, m.Len()) + + m.Del("key2") + assert.Equal(t, 0, m.Len()) +} + +func TestMap_Seq2(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + collected := maps.Collect(m.Seq2()) + + assert.Equal(t, 3, len(collected)) + assert.Equal(t, 1, collected["key1"]) + assert.Equal(t, 2, collected["key2"]) + assert.Equal(t, 3, collected["key3"]) +} + +func TestMap_Seq2_EarlyReturn(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + count := 0 + for range m.Seq2() { + count++ + if count == 2 { + break + } + } + + assert.Equal(t, 2, count) +} + +func TestMap_Seq2_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + count := 0 + for range m.Seq2() { + count++ + } + + assert.Equal(t, 0, count) +} + +func TestMap_MarshalJSON(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + + data, err := json.Marshal(m) + assert.NoError(t, err) + + var result map[string]int + err = json.Unmarshal(data, &result) + assert.NoError(t, err) + assert.Equal(t, 2, len(result)) + assert.Equal(t, 1, result["key1"]) + assert.Equal(t, 2, result["key2"]) +} + +func TestMap_MarshalJSON_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + data, err := json.Marshal(m) + assert.NoError(t, err) + assert.Equal(t, "{}", string(data)) +} + +func TestMap_UnmarshalJSON(t *testing.T) { + t.Parallel() + + jsonData := `{"key1": 1, "key2": 2}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + + assert.Equal(t, 2, m.Len()) + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) + + value, ok = m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 2, value) +} + +func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) { + t.Parallel() + + jsonData := `{}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + assert.Equal(t, 0, m.Len()) +} + +func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + jsonData := `{"key1": 1, "key2":}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.Error(t, err) +} + +func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("existing", 999) + + jsonData := `{"key1": 1, "key2": 2}` + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + + assert.Equal(t, 2, m.Len()) + _, ok := m.Get("existing") + assert.False(t, ok) + + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) +} + +func TestMap_JSONRoundTrip(t *testing.T) { + t.Parallel() + + original := NewMap[string, int]() + original.Set("key1", 1) + original.Set("key2", 2) + original.Set("key3", 3) + + data, err := json.Marshal(original) + assert.NoError(t, err) + + restored := NewMap[string, int]() + err = json.Unmarshal(data, restored) + assert.NoError(t, err) + + assert.Equal(t, original.Len(), restored.Len()) + + for k, v := range original.Seq2() { + restoredValue, ok := restored.Get(k) + assert.True(t, ok) + assert.Equal(t, v, restoredValue) + } +} + +func TestMap_ConcurrentAccess(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numGoroutines = 100 + const numOperations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := range numGoroutines { + go func(id int) { + defer wg.Done() + for j := range numOperations { + key := id*numOperations + j + m.Set(key, key*2) + value, ok := m.Get(key) + assert.True(t, ok) + assert.Equal(t, key*2, value) + } + }(i) + } + + wg.Wait() + + assert.Equal(t, numGoroutines*numOperations, m.Len()) +} + +func TestMap_ConcurrentReadWrite(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numReaders = 50 + const numWriters = 50 + const numOperations = 100 + + for i := range 1000 { + m.Set(i, i) + } + + var wg sync.WaitGroup + wg.Add(numReaders + numWriters) + + for range numReaders { + go func() { + defer wg.Done() + for j := range numOperations { + key := j % 1000 + value, ok := m.Get(key) + if ok { + assert.Equal(t, key, value) + } + _ = m.Len() + } + }() + } + + for i := range numWriters { + go func(id int) { + defer wg.Done() + for j := range numOperations { + key := 1000 + id*numOperations + j + m.Set(key, key) + if j%10 == 0 { + m.Del(key) + } + } + }(i) + } + + wg.Wait() +} + +func TestMap_ConcurrentSeq2(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + for i := range 100 { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numIterators = 10 + + wg.Add(numIterators) + for range numIterators { + go func() { + defer wg.Done() + count := 0 + for k, v := range m.Seq2() { + assert.Equal(t, k*2, v) + count++ + } + assert.Equal(t, 100, count) + }() + } + + wg.Wait() +} + +func TestMap_TypeSafety(t *testing.T) { + t.Parallel() + + stringIntMap := NewMap[string, int]() + stringIntMap.Set("key", 42) + value, ok := stringIntMap.Get("key") + assert.True(t, ok) + assert.Equal(t, 42, value) + + intStringMap := NewMap[int, string]() + intStringMap.Set(42, "value") + strValue, ok := intStringMap.Get(42) + assert.True(t, ok) + assert.Equal(t, "value", strValue) + + structMap := NewMap[string, struct{ Name string }]() + structMap.Set("key", struct{ Name string }{Name: "test"}) + structValue, ok := structMap.Get("key") + assert.True(t, ok) + assert.Equal(t, "test", structValue.Name) +} + +func TestMap_InterfaceCompliance(t *testing.T) { + t.Parallel() + + var _ json.Marshaler = &Map[string, any]{} + var _ json.Unmarshaler = &Map[string, any]{} +} + +func BenchmarkMap_Set(b *testing.B) { + m := NewMap[int, int]() + + for i := 0; b.Loop(); i++ { + m.Set(i, i*2) + } +} + +func BenchmarkMap_Get(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for i := 0; b.Loop(); i++ { + m.Get(i % 1000) + } +} + +func BenchmarkMap_Seq2(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for b.Loop() { + for range m.Seq2() { + } + } +} + +func BenchmarkMap_ConcurrentReadWrite(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + m.Get(i % 1000) + } else { + m.Set(i+1000, i*2) + } + i++ + } + }) +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index e920651d0faeb87da765c4ab67735c1c2d285001..2f76cc7771e3f0383f20b4ef1dffe448e06a253c 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -8,9 +8,9 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" + "github.com/charmbracelet/crush/csync" "github.com/charmbracelet/crush/internal/config" fur "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/history" @@ -68,8 +68,7 @@ type agent struct { sessions session.Service messages message.Service - toolsDone atomic.Bool - tools []tools.BaseTool + tools *csync.LazySlice[tools.BaseTool] provider provider.Provider providerID string @@ -168,24 +167,10 @@ func NewAgent( return nil, err } - agent := &agent{ - Broker: pubsub.NewBroker[AgentEvent](), - agentCfg: agentCfg, - provider: agentProvider, - providerID: string(providerCfg.ID), - messages: messages, - sessions: sessions, - titleProvider: titleProvider, - summarizeProvider: summarizeProvider, - summarizeProviderID: string(smallModelProviderCfg.ID), - activeRequests: sync.Map{}, - } - - go func() { + toolFn := func() []tools.BaseTool { slog.Info("Initializing agent tools", "agent", agentCfg.ID) defer func() { slog.Info("Initialized agent tools", "agent", agentCfg.ID) - agent.toolsDone.Store(true) }() cwd := cfg.WorkingDir() @@ -214,8 +199,7 @@ func NewAgent( } if agentCfg.AllowedTools == nil { - agent.tools = allTools - return + return allTools } var filteredTools []tools.BaseTool @@ -224,10 +208,22 @@ func NewAgent( filteredTools = append(filteredTools, tool) } } - agent.tools = filteredTools - }() + return filteredTools + } - return agent, nil + return &agent{ + Broker: pubsub.NewBroker[AgentEvent](), + agentCfg: agentCfg, + provider: agentProvider, + providerID: string(providerCfg.ID), + messages: messages, + sessions: sessions, + titleProvider: titleProvider, + summarizeProvider: summarizeProvider, + summarizeProviderID: string(smallModelProviderCfg.ID), + activeRequests: sync.Map{}, + tools: csync.NewLazySlice(toolFn), + }, nil } func (a *agent) Model() fur.Model { @@ -449,10 +445,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - if !a.toolsDone.Load() { - return message.Message{}, nil, fmt.Errorf("agent is still initializing, please wait a moment and try again") - } - eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Iter())) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, @@ -501,7 +494,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg default: // Continue processing var tool tools.BaseTool - for _, availableTool := range a.tools { + for availableTool := range a.tools.Iter() { if availableTool.Info().Name == toolCall.Name { tool = availableTool break @@ -911,7 +904,7 @@ func (a *agent) UpdateModel() error { smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] var smallModelProviderCfg config.ProviderConfig - for _, p := range cfg.Providers { + for _, p := range cfg.Providers.Seq2() { if p.ID == smallModelCfg.Provider { smallModelProviderCfg = p break diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index f7a6dce4baa2c3a2798c30baa6b995f6da72d05b..2a2c47a171a7ac685d644005e61e507a3964389f 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -422,7 +422,7 @@ func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provide func (s *splashCmp) isProviderConfigured(providerID string) bool { cfg := config.Get() - if _, ok := cfg.Providers[providerID]; ok { + if _, ok := cfg.Providers.Get(providerID); ok { return true } return false diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 86b1b9a3fa0b4b6faa56a927a9011673aa8365af..5f558364eec801d77a250c891a80110e0c9a3b86 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -103,7 +103,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { if err != nil { return util.ReportError(err) } - for providerID, providerConfig := range cfg.Providers { + for providerID, providerConfig := range cfg.Providers.Seq2() { if providerConfig.Disable { continue } @@ -164,7 +164,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } // Check if this provider is configured and not disabled - if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable { + if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable { continue } @@ -174,7 +174,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } section := commands.NewItemSection(name) - if _, ok := cfg.Providers[string(provider.ID)]; ok { + if _, ok := cfg.Providers.Get(string(provider.ID)); ok { section.SetInfo(configured) } modelItems = append(modelItems, section) diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index b28efc6010582a503c34e87ad101832925d8acca..b53388d16f17bbae8612cc66d1525e3e0e616db5 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -357,7 +357,7 @@ func (m *modelDialogCmp) modelTypeRadio() string { func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { cfg := config.Get() - if _, ok := cfg.Providers[providerID]; ok { + if _, ok := cfg.Providers.Get(providerID); ok { return true } return false