diff --git a/go.mod b/go.mod index 1ab2021a0b2c32a7ce9eeb6cf4827459e8bb27a6..32d9dc28d09cc133801c00d0628d6630922eacf5 100644 --- a/go.mod +++ b/go.mod @@ -42,11 +42,12 @@ require ( github.com/tidwall/sjson v1.2.5 github.com/u-root/u-root v0.14.1-0.20250724181933-b01901710169 github.com/zeebo/xxh3 v1.0.2 - golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 gopkg.in/natefinch/lumberjack.v2 v2.2.1 mvdan.cc/sh/v3 v3.11.0 ) +require golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect + require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.13.0 // indirect diff --git a/internal/app/app.go b/internal/app/app.go index 05022370977b7c2e0ff6ef6911d3ba40e37982e4..50e117ea1ae272156dbd11baa1a5f157a74333f1 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -84,7 +84,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { app.setupEvents() // Initialize LSP clients in the background. - go app.initLSPClients(ctx) + app.initLSPClients(ctx) // TODO: remove the concept of agent config, most likely. if cfg.IsConfigured() { diff --git a/internal/config/config.go b/internal/config/config.go index f75b99e2f34ae184908eceb9df8ad01a1042232c..9709c11a0636d91cb492b7735b63e46e5e843c74 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "context" "fmt" + "log/slog" "net/http" "os" "slices" @@ -10,9 +11,9 @@ import ( "time" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/tidwall/sjson" - "golang.org/x/exp/slog" ) const ( @@ -240,7 +241,7 @@ type Config struct { Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` // The providers that are configured - Providers map[string]ProviderConfig `json:"providers,omitempty"` + Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"` MCP MCPs `json:"mcp,omitempty"` @@ -265,8 +266,8 @@ func (c *Config) WorkingDir() string { } func (c *Config) EnabledProviders() []ProviderConfig { - enabled := make([]ProviderConfig, 0, len(c.Providers)) - for _, p := range c.Providers { + var enabled []ProviderConfig + for _, p := range c.Providers.Seq2() { if !p.Disable { enabled = append(enabled, p) } @@ -280,7 +281,7 @@ func (c *Config) IsConfigured() bool { } func (c *Config) GetModel(provider, model string) *catwalk.Model { - if providerConfig, ok := c.Providers[provider]; ok { + if providerConfig, ok := c.Providers.Get(provider); ok { for _, m := range providerConfig.Models { if m.ID == model { return &m @@ -295,7 +296,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi if !ok { return nil } - if providerConfig, ok := c.Providers[model.Provider]; ok { + if providerConfig, ok := c.Providers.Get(model.Provider); ok { return &providerConfig } return nil @@ -376,14 +377,10 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return fmt.Errorf("failed to save API key to config file: %w", err) } - if c.Providers == nil { - c.Providers = make(map[string]ProviderConfig) - } - - providerConfig, exists := c.Providers[providerID] + providerConfig, exists := c.Providers.Get(providerID) if exists { providerConfig.APIKey = apiKey - c.Providers[providerID] = providerConfig + c.Providers.Set(providerID, providerConfig) return nil } @@ -412,7 +409,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error { return fmt.Errorf("provider with ID %s not found in known providers", providerID) } // Store the updated provider config - c.Providers[providerID] = providerConfig + c.Providers.Set(providerID, providerConfig) return nil } diff --git a/internal/config/init.go b/internal/config/init.go index 827a287718e40e1fc5b9b761293c00799ec5ef3d..ff44d43bb878f579d003c84537fcd970f9e52f9e 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -5,7 +5,6 @@ import ( "os" "path/filepath" "strings" - "sync" "sync/atomic" ) @@ -18,26 +17,20 @@ type ProjectInitFlag struct { } // TODO: we need to remove the global config instance keeping it now just until everything is migrated -var ( - instance atomic.Pointer[Config] - cwd string - once sync.Once // Ensures the initialization happens only once -) +var instance atomic.Pointer[Config] func Init(workingDir string, debug bool) (*Config, error) { - var err error - once.Do(func() { - cwd = workingDir - var cfg *Config - cfg, err = Load(cwd, debug) - instance.Store(cfg) - }) - - return instance.Load(), err + cfg, err := Load(workingDir, debug) + if err != nil { + return nil, err + } + instance.Store(cfg) + return instance.Load(), nil } func Get() *Config { - return instance.Load() + cfg := instance.Load() + return cfg } func ProjectNeedsInitialization() (bool, error) { diff --git a/internal/config/load.go b/internal/config/load.go index bf0fc3d562d6a38544399a095f4efe0a5f75fcd2..98569d41be810dd0b9382c4df56cfb3e9c1c5842 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -4,17 +4,17 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "os" "path/filepath" "runtime" "slices" "strings" - "sync" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/log" - "golang.org/x/exp/slog" ) const catwalkURL = "https://catwalk.charm.sh" @@ -63,7 +63,7 @@ func Load(workingDir string, debug bool) (*Config, error) { ) // Load known providers, this loads the config from catwalk - providers, err := LoadProviders(catwalk.NewWithURL(catwalkURL)) + providers, err := Providers() if err != nil || len(providers) == 0 { return nil, fmt.Errorf("failed to load providers: %w", err) } @@ -77,35 +77,6 @@ func Load(workingDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to configure providers: %w", err) } - // Test provider connections in parallel - var testResults sync.Map - var wg sync.WaitGroup - - for _, p := range cfg.Providers { - if p.Type == catwalk.TypeOpenAI || p.Type == catwalk.TypeAnthropic { - wg.Add(1) - go func(provider ProviderConfig) { - defer wg.Done() - err := provider.TestConnection(cfg.resolver) - testResults.Store(provider.ID, err == nil) - if err != nil { - slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) - } - }(p) - } - } - wg.Wait() - - // Remove failed providers - testResults.Range(func(key, value any) bool { - providerID := key.(string) - passed := value.(bool) - if !passed { - delete(cfg.Providers, providerID) - } - return true - }) - if !cfg.IsConfigured() { slog.Warn("No providers configured") return cfg, nil @@ -122,12 +93,12 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know knownProviderNames := make(map[string]bool) for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true - config, configExists := c.Providers[string(p.ID)] + config, configExists := c.Providers.Get(string(p.ID)) // if the user configured a known provider we need to allow it to override a couple of parameters if configExists { if config.Disable { slog.Debug("Skipping provider due to disable flag", "provider", p.ID) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) continue } if config.BaseURL != "" { @@ -183,7 +154,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if !hasVertexCredentials(env) { if configExists { slog.Warn("Skipping Vertex AI provider due to missing credentials") - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -194,7 +165,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if err != nil || endpoint == "" { if configExists { slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -204,7 +175,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if !hasAWSCredentials(env) { if configExists { slog.Warn("Skipping Bedrock provider due to missing AWS credentials") - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } @@ -219,16 +190,16 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if v == "" || err != nil { if configExists { slog.Warn("Skipping provider due to missing API key", "provider", p.ID) - delete(c.Providers, string(p.ID)) + c.Providers.Del(string(p.ID)) } continue } } - c.Providers[string(p.ID)] = prepared + c.Providers.Set(string(p.ID), prepared) } // validate the custom providers - for id, providerConfig := range c.Providers { + for id, providerConfig := range c.Providers.Seq2() { if knownProviderNames[id] { continue } @@ -245,7 +216,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know if providerConfig.Disable { slog.Debug("Skipping custom provider due to disable flag", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if providerConfig.APIKey == "" { @@ -253,17 +224,17 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } if providerConfig.BaseURL == "" { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if len(providerConfig.Models) == 0 { slog.Warn("Skipping custom provider because the provider has no models", "provider", id) - delete(c.Providers, id) + c.Providers.Del(id) continue } if providerConfig.Type != catwalk.TypeOpenAI { slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type) - delete(c.Providers, id) + c.Providers.Del(id) continue } @@ -274,11 +245,11 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know baseURL, err := resolver.ResolveValue(providerConfig.BaseURL) if baseURL == "" || err != nil { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err) - delete(c.Providers, id) + c.Providers.Del(id) continue } - c.Providers[id] = providerConfig + c.Providers.Set(id, providerConfig) } return nil } @@ -298,7 +269,7 @@ func (c *Config) setDefaults(workingDir string) { c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory) } if c.Providers == nil { - c.Providers = make(map[string]ProviderConfig) + c.Providers = csync.NewMap[string, ProviderConfig]() } if c.Models == nil { c.Models = make(map[SelectedModelType]SelectedModel) @@ -317,7 +288,7 @@ func (c *Config) setDefaults(workingDir string) { } func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { - if len(knownProviders) == 0 && len(c.Providers) == 0 { + if len(knownProviders) == 0 && c.Providers.Len() == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return } @@ -325,7 +296,7 @@ func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (large // Use the first provider enabled based on the known providers order // if no provider found that is known use the first provider configured for _, p := range knownProviders { - providerConfig, ok := c.Providers[string(p.ID)] + providerConfig, ok := c.Providers.Get(string(p.ID)) if !ok || providerConfig.Disable { continue } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index a3d224d443b8995b747481871a82a097afa02e1b..5a52426f51ace9ee9e26bb42208511a72009dc3b 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/assert" ) @@ -29,9 +30,10 @@ func TestConfig_LoadFromReaders(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, loadedConfig) - assert.Len(t, loadedConfig.Providers, 1) - assert.Equal(t, "key2", loadedConfig.Providers["openai"].APIKey) - assert.Equal(t, "https://api.openai.com/v2", loadedConfig.Providers["openai"].BaseURL) + assert.Equal(t, 1, loadedConfig.Providers.Len()) + pc, _ := loadedConfig.Providers.Get("openai") + assert.Equal(t, "key2", pc.APIKey) + assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) } func TestConfig_setDefaults(t *testing.T) { @@ -73,10 +75,11 @@ func TestConfig_configureProviders(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, 1, cfg.Providers.Len()) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "$OPENAI_API_KEY", cfg.Providers["openai"].APIKey) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey) } func TestConfig_configureProvidersWithOverride(t *testing.T) { @@ -92,22 +95,21 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ - "openai": { - APIKey: "xyz", - BaseURL: "https://api.openai.com/v2", - Models: []catwalk.Model{ - { - ID: "test-model", - Name: "Updated", - }, - { - ID: "another-model", - }, - }, + Providers: csync.NewMap[string, ProviderConfig](), + } + cfg.Providers.Set("openai", ProviderConfig{ + APIKey: "xyz", + BaseURL: "https://api.openai.com/v2", + Models: []catwalk.Model{ + { + ID: "test-model", + Name: "Updated", + }, + { + ID: "another-model", }, }, - } + }) cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ @@ -116,13 +118,14 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, 1, cfg.Providers.Len()) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey) - assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL) - assert.Len(t, cfg.Providers["openai"].Models, 2) - assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Name) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "xyz", pc.APIKey) + assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) + assert.Len(t, pc.Models, 2) + assert.Equal(t, "Updated", pc.Models[0].Name) } func TestConfig_configureProvidersWithNewProvider(t *testing.T) { @@ -138,7 +141,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "xyz", BaseURL: "https://api.someendpoint.com/v2", @@ -148,7 +151,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{ @@ -158,16 +161,17 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Should be to because of the env variable - assert.Len(t, cfg.Providers, 2) + assert.Equal(t, cfg.Providers.Len(), 2) // We want to make sure that we keep the configured API key as a placeholder - assert.Equal(t, "xyz", cfg.Providers["custom"].APIKey) + pc, _ := cfg.Providers.Get("custom") + assert.Equal(t, "xyz", pc.APIKey) // Make sure we set the ID correctly - assert.Equal(t, "custom", cfg.Providers["custom"].ID) - assert.Equal(t, "https://api.someendpoint.com/v2", cfg.Providers["custom"].BaseURL) - assert.Len(t, cfg.Providers["custom"].Models, 1) + assert.Equal(t, "custom", pc.ID) + assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL) + assert.Len(t, pc.Models, 1) - _, ok := cfg.Providers["openai"] + _, ok := cfg.Providers.Get("openai") assert.True(t, ok, "OpenAI provider should still be present") } @@ -192,9 +196,9 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) - bedrockProvider, ok := cfg.Providers["bedrock"] + bedrockProvider, ok := cfg.Providers.Get("bedrock") assert.True(t, ok, "Bedrock provider should be present") assert.Len(t, bedrockProvider.Models, 1) assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID) @@ -219,7 +223,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without credentials - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { @@ -267,9 +271,9 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) - vertexProvider, ok := cfg.Providers["vertexai"] + vertexProvider, ok := cfg.Providers.Get("vertexai") assert.True(t, ok, "VertexAI provider should be present") assert.Len(t, vertexProvider.Models, 1) assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID) @@ -300,7 +304,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without proper credentials - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { @@ -325,7 +329,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) // Provider should not be configured without project - assert.Len(t, cfg.Providers, 0) + assert.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersSetProviderID(t *testing.T) { @@ -348,16 +352,17 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) + assert.Equal(t, cfg.Providers.Len(), 1) // Provider ID should be set - assert.Equal(t, "openai", cfg.Providers["openai"].ID) + pc, _ := cfg.Providers.Get("openai") + assert.Equal(t, "openai", pc.ID) } func TestConfig_EnabledProviders(t *testing.T) { t.Run("all providers enabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -368,7 +373,7 @@ func TestConfig_EnabledProviders(t *testing.T) { APIKey: "key2", Disable: false, }, - }, + }), } enabled := cfg.EnabledProviders() @@ -377,7 +382,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("some providers disabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -388,7 +393,7 @@ func TestConfig_EnabledProviders(t *testing.T) { APIKey: "key2", Disable: true, }, - }, + }), } enabled := cfg.EnabledProviders() @@ -398,7 +403,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("empty providers map", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{}, + Providers: csync.NewMap[string, ProviderConfig](), } enabled := cfg.EnabledProviders() @@ -409,13 +414,13 @@ func TestConfig_EnabledProviders(t *testing.T) { func TestConfig_IsConfigured(t *testing.T) { t.Run("returns true when at least one provider is enabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", Disable: false, }, - }, + }), } assert.True(t, cfg.IsConfigured()) @@ -423,7 +428,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when no providers are configured", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{}, + Providers: csync.NewMap[string, ProviderConfig](), } assert.False(t, cfg.IsConfigured()) @@ -431,7 +436,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when all providers are disabled", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -442,7 +447,7 @@ func TestConfig_IsConfigured(t *testing.T) { APIKey: "key2", Disable: true, }, - }, + }), } assert.False(t, cfg.IsConfigured()) @@ -462,11 +467,11 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { Disable: true, }, - }, + }), } cfg.setDefaults("/tmp") @@ -478,15 +483,15 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { assert.NoError(t, err) // Provider should be removed from config when disabled - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("openai") assert.False(t, exists) } func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{{ @@ -496,7 +501,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "openai": { APIKey: "$MISSING", }, - }, + }), } cfg.setDefaults("/tmp") @@ -505,21 +510,21 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 1) + _, exists := cfg.Providers.Get("custom") assert.True(t, exists) }) t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", Models: []catwalk.Model{{ ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -528,20 +533,20 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("custom provider with no models is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -550,14 +555,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("custom provider with unsupported type is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -566,7 +571,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -575,14 +580,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) t.Run("valid custom provider is kept and ID is set", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -591,7 +596,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -600,8 +605,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - customProvider, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 1) + customProvider, exists := cfg.Providers.Get("custom") assert.True(t, exists) assert.Equal(t, "custom", customProvider.ID) assert.Equal(t, "test-key", customProvider.APIKey) @@ -610,7 +615,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("disabled custom provider is removed", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -620,7 +625,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -629,8 +634,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["custom"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("custom") assert.False(t, exists) }) } @@ -649,11 +654,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "vertexai": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -664,8 +669,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["vertexai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("vertexai") assert.False(t, exists) }) @@ -682,11 +687,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "bedrock": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -695,8 +700,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["bedrock"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("bedrock") assert.False(t, exists) }) @@ -713,11 +718,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { BaseURL: "custom-url", }, - }, + }), } cfg.setDefaults("/tmp") @@ -726,8 +731,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 0) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 0) + _, exists := cfg.Providers.Get("openai") assert.False(t, exists) }) @@ -744,11 +749,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { APIKey: "test-key", }, - }, + }), } cfg.setDefaults("/tmp") @@ -759,8 +764,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, knownProviders) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - _, exists := cfg.Providers["openai"] + assert.Equal(t, cfg.Providers.Len(), 1) + _, exists := cfg.Providers.Get("openai") assert.True(t, exists) }) } @@ -883,7 +888,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -894,7 +899,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) @@ -932,13 +937,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{}, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) @@ -969,7 +974,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -980,7 +985,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }, }, - }, + }), } cfg.setDefaults("/tmp") env := env.NewFromMap(map[string]string{}) diff --git a/internal/config/provider.go b/internal/config/provider.go index 9b5cdf608c36c36d62faffdb19e84c74013a1884..ba02f9d8e1bc0f2ec58c2ed3e736a87e1d7a614b 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -2,10 +2,13 @@ package config import ( "encoding/json" + "fmt" + "log/slog" "os" "path/filepath" "runtime" "sync" + "time" "github.com/charmbracelet/catwalk/pkg/catwalk" ) @@ -41,57 +44,88 @@ func providerCacheFileData() string { } func saveProvidersInCache(path string, providers []catwalk.Provider) error { - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0o755); err != nil { - return err + slog.Info("Saving cached provider data", "path", path) + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("failed to create directory for provider cache: %w", err) } data, err := json.MarshalIndent(providers, "", " ") if err != nil { - return err + return fmt.Errorf("failed to marshal provider data: %w", err) } - return os.WriteFile(path, data, 0o644) + if err := os.WriteFile(path, data, 0o644); err != nil { + return fmt.Errorf("failed to write provider data to cache: %w", err) + } + return nil } func loadProvidersFromCache(path string) ([]catwalk.Provider, error) { data, err := os.ReadFile(path) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read provider cache file: %w", err) } var providers []catwalk.Provider - err = json.Unmarshal(data, &providers) - return providers, err -} - -func loadProviders(path string, client ProviderClient) ([]catwalk.Provider, error) { - providers, err := client.GetProviders() - if err != nil { - fallbackToCache, err := loadProvidersFromCache(path) - if err != nil { - return nil, err - } - providers = fallbackToCache - } else { - if err := saveProvidersInCache(path, providerList); err != nil { - return nil, err - } + if err := json.Unmarshal(data, &providers); err != nil { + return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err) } return providers, nil } func Providers() ([]catwalk.Provider, error) { - return LoadProviders(catwalk.NewWithURL(catwalkURL)) + client := catwalk.NewWithURL(catwalkURL) + path := providerCacheFileData() + return loadProvidersOnce(client, path) } -func LoadProviders(client ProviderClient) ([]catwalk.Provider, error) { +func loadProvidersOnce(client ProviderClient, path string) ([]catwalk.Provider, error) { var err error providerOnce.Do(func() { - providerList, err = loadProviders(providerCacheFileData(), client) + providerList, err = loadProviders(client, path) }) if err != nil { return nil, err } return providerList, nil } + +func loadProviders(client ProviderClient, path string) (providerList []catwalk.Provider, err error) { + // if cache is not stale, load from it + stale, exists := isCacheStale(path) + if !stale { + slog.Info("Using cached provider data", "path", path) + providerList, err = loadProvidersFromCache(path) + if len(providerList) > 0 && err == nil { + go func() { + slog.Info("Updating provider cache in background") + updated, uerr := client.GetProviders() + if len(updated) > 0 && uerr == nil { + _ = saveProvidersInCache(path, updated) + } + }() + return + } + } + + slog.Info("Getting live provider data") + providerList, err = client.GetProviders() + if len(providerList) > 0 && err == nil { + err = saveProvidersInCache(path, providerList) + return + } + if !exists { + err = fmt.Errorf("failed to load providers") + return + } + providerList, err = loadProvidersFromCache(path) + return +} + +func isCacheStale(path string) (stale, exists bool) { + info, err := os.Stat(path) + if err != nil { + return true, false + } + return time.Since(info.ModTime()) > 24*time.Hour, true +} diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cb71cabfa5a01cb16b6ef2b6708d1780e31951a9 --- /dev/null +++ b/internal/config/provider_empty_test.go @@ -0,0 +1,47 @@ +package config + +import ( + "encoding/json" + "os" + "testing" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +type emptyProviderClient struct{} + +func (m *emptyProviderClient) GetProviders() ([]catwalk.Provider, error) { + return []catwalk.Provider{}, nil +} + +func TestProvider_loadProvidersEmptyResult(t *testing.T) { + client := &emptyProviderClient{} + tmpPath := t.TempDir() + "/providers.json" + + providers, err := loadProviders(client, tmpPath) + require.EqualError(t, err, "failed to load providers") + require.Empty(t, providers) + require.Len(t, providers, 0) + + // Check that no cache file was created for empty results + require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results") +} + +func TestProvider_loadProvidersEmptyCache(t *testing.T) { + client := &mockProviderClient{shouldFail: false} + tmpPath := t.TempDir() + "/providers.json" + + // Create an empty cache file + emptyProviders := []catwalk.Provider{} + data, err := json.Marshal(emptyProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) + + // Should refresh and get real providers instead of using empty cache + providers, err := loadProviders(client, tmpPath) + require.NoError(t, err) + require.NotNil(t, providers) + require.Len(t, providers, 1) + require.Equal(t, "Mock", providers[0].Name) +} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index a63099ee27c96abb97d2781b186bb5aa9e060396..e6a1f331716d88285ef4c9929a23a474ed3597a0 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -28,7 +28,7 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { func TestProvider_loadProvidersNoIssues(t *testing.T) { client := &mockProviderClient{shouldFail: false} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.NoError(t, err) assert.NotNil(t, providers) assert.Len(t, providers, 1) @@ -57,7 +57,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { if err != nil { t.Fatalf("Failed to write old providers to file: %v", err) } - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.NoError(t, err) assert.NotNil(t, providers) assert.Len(t, providers, 1) @@ -67,7 +67,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.Error(t, err) assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") } diff --git a/internal/csync/maps.go b/internal/csync/maps.go new file mode 100644 index 0000000000000000000000000000000000000000..45e426630a4e50b45125d41dcca54d4e183b4f6f --- /dev/null +++ b/internal/csync/maps.go @@ -0,0 +1,92 @@ +package csync + +import ( + "encoding/json" + "iter" + "maps" + "sync" +) + +// Map is a concurrent map implementation that provides thread-safe access. +type Map[K comparable, V any] struct { + inner map[K]V + mu sync.RWMutex +} + +// NewMap creates a new thread-safe map with the specified key and value types. +func NewMap[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ + inner: make(map[K]V), + } +} + +// NewMapFrom creates a new thread-safe map from an existing map. +func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] { + return &Map[K, V]{ + inner: m, + } +} + +// Set sets the value for the specified key in the map. +func (m *Map[K, V]) Set(key K, value V) { + m.mu.Lock() + defer m.mu.Unlock() + m.inner[key] = value +} + +// Del deletes the specified key from the map. +func (m *Map[K, V]) Del(key K) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.inner, key) +} + +// Get gets the value for the specified key from the map. +func (m *Map[K, V]) Get(key K) (V, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + v, ok := m.inner[key] + return v, ok +} + +// Len returns the number of items in the map. +func (m *Map[K, V]) Len() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.inner) +} + +// Seq2 returns an iter.Seq2 that yields key-value pairs from the map. +func (m *Map[K, V]) Seq2() iter.Seq2[K, V] { + dst := make(map[K]V) + m.mu.RLock() + maps.Copy(dst, m.inner) + m.mu.RUnlock() + return func(yield func(K, V) bool) { + for k, v := range dst { + if !yield(k, v) { + return + } + } + } +} + +var ( + _ json.Unmarshaler = &Map[string, any]{} + _ json.Marshaler = &Map[string, any]{} +) + +// UnmarshalJSON implements json.Unmarshaler. +func (m *Map[K, V]) UnmarshalJSON(data []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + m.inner = make(map[K]V) + return json.Unmarshal(data, &m.inner) +} + +// MarshalJSON implements json.Marshaler. +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return json.Marshal(m.inner) +} diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go new file mode 100644 index 0000000000000000000000000000000000000000..73e6f1db245231e9fad82103366d96a326acc4f6 --- /dev/null +++ b/internal/csync/maps_test.go @@ -0,0 +1,452 @@ +package csync + +import ( + "encoding/json" + "maps" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + assert.NotNil(t, m) + assert.NotNil(t, m.inner) + assert.Equal(t, 0, m.Len()) +} + +func TestNewMapFrom(t *testing.T) { + t.Parallel() + + original := map[string]int{ + "key1": 1, + "key2": 2, + } + + m := NewMapFrom(original) + assert.NotNil(t, m) + assert.Equal(t, original, m.inner) + assert.Equal(t, 2, m.Len()) + + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) +} + +func TestMap_Set(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + m.Set("key1", 42) + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, 1, m.Len()) + + m.Set("key1", 100) + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 100, value) + assert.Equal(t, 1, m.Len()) +} + +func TestMap_Get(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + value, ok := m.Get("nonexistent") + assert.False(t, ok) + assert.Equal(t, 0, value) + + m.Set("key1", 42) + value, ok = m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 42, value) +} + +func TestMap_Del(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 42) + m.Set("key2", 100) + + assert.Equal(t, 2, m.Len()) + + m.Del("key1") + _, ok := m.Get("key1") + assert.False(t, ok) + assert.Equal(t, 1, m.Len()) + + value, ok := m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 100, value) + + m.Del("nonexistent") + assert.Equal(t, 1, m.Len()) +} + +func TestMap_Len(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + assert.Equal(t, 0, m.Len()) + + m.Set("key1", 1) + assert.Equal(t, 1, m.Len()) + + m.Set("key2", 2) + assert.Equal(t, 2, m.Len()) + + m.Del("key1") + assert.Equal(t, 1, m.Len()) + + m.Del("key2") + assert.Equal(t, 0, m.Len()) +} + +func TestMap_Seq2(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + collected := maps.Collect(m.Seq2()) + + assert.Equal(t, 3, len(collected)) + assert.Equal(t, 1, collected["key1"]) + assert.Equal(t, 2, collected["key2"]) + assert.Equal(t, 3, collected["key3"]) +} + +func TestMap_Seq2_EarlyReturn(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + m.Set("key3", 3) + + count := 0 + for range m.Seq2() { + count++ + if count == 2 { + break + } + } + + assert.Equal(t, 2, count) +} + +func TestMap_Seq2_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + count := 0 + for range m.Seq2() { + count++ + } + + assert.Equal(t, 0, count) +} + +func TestMap_MarshalJSON(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("key1", 1) + m.Set("key2", 2) + + data, err := json.Marshal(m) + assert.NoError(t, err) + + result := &Map[string, int]{} + err = json.Unmarshal(data, result) + assert.NoError(t, err) + assert.Equal(t, 2, result.Len()) + v1, _ := result.Get("key1") + v2, _ := result.Get("key2") + assert.Equal(t, 1, v1) + assert.Equal(t, 2, v2) +} + +func TestMap_MarshalJSON_EmptyMap(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + + data, err := json.Marshal(m) + assert.NoError(t, err) + assert.Equal(t, "{}", string(data)) +} + +func TestMap_UnmarshalJSON(t *testing.T) { + t.Parallel() + + jsonData := `{"key1": 1, "key2": 2}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + + assert.Equal(t, 2, m.Len()) + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) + + value, ok = m.Get("key2") + assert.True(t, ok) + assert.Equal(t, 2, value) +} + +func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) { + t.Parallel() + + jsonData := `{}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + assert.Equal(t, 0, m.Len()) +} + +func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + jsonData := `{"key1": 1, "key2":}` + + m := NewMap[string, int]() + err := json.Unmarshal([]byte(jsonData), m) + assert.Error(t, err) +} + +func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) { + t.Parallel() + + m := NewMap[string, int]() + m.Set("existing", 999) + + jsonData := `{"key1": 1, "key2": 2}` + err := json.Unmarshal([]byte(jsonData), m) + assert.NoError(t, err) + + assert.Equal(t, 2, m.Len()) + _, ok := m.Get("existing") + assert.False(t, ok) + + value, ok := m.Get("key1") + assert.True(t, ok) + assert.Equal(t, 1, value) +} + +func TestMap_JSONRoundTrip(t *testing.T) { + t.Parallel() + + original := NewMap[string, int]() + original.Set("key1", 1) + original.Set("key2", 2) + original.Set("key3", 3) + + data, err := json.Marshal(original) + assert.NoError(t, err) + + restored := NewMap[string, int]() + err = json.Unmarshal(data, restored) + assert.NoError(t, err) + + assert.Equal(t, original.Len(), restored.Len()) + + for k, v := range original.Seq2() { + restoredValue, ok := restored.Get(k) + assert.True(t, ok) + assert.Equal(t, v, restoredValue) + } +} + +func TestMap_ConcurrentAccess(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numGoroutines = 100 + const numOperations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := range numGoroutines { + go func(id int) { + defer wg.Done() + for j := range numOperations { + key := id*numOperations + j + m.Set(key, key*2) + value, ok := m.Get(key) + assert.True(t, ok) + assert.Equal(t, key*2, value) + } + }(i) + } + + wg.Wait() + + assert.Equal(t, numGoroutines*numOperations, m.Len()) +} + +func TestMap_ConcurrentReadWrite(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + const numReaders = 50 + const numWriters = 50 + const numOperations = 100 + + for i := range 1000 { + m.Set(i, i) + } + + var wg sync.WaitGroup + wg.Add(numReaders + numWriters) + + for range numReaders { + go func() { + defer wg.Done() + for j := range numOperations { + key := j % 1000 + value, ok := m.Get(key) + if ok { + assert.Equal(t, key, value) + } + _ = m.Len() + } + }() + } + + for i := range numWriters { + go func(id int) { + defer wg.Done() + for j := range numOperations { + key := 1000 + id*numOperations + j + m.Set(key, key) + if j%10 == 0 { + m.Del(key) + } + } + }(i) + } + + wg.Wait() +} + +func TestMap_ConcurrentSeq2(t *testing.T) { + t.Parallel() + + m := NewMap[int, int]() + for i := range 100 { + m.Set(i, i*2) + } + + var wg sync.WaitGroup + const numIterators = 10 + + wg.Add(numIterators) + for range numIterators { + go func() { + defer wg.Done() + count := 0 + for k, v := range m.Seq2() { + assert.Equal(t, k*2, v) + count++ + } + assert.Equal(t, 100, count) + }() + } + + wg.Wait() +} + +func TestMap_TypeSafety(t *testing.T) { + t.Parallel() + + stringIntMap := NewMap[string, int]() + stringIntMap.Set("key", 42) + value, ok := stringIntMap.Get("key") + assert.True(t, ok) + assert.Equal(t, 42, value) + + intStringMap := NewMap[int, string]() + intStringMap.Set(42, "value") + strValue, ok := intStringMap.Get(42) + assert.True(t, ok) + assert.Equal(t, "value", strValue) + + structMap := NewMap[string, struct{ Name string }]() + structMap.Set("key", struct{ Name string }{Name: "test"}) + structValue, ok := structMap.Get("key") + assert.True(t, ok) + assert.Equal(t, "test", structValue.Name) +} + +func TestMap_InterfaceCompliance(t *testing.T) { + t.Parallel() + + var _ json.Marshaler = &Map[string, any]{} + var _ json.Unmarshaler = &Map[string, any]{} +} + +func BenchmarkMap_Set(b *testing.B) { + m := NewMap[int, int]() + + for i := 0; b.Loop(); i++ { + m.Set(i, i*2) + } +} + +func BenchmarkMap_Get(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for i := 0; b.Loop(); i++ { + m.Get(i % 1000) + } +} + +func BenchmarkMap_Seq2(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + for b.Loop() { + for range m.Seq2() { + } + } +} + +func BenchmarkMap_ConcurrentReadWrite(b *testing.B) { + m := NewMap[int, int]() + for i := range 1000 { + m.Set(i, i*2) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + m.Get(i % 1000) + } else { + m.Set(i+1000, i*2) + } + i++ + } + }) +} diff --git a/internal/csync/slices.go b/internal/csync/slices.go new file mode 100644 index 0000000000000000000000000000000000000000..be723655079ccc6b07f55c3237b706a17bb14d40 --- /dev/null +++ b/internal/csync/slices.go @@ -0,0 +1,36 @@ +package csync + +import ( + "iter" + "sync" +) + +// LazySlice is a thread-safe lazy-loaded slice. +type LazySlice[K any] struct { + inner []K + wg sync.WaitGroup +} + +// NewLazySlice creates a new slice and runs the [load] function in a goroutine +// to populate it. +func NewLazySlice[K any](load func() []K) *LazySlice[K] { + s := &LazySlice[K]{} + s.wg.Add(1) + go func() { + s.inner = load() + s.wg.Done() + }() + return s +} + +// Seq returns an iterator that yields elements from the slice. +func (s *LazySlice[K]) Seq() iter.Seq[K] { + s.wg.Wait() + return func(yield func(K) bool) { + for _, v := range s.inner { + if !yield(v) { + return + } + } + } +} diff --git a/internal/csync/slices_test.go b/internal/csync/slices_test.go new file mode 100644 index 0000000000000000000000000000000000000000..731cb96f55dd24cae74f55c0ef8e97ebd28aacaa --- /dev/null +++ b/internal/csync/slices_test.go @@ -0,0 +1,87 @@ +package csync + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLazySlice_Seq(t *testing.T) { + t.Parallel() + + data := []string{"a", "b", "c"} + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens + return data + }) + + var result []string + for v := range s.Seq() { + result = append(result, v) + } + + assert.Equal(t, data, result) +} + +func TestLazySlice_SeqWaitsForLoading(t *testing.T) { + t.Parallel() + + var loaded atomic.Bool + data := []string{"x", "y", "z"} + + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(100 * time.Millisecond) + loaded.Store(true) + return data + }) + + assert.False(t, loaded.Load(), "should not be loaded immediately") + + var result []string + for v := range s.Seq() { + result = append(result, v) + } + + assert.True(t, loaded.Load(), "should be loaded after Seq") + assert.Equal(t, data, result) +} + +func TestLazySlice_EmptySlice(t *testing.T) { + t.Parallel() + + s := NewLazySlice(func() []string { + return []string{} + }) + + var result []string + for v := range s.Seq() { + result = append(result, v) + } + + assert.Empty(t, result) +} + +func TestLazySlice_EarlyBreak(t *testing.T) { + t.Parallel() + + data := []string{"a", "b", "c", "d", "e"} + s := NewLazySlice(func() []string { + // TODO: use synctest when new Go is out. + time.Sleep(10 * time.Millisecond) // Small delay to ensure loading happens + return data + }) + + var result []string + for v := range s.Seq() { + result = append(result, v) + if len(result) == 2 { + break + } + } + + assert.Equal(t, []string{"a", "b"}, result) +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index b16905f7736c3914dea9b25d8b53c661284d2faf..2c3876ccac9ed028b1714ed96b0c6de0cce007c9 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -12,6 +12,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/llm/provider" @@ -68,7 +69,8 @@ type agent struct { sessions session.Service messages message.Service - tools []tools.BaseTool + tools *csync.LazySlice[tools.BaseTool] + provider provider.Provider providerID string @@ -95,25 +97,8 @@ func NewAgent( ) (Service, error) { ctx := context.Background() cfg := config.Get() - otherTools := GetMCPTools(ctx, permissions, cfg) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - - cwd := cfg.WorkingDir() - allTools := []tools.BaseTool{ - tools.NewBashTool(permissions, cwd), - tools.NewDownloadTool(permissions, cwd), - tools.NewEditTool(lspClients, permissions, history, cwd), - tools.NewFetchTool(permissions, cwd), - tools.NewGlobTool(cwd), - tools.NewGrepTool(cwd), - tools.NewLsTool(cwd), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients, cwd), - tools.NewWriteTool(lspClients, permissions, history, cwd), - } + var agentTool tools.BaseTool if agentCfg.ID == "coder" { taskAgentCfg := config.Get().Agents["task"] if taskAgentCfg.ID == "" { @@ -124,17 +109,9 @@ func NewAgent( return nil, fmt.Errorf("failed to create task agent: %w", err) } - allTools = append( - allTools, - NewAgentTool( - taskAgent, - sessions, - messages, - ), - ) + agentTool = NewAgentTool(taskAgent, sessions, messages) } - allTools = append(allTools, otherTools...) providerCfg := config.Get().GetProviderForModel(agentCfg.Model) if providerCfg == nil { return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) @@ -191,32 +168,63 @@ func NewAgent( return nil, err } - agentTools := []tools.BaseTool{} - if agentCfg.AllowedTools == nil { - agentTools = allTools - } else { + toolFn := func() []tools.BaseTool { + slog.Info("Initializing agent tools", "agent", agentCfg.ID) + defer func() { + slog.Info("Initialized agent tools", "agent", agentCfg.ID) + }() + + cwd := cfg.WorkingDir() + allTools := []tools.BaseTool{ + tools.NewBashTool(permissions, cwd), + tools.NewDownloadTool(permissions, cwd), + tools.NewEditTool(lspClients, permissions, history, cwd), + tools.NewFetchTool(permissions, cwd), + tools.NewGlobTool(cwd), + tools.NewGrepTool(cwd), + tools.NewLsTool(cwd), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients, cwd), + tools.NewWriteTool(lspClients, permissions, history, cwd), + } + + mcpTools := GetMCPTools(ctx, permissions, cfg) + allTools = append(allTools, mcpTools...) + + if len(lspClients) > 0 { + allTools = append(allTools, tools.NewDiagnosticsTool(lspClients)) + } + + if agentTool != nil { + allTools = append(allTools, agentTool) + } + + if agentCfg.AllowedTools == nil { + return allTools + } + + var filteredTools []tools.BaseTool for _, tool := range allTools { if slices.Contains(agentCfg.AllowedTools, tool.Name()) { - agentTools = append(agentTools, tool) + filteredTools = append(filteredTools, tool) } } + return filteredTools } - agent := &agent{ + return &agent{ Broker: pubsub.NewBroker[AgentEvent](), agentCfg: agentCfg, provider: agentProvider, providerID: string(providerCfg.ID), messages: messages, sessions: sessions, - tools: agentTools, titleProvider: titleProvider, summarizeProvider: summarizeProvider, summarizeProviderID: string(smallModelProviderCfg.ID), activeRequests: sync.Map{}, - } - - return agent, nil + tools: csync.NewLazySlice(toolFn), + }, nil } func (a *agent) Model() catwalk.Model { @@ -284,7 +292,7 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str Parts: parts, }, }, - make([]tools.BaseTool, 0), + nil, ) var finalResponse *provider.ProviderResponse @@ -438,7 +446,7 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq())) assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, @@ -487,7 +495,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg default: // Continue processing var tool tools.BaseTool - for _, availableTool := range a.tools { + for availableTool := range a.tools.Seq() { if availableTool.Info().Name == toolCall.Name { tool = availableTool break @@ -737,7 +745,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { response := a.summarizeProvider.StreamResponse( summarizeCtx, msgsWithPrompt, - make([]tools.BaseTool, 0), + nil, ) var finalResponse *provider.ProviderResponse for r := range response { @@ -899,7 +907,7 @@ func (a *agent) UpdateModel() error { smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] var smallModelProviderCfg config.ProviderConfig - for _, p := range cfg.Providers { + for _, p := range cfg.Providers.Seq2() { if p.ID == smallModelCfg.Provider { smallModelProviderCfg = p break diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 4fa0cff4d17c28da16528d33ff54e2a905521387..b2d1da11148e74362e7b529b9ec78dc1810d0f0d 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -188,9 +188,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, } - if len(tools) > 0 { - config.Tools = g.convertTools(tools) - } + config.Tools = g.convertTools(tools) chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 @@ -290,9 +288,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, } - if len(tools) > 0 { - config.Tools = g.convertTools(tools) - } + config.Tools = g.convertTools(tools) chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 diff --git a/internal/llm/provider/openai_test.go b/internal/llm/provider/openai_test.go index 26c4d85ae35bbf4681719a12b568befccd8012af..ef79803c8a8aa1ee3fe6cb7de8bc8fa86f26c03c 100644 --- a/internal/llm/provider/openai_test.go +++ b/internal/llm/provider/openai_test.go @@ -11,7 +11,6 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" @@ -79,7 +78,7 @@ func TestOpenAIClientStreamChoices(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - eventsChan := client.stream(ctx, messages, []tools.BaseTool{}) + eventsChan := client.stream(ctx, messages, nil) // Collect events - this will panic without the bounds check for event := range eventsChan { diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 3aa25dfe13c1cf24de1e2fea3fa651bec1b07eb3..4c139d7ea4236feed33998d41535a68842778b1e 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -425,7 +425,7 @@ func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk. func (s *splashCmp) isProviderConfigured(providerID string) bool { cfg := config.Get() - if _, ok := cfg.Providers[providerID]; ok { + if _, ok := cfg.Providers.Get(providerID); ok { return true } return false diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 13051067413379b7b80968ca4d8eec4bc354d893..5a36ab736351f2c92154da997f01ba7360470d8a 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -103,7 +103,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { if err != nil { return util.ReportError(err) } - for providerID, providerConfig := range cfg.Providers { + for providerID, providerConfig := range cfg.Providers.Seq2() { if providerConfig.Disable { continue } @@ -164,7 +164,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } // Check if this provider is configured and not disabled - if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable { + if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable { continue } @@ -174,7 +174,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { } section := commands.NewItemSection(name) - if _, ok := cfg.Providers[string(provider.ID)]; ok { + if _, ok := cfg.Providers.Get(string(provider.ID)); ok { section.SetInfo(configured) } modelItems = append(modelItems, section) diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index eb0ed9eebcb5ebce41eff33ab09f7c0e5b995bde..795e2585760391bcd711491533a156f9b2c810ba 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -357,7 +357,7 @@ func (m *modelDialogCmp) modelTypeRadio() string { func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { cfg := config.Get() - if _, ok := cfg.Providers[providerID]; ok { + if _, ok := cfg.Providers.Get(providerID); ok { return true } return false