From 726840efa887b29ab66a6ed4497f6a2c43f407c8 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 6 Feb 2026 13:17:07 +0100 Subject: [PATCH] refactor: change Config.Providers from csync.Map to plain map MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace *csync.Map[string, ProviderConfig] with map[string]ProviderConfig on Config. Update all .Get/.Set/.Del/.Len /.Seq2/.Copy calls to plain map operations across the codebase. Remove csync dependency from config package. 🐾 Generated with Crush Assisted-by: Claude Opus 4.6 via Crush --- internal/agent/agent_tool.go | 2 +- internal/agent/agentic_fetch_tool.go | 2 +- internal/agent/coordinator.go | 12 +- internal/app/app.go | 3 +- internal/cmd/models.go | 2 +- internal/config/config.go | 9 +- internal/config/load.go | 35 +++-- internal/config/load_test.go | 215 +++++++++++++-------------- internal/config/service.go | 10 +- internal/ui/chat/messages.go | 2 +- internal/ui/dialog/models.go | 6 +- internal/ui/model/sidebar.go | 2 +- internal/ui/model/ui.go | 2 +- 13 files changed, 150 insertions(+), 152 deletions(-) diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index ceda8a093a8925a3de67141eacef4e453a842d75..64e95f5f403dad02502864cd39d1209b252c0621 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -67,7 +67,7 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) maxTokens = model.ModelCfg.MaxTokens } - providerCfg, ok := c.cfgSvc.Config().Providers.Get(model.ModelCfg.Provider) + providerCfg, ok := c.cfgSvc.Config().Providers[model.ModelCfg.Provider] if !ok { return fantasy.ToolResponse{}, errors.New("model provider not configured") } diff --git a/internal/agent/agentic_fetch_tool.go b/internal/agent/agentic_fetch_tool.go index 0c27c06e094f02fe016b4e8fcb0b9c16a836514a..4a8a4aac74ebd0aea83b42ed8e3953529550f622 100644 --- a/internal/agent/agentic_fetch_tool.go +++ b/internal/agent/agentic_fetch_tool.go @@ -156,7 +156,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err) } - smallProviderCfg, ok := c.cfgSvc.Config().Providers.Get(small.ModelCfg.Provider) + smallProviderCfg, ok := c.cfgSvc.Config().Providers[small.ModelCfg.Provider] if !ok { return fantasy.ToolResponse{}, errors.New("small model provider not configured") } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 86801cf2081a70434ee3a45d2f15cb042a1314ba..dfc88f490dc7f594f5d859b7b5b608518b47d9bc 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -147,7 +147,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments = filteredAttachments } - providerCfg, ok := c.cfgSvc.Config().Providers.Get(model.ModelCfg.Provider) + providerCfg, ok := c.cfgSvc.Config().Providers[model.ModelCfg.Provider] if !ok { return nil, errors.New("model provider not configured") } @@ -360,7 +360,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age return nil, err } - largeProviderCfg, _ := c.cfgSvc.Config().Providers.Get(large.ModelCfg.Provider) + largeProviderCfg, _ := c.cfgSvc.Config().Providers[large.ModelCfg.Provider] result := NewSessionAgent(SessionAgentOptions{ large, small, @@ -488,7 +488,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Mo return Model{}, Model{}, errors.New("small model not selected") } - largeProviderCfg, ok := c.cfgSvc.Config().Providers.Get(largeModelCfg.Provider) + largeProviderCfg, ok := c.cfgSvc.Config().Providers[largeModelCfg.Provider] if !ok { return Model{}, Model{}, errors.New("large model provider not configured") } @@ -498,7 +498,7 @@ func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Mo return Model{}, Model{}, err } - smallProviderCfg, ok := c.cfgSvc.Config().Providers.Get(smallModelCfg.Provider) + smallProviderCfg, ok := c.cfgSvc.Config().Providers[smallModelCfg.Provider] if !ok { return Model{}, Model{}, errors.New("large model provider not configured") } @@ -881,7 +881,7 @@ func (c *coordinator) QueuedPromptsList(sessionID string) []string { } func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { - providerCfg, ok := c.cfgSvc.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider) + providerCfg, ok := c.cfgSvc.Config().Providers[c.currentAgent.Model().ModelCfg.Provider] if !ok { return errors.New("model provider not configured") } @@ -912,7 +912,7 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con } providerCfg.APIKey = newAPIKey - c.cfgSvc.Config().Providers.Set(providerCfg.ID, providerCfg) + c.cfgSvc.Config().Providers[providerCfg.ID] = providerCfg if err := c.UpdateModels(ctx); err != nil { return err diff --git a/internal/app/app.go b/internal/app/app.go index 3412713e74df786f371e98ef470f0ab40719461d..08cf8703ffb83d8c1cb7a26a287de2d0c45bc8fc 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log/slog" + "maps" "os" "strings" "sync" @@ -321,7 +322,7 @@ func (app *App) UpdateAgentModel(ctx context.Context) error { // If largeModel is provided but smallModel is not, the small model defaults to // the provider's default small model. func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, smallModel string) error { - providers := app.configService.Config().Providers.Copy() + providers := maps.Clone(app.configService.Config().Providers) largeMatches, smallMatches, err := findModels(providers, largeModel, smallModel) if err != nil { diff --git a/internal/cmd/models.go b/internal/cmd/models.go index 3b2c4f7443f5389171e69c85b96e08db3b4a0255..bdfee5e1b61186ef3b09e0ef4e381ecdd4b6c5e9 100644 --- a/internal/cmd/models.go +++ b/internal/cmd/models.go @@ -55,7 +55,7 @@ crush models gpt5`, var providerIDs []string providerModels := make(map[string][]string) - for providerID, provider := range cfg.Config().Providers.Seq2() { + for providerID, provider := range cfg.Config().Providers { if provider.Disable { continue } diff --git a/internal/config/config.go b/internal/config/config.go index a386f2d39cdc0376765db4b772000efddebb2a6e..7995139ca9e0002955818978d9beac7609b1e0bf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,7 +12,6 @@ import ( "time" "charm.land/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" @@ -366,7 +365,7 @@ type Config struct { RecentModels map[SelectedModelType][]SelectedModel `json:"recent_models,omitempty" jsonschema:"-"` // The providers that are configured - Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty" jsonschema:"description=AI provider configurations"` + Providers map[string]ProviderConfig `json:"providers,omitempty" jsonschema:"description=AI provider configurations"` MCP MCPs `json:"mcp,omitempty" jsonschema:"description=Model Context Protocol server configurations"` @@ -383,7 +382,7 @@ type Config struct { func (c *Config) EnabledProviders() []ProviderConfig { var enabled []ProviderConfig - for p := range c.Providers.Seq() { + for _, p := range c.Providers { if !p.Disable { enabled = append(enabled, p) } @@ -397,7 +396,7 @@ func (c *Config) IsConfigured() bool { } func (c *Config) GetModel(provider, model string) *catwalk.Model { - if providerConfig, ok := c.Providers.Get(provider); ok { + if providerConfig, ok := c.Providers[provider]; ok { for _, m := range providerConfig.Models { if m.ID == model { return &m @@ -412,7 +411,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi if !ok { return nil } - if providerConfig, ok := c.Providers.Get(model.Provider); ok { + if providerConfig, ok := c.Providers[model.Provider]; ok { return &providerConfig } return nil diff --git a/internal/config/load.go b/internal/config/load.go index adb23c2495a94e4afe5fc8eb47712462cad03eff..22442f1558ea1bd7cb4f8d2a61d81fc5f8fed375 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -18,7 +18,6 @@ import ( "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/agent/hyper" - "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/home" @@ -143,7 +142,7 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true - config, configExists := c.Providers.Get(string(p.ID)) + config, configExists := c.Providers[string(p.ID)] // if the user configured a known provider we need to allow it to override a couple of parameters if configExists { if config.BaseURL != "" { @@ -216,7 +215,7 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil: // Claude Code subscription is not supported anymore. Remove to show onboarding. s.RemoveConfigField("providers.anthropic") - c.Providers.Del(string(p.ID)) + delete(c.Providers, string(p.ID)) continue case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil: prepared.SetupGitHubCopilot() @@ -228,7 +227,7 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno if !hasVertexCredentials(env) { if configExists { slog.Warn("Skipping Vertex AI provider due to missing credentials") - c.Providers.Del(string(p.ID)) + delete(c.Providers, string(p.ID)) } continue } @@ -239,7 +238,7 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno if err != nil || endpoint == "" { if configExists { slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err) - c.Providers.Del(string(p.ID)) + delete(c.Providers, string(p.ID)) } continue } @@ -249,7 +248,7 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno if !hasAWSCredentials(env) { if configExists { slog.Warn("Skipping Bedrock provider due to missing AWS credentials") - c.Providers.Del(string(p.ID)) + delete(c.Providers, string(p.ID)) } continue } @@ -268,16 +267,16 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno if v == "" || err != nil { if configExists { slog.Warn("Skipping provider due to missing API key", "provider", p.ID) - c.Providers.Del(string(p.ID)) + delete(c.Providers, string(p.ID)) } continue } } - c.Providers.Set(string(p.ID), prepared) + c.Providers[string(p.ID)] = prepared } // validate the custom providers - for id, providerConfig := range c.Providers.Seq2() { + for id, providerConfig := range c.Providers { if knownProviderNames[id] { continue } @@ -293,13 +292,13 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno } if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) && providerConfig.Type != hyper.Name { slog.Warn("Skipping custom provider due to unsupported provider type", "provider", id) - c.Providers.Del(id) + delete(c.Providers, id) continue } if providerConfig.Disable { slog.Debug("Skipping custom provider due to disable flag", "provider", id) - c.Providers.Del(id) + delete(c.Providers, id) continue } if providerConfig.APIKey == "" { @@ -307,12 +306,12 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno } if providerConfig.BaseURL == "" { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id) - c.Providers.Del(id) + delete(c.Providers, id) continue } if len(providerConfig.Models) == 0 { slog.Warn("Skipping custom provider because the provider has no models", "provider", id) - c.Providers.Del(id) + delete(c.Providers, id) continue } apiKey, err := resolver.ResolveValue(providerConfig.APIKey) @@ -322,7 +321,7 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno baseURL, err := resolver.ResolveValue(providerConfig.BaseURL) if baseURL == "" || err != nil { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err) - c.Providers.Del(id) + delete(c.Providers, id) continue } @@ -335,7 +334,7 @@ func (s *Service) configureProviders(env env.Env, resolver VariableResolver, kno providerConfig.ExtraHeaders[k] = resolved } - c.Providers.Set(id, providerConfig) + c.Providers[id] = providerConfig } return nil } @@ -363,7 +362,7 @@ func (c *Config) setDefaults(workingDir, dataDir string) { } } if c.Providers == nil { - c.Providers = csync.NewMap[string, ProviderConfig]() + c.Providers = make(map[string]ProviderConfig) } if c.Models == nil { c.Models = make(map[SelectedModelType]SelectedModel) @@ -467,7 +466,7 @@ func (c *Config) applyLSPDefaults() { func (s *Service) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { c := s.cfg - if len(knownProviders) == 0 && c.Providers.Len() == 0 { + if len(knownProviders) == 0 && len(c.Providers) == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return largeModel, smallModel, err } @@ -475,7 +474,7 @@ func (s *Service) defaultModelSelection(knownProviders []catwalk.Provider) (larg // Use the first provider enabled based on the known providers order // if no provider found that is known use the first provider configured for _, p := range knownProviders { - providerConfig, ok := c.Providers.Get(string(p.ID)) + providerConfig, ok := c.Providers[string(p.ID)] if !ok || providerConfig.Disable { continue } diff --git a/internal/config/load_test.go b/internal/config/load_test.go index c406c471ef82add9363edf99995f3289a7ad030b..f1f5c4b366be4053dd3756b8e82e99a69d633e58 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -8,7 +8,6 @@ import ( "testing" "charm.land/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,8 +33,8 @@ func TestConfig_LoadFromBytes(t *testing.T) { require.NoError(t, err) require.NotNil(t, loadedConfig) - require.Equal(t, 1, loadedConfig.Providers.Len()) - pc, _ := loadedConfig.Providers.Get("openai") + require.Equal(t, 1, len(loadedConfig.Providers)) + pc, _ := loadedConfig.Providers["openai"] require.Equal(t, "key2", pc.APIKey) require.Equal(t, "https://api.openai.com/v2", pc.BaseURL) } @@ -79,10 +78,10 @@ func TestConfig_configureProviders(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, 1, cfg.Providers.Len()) + require.Equal(t, 1, len(cfg.Providers)) // We want to make sure that we keep the configured API key as a placeholder - pc, _ := cfg.Providers.Get("openai") + pc, _ := cfg.Providers["openai"] require.Equal(t, "$OPENAI_API_KEY", pc.APIKey) } @@ -99,9 +98,9 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMap[string, ProviderConfig](), + Providers: make(map[string]ProviderConfig), } - cfg.Providers.Set("openai", ProviderConfig{ + cfg.Providers["openai"] = ProviderConfig{ APIKey: "xyz", BaseURL: "https://api.openai.com/v2", Models: []catwalk.Model{ @@ -113,7 +112,7 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { ID: "another-model", }, }, - }) + } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ @@ -122,10 +121,10 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, 1, cfg.Providers.Len()) + require.Equal(t, 1, len(cfg.Providers)) // We want to make sure that we keep the configured API key as a placeholder - pc, _ := cfg.Providers.Get("openai") + pc, _ := cfg.Providers["openai"] require.Equal(t, "xyz", pc.APIKey) require.Equal(t, "https://api.openai.com/v2", pc.BaseURL) require.Len(t, pc.Models, 2) @@ -145,7 +144,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "xyz", BaseURL: "https://api.someendpoint.com/v2", @@ -155,7 +154,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { }, }, }, - }), + }, } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ @@ -165,17 +164,17 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Should be to because of the env variable - require.Equal(t, cfg.Providers.Len(), 2) + require.Equal(t, len(cfg.Providers), 2) // We want to make sure that we keep the configured API key as a placeholder - pc, _ := cfg.Providers.Get("custom") + pc, _ := cfg.Providers["custom"] require.Equal(t, "xyz", pc.APIKey) // Make sure we set the ID correctly require.Equal(t, "custom", pc.ID) require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL) require.Len(t, pc.Models, 1) - _, ok := cfg.Providers.Get("openai") + _, ok := cfg.Providers["openai"] require.True(t, ok, "OpenAI provider should still be present") } @@ -200,9 +199,9 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 1) + require.Equal(t, len(cfg.Providers), 1) - bedrockProvider, ok := cfg.Providers.Get("bedrock") + bedrockProvider, ok := cfg.Providers["bedrock"] require.True(t, ok, "Bedrock provider should be present") require.Len(t, bedrockProvider.Models, 1) require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID) @@ -227,7 +226,7 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without credentials - require.Equal(t, cfg.Providers.Len(), 0) + require.Equal(t, len(cfg.Providers), 0) } func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { @@ -274,9 +273,9 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 1) + require.Equal(t, len(cfg.Providers), 1) - vertexProvider, ok := cfg.Providers.Get("vertexai") + vertexProvider, ok := cfg.Providers["vertexai"] require.True(t, ok, "VertexAI provider should be present") require.Len(t, vertexProvider.Models, 1) require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID) @@ -307,7 +306,7 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without proper credentials - require.Equal(t, cfg.Providers.Len(), 0) + require.Equal(t, len(cfg.Providers), 0) } func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { @@ -332,7 +331,7 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without project - require.Equal(t, cfg.Providers.Len(), 0) + require.Equal(t, len(cfg.Providers), 0) } func TestConfig_configureProvidersSetProviderID(t *testing.T) { @@ -355,17 +354,17 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) { resolver := NewEnvironmentVariableResolver(env) err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 1) + require.Equal(t, len(cfg.Providers), 1) // Provider ID should be set - pc, _ := cfg.Providers.Get("openai") + pc, _ := cfg.Providers["openai"] require.Equal(t, "openai", pc.ID) } func TestConfig_EnabledProviders(t *testing.T) { t.Run("all providers enabled", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -376,7 +375,7 @@ func TestConfig_EnabledProviders(t *testing.T) { APIKey: "key2", Disable: false, }, - }), + }, } enabled := cfg.EnabledProviders() @@ -385,7 +384,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("some providers disabled", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -396,7 +395,7 @@ func TestConfig_EnabledProviders(t *testing.T) { APIKey: "key2", Disable: true, }, - }), + }, } enabled := cfg.EnabledProviders() @@ -406,7 +405,7 @@ func TestConfig_EnabledProviders(t *testing.T) { t.Run("empty providers map", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMap[string, ProviderConfig](), + Providers: make(map[string]ProviderConfig), } enabled := cfg.EnabledProviders() @@ -417,13 +416,13 @@ func TestConfig_EnabledProviders(t *testing.T) { func TestConfig_IsConfigured(t *testing.T) { t.Run("returns true when at least one provider is enabled", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", Disable: false, }, - }), + }, } require.True(t, cfg.IsConfigured()) @@ -431,7 +430,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when no providers are configured", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMap[string, ProviderConfig](), + Providers: make(map[string]ProviderConfig), } require.False(t, cfg.IsConfigured()) @@ -439,7 +438,7 @@ func TestConfig_IsConfigured(t *testing.T) { t.Run("returns false when all providers are disabled", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", @@ -450,7 +449,7 @@ func TestConfig_IsConfigured(t *testing.T) { APIKey: "key2", Disable: true, }, - }), + }, } require.False(t, cfg.IsConfigured()) @@ -532,11 +531,11 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { Disable: true, }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -547,8 +546,8 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 1) - prov, exists := cfg.Providers.Get("openai") + require.Equal(t, len(cfg.Providers), 1) + prov, exists := cfg.Providers["openai"] require.True(t, exists) require.True(t, prov.Disable) } @@ -556,7 +555,7 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{{ @@ -566,7 +565,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { "openai": { APIKey: "$MISSING", }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -575,21 +574,21 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 1) - _, exists := cfg.Providers.Get("custom") + require.Equal(t, len(cfg.Providers), 1) + _, exists := cfg.Providers["custom"] require.True(t, exists) }) t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", Models: []catwalk.Model{{ ID: "test-model", }}, }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -598,20 +597,20 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 0) - _, exists := cfg.Providers.Get("custom") + require.Equal(t, len(cfg.Providers), 0) + _, exists := cfg.Providers["custom"] require.False(t, exists) }) t.Run("custom provider with no models is removed", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{}, }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -620,14 +619,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 0) - _, exists := cfg.Providers.Get("custom") + require.Equal(t, len(cfg.Providers), 0) + _, exists := cfg.Providers["custom"] require.False(t, exists) }) t.Run("custom provider with unsupported type is removed", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -636,7 +635,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -645,14 +644,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 0) - _, exists := cfg.Providers.Get("custom") + require.Equal(t, len(cfg.Providers), 0) + _, exists := cfg.Providers["custom"] require.False(t, exists) }) t.Run("valid custom provider is kept and ID is set", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -661,7 +660,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -670,8 +669,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 1) - customProvider, exists := cfg.Providers.Get("custom") + require.Equal(t, len(cfg.Providers), 1) + customProvider, exists := cfg.Providers["custom"] require.True(t, exists) require.Equal(t, "custom", customProvider.ID) require.Equal(t, "test-key", customProvider.APIKey) @@ -680,7 +679,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom anthropic provider is supported", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom-anthropic": { APIKey: "test-key", BaseURL: "https://api.anthropic.com/v1", @@ -689,7 +688,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "claude-3-sonnet", }}, }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -698,8 +697,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 1) - customProvider, exists := cfg.Providers.Get("custom-anthropic") + require.Equal(t, len(cfg.Providers), 1) + customProvider, exists := cfg.Providers["custom-anthropic"] require.True(t, exists) require.Equal(t, "custom-anthropic", customProvider.ID) require.Equal(t, "test-key", customProvider.APIKey) @@ -709,7 +708,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("disabled custom provider is removed", func(t *testing.T) { cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -719,7 +718,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "test-model", }}, }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -728,8 +727,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 0) - _, exists := cfg.Providers.Get("custom") + require.Equal(t, len(cfg.Providers), 0) + _, exists := cfg.Providers["custom"] require.False(t, exists) }) } @@ -748,11 +747,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "vertexai": { BaseURL: "custom-url", }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -763,8 +762,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 0) - _, exists := cfg.Providers.Get("vertexai") + require.Equal(t, len(cfg.Providers), 0) + _, exists := cfg.Providers["vertexai"] require.False(t, exists) }) @@ -781,11 +780,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "bedrock": { BaseURL: "custom-url", }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -794,8 +793,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 0) - _, exists := cfg.Providers.Get("bedrock") + require.Equal(t, len(cfg.Providers), 0) + _, exists := cfg.Providers["bedrock"] require.False(t, exists) }) @@ -812,11 +811,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { BaseURL: "custom-url", }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -825,8 +824,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 0) - _, exists := cfg.Providers.Get("openai") + require.Equal(t, len(cfg.Providers), 0) + _, exists := cfg.Providers["openai"] require.False(t, exists) }) @@ -843,11 +842,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { APIKey: "test-key", }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -858,8 +857,8 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { err := serviceFor(cfg).configureProviders(env, resolver, knownProviders) require.NoError(t, err) - require.Equal(t, cfg.Providers.Len(), 1) - _, exists := cfg.Providers.Get("openai") + require.Equal(t, len(cfg.Providers), 1) + _, exists := cfg.Providers["openai"] require.True(t, exists) }) } @@ -982,7 +981,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -993,7 +992,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }, }, - }), + }, } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) @@ -1031,13 +1030,13 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{}, }, - }), + }, } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) @@ -1068,7 +1067,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { } cfg := &Config{ - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", @@ -1079,7 +1078,7 @@ func TestConfig_defaultModelSelection(t *testing.T) { }, }, }, - }), + }, } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) @@ -1117,11 +1116,11 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { Options: &Options{ DisableDefaultProviders: true, }, - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { APIKey: "$OPENAI_API_KEY", }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -1133,8 +1132,8 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { require.NoError(t, err) // openai should NOT be present because it lacks base_url and models. - require.Equal(t, 0, cfg.Providers.Len()) - _, exists := cfg.Providers.Get("openai") + require.Equal(t, 0, len(cfg.Providers)) + _, exists := cfg.Providers["openai"] require.False(t, exists, "openai should not be present without full specification") }) @@ -1155,7 +1154,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { Options: &Options{ DisableDefaultProviders: true, }, - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "my-llm": { APIKey: "$MY_API_KEY", BaseURL: "https://my-llm.example.com/v1", @@ -1163,7 +1162,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { ID: "my-model", }}, }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -1176,14 +1175,14 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { require.NoError(t, err) // Only fully specified provider should be present. - require.Equal(t, 1, cfg.Providers.Len()) - provider, exists := cfg.Providers.Get("my-llm") + require.Equal(t, 1, len(cfg.Providers)) + provider, exists := cfg.Providers["my-llm"] require.True(t, exists, "my-llm should be present") require.Equal(t, "https://my-llm.example.com/v1", provider.BaseURL) require.Len(t, provider.Models, 1) // Default openai should NOT be present. - _, exists = cfg.Providers.Get("openai") + _, exists = cfg.Providers["openai"] require.False(t, exists, "openai should not be present") }) @@ -1213,11 +1212,11 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { Options: &Options{ DisableDefaultProviders: false, }, - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "openai": { APIKey: "$OPENAI_API_KEY", }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -1230,10 +1229,10 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { require.NoError(t, err) // Both providers should be present. - require.Equal(t, 2, cfg.Providers.Len()) - _, exists := cfg.Providers.Get("openai") + require.Equal(t, 2, len(cfg.Providers)) + _, exists := cfg.Providers["openai"] require.True(t, exists, "openai should be present") - _, exists = cfg.Providers.Get("anthropic") + _, exists = cfg.Providers["anthropic"] require.True(t, exists, "anthropic should be present") }) @@ -1242,13 +1241,13 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { Options: &Options{ DisableDefaultProviders: true, }, - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "my-llm": { APIKey: "test-key", BaseURL: "https://my-llm.example.com/v1", Models: []catwalk.Model{}, // No models. }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -1258,7 +1257,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { require.NoError(t, err) // Provider should be rejected for missing models. - require.Equal(t, 0, cfg.Providers.Len()) + require.Equal(t, 0, len(cfg.Providers)) }) t.Run("when enabled, provider missing base_url is rejected", func(t *testing.T) { @@ -1266,13 +1265,13 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { Options: &Options{ DisableDefaultProviders: true, }, - Providers: csync.NewMapFrom(map[string]ProviderConfig{ + Providers: map[string]ProviderConfig{ "my-llm": { APIKey: "test-key", Models: []catwalk.Model{{ID: "model"}}, // No BaseURL. }, - }), + }, } cfg.setDefaults("/tmp", "") @@ -1282,7 +1281,7 @@ func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) { require.NoError(t, err) // Provider should be rejected for missing base_url. - require.Equal(t, 0, cfg.Providers.Len()) + require.Equal(t, 0, len(cfg.Providers)) }) } diff --git a/internal/config/service.go b/internal/config/service.go index 8124ff5e48cf1b6ac2ca190da9efa5fdbf736346..1f3107243648cb2967eee40b4a954ac8add2313c 100644 --- a/internal/config/service.go +++ b/internal/config/service.go @@ -364,7 +364,7 @@ func (s *Service) recordRecentModel(modelType SelectedModelType, model SelectedM // RefreshOAuthToken refreshes the OAuth token for the given provider. func (s *Service) RefreshOAuthToken(ctx context.Context, providerID string) error { cfg := s.cfg - providerConfig, exists := cfg.Providers.Get(providerID) + providerConfig, exists := cfg.Providers[providerID] if !exists { return fmt.Errorf("provider %s not found", providerID) } @@ -396,7 +396,7 @@ func (s *Service) RefreshOAuthToken(ctx context.Context, providerID string) erro providerConfig.SetupGitHubCopilot() } - cfg.Providers.Set(providerID, providerConfig) + cfg.Providers[providerID] = providerConfig if err := cmp.Or( s.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken), @@ -439,10 +439,10 @@ func (s *Service) SetProviderAPIKey(providerID string, apiKey any) error { } } - providerConfig, exists = cfg.Providers.Get(providerID) + providerConfig, exists = cfg.Providers[providerID] if exists { setKeyOrToken() - cfg.Providers.Set(providerID, providerConfig) + cfg.Providers[providerID] = providerConfig return nil } @@ -469,7 +469,7 @@ func (s *Service) SetProviderAPIKey(providerID string, apiKey any) error { } else { return fmt.Errorf("provider with ID %s not found in known providers", providerID) } - cfg.Providers.Set(providerID, providerConfig) + cfg.Providers[providerID] = providerConfig return nil } diff --git a/internal/ui/chat/messages.go b/internal/ui/chat/messages.go index 5dac49c08d32ae2315f9d8096f0410b2511ecb04..1aaccdc911fadfc41208e69c6903e8b4e9e90212 100644 --- a/internal/ui/chat/messages.go +++ b/internal/ui/chat/messages.go @@ -239,7 +239,7 @@ func (a *AssistantInfoItem) renderContent(width int) string { } modelFormatted := a.sty.Chat.Message.AssistantInfoModel.Render(model.Name) providerName := a.message.Provider - if providerConfig, ok := a.cfg.Providers.Get(a.message.Provider); ok { + if providerConfig, ok := a.cfg.Providers[a.message.Provider]; ok { providerName = providerConfig.Name } provider := a.sty.Chat.Message.AssistantInfoProvider.Render(fmt.Sprintf("via %s", providerName)) diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 1451c0f4fe04d3dc4f0207bdce95007fe782ffcf..f96e547ccdddb6441762f968db79b62fd27d6ff7 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -361,7 +361,7 @@ func (m *Models) setProviderItems() error { // itemsMap contains the keys of added model items. itemsMap := make(map[string]*ModelItem) groups := []ModelGroup{} - for id, p := range cfg.Providers.Seq2() { + for id, p := range cfg.Providers { if p.Disable { continue } @@ -411,7 +411,7 @@ func (m *Models) setProviderItems() error { continue } - providerConfig, providerConfigured := cfg.Providers.Get(providerID) + providerConfig, providerConfigured := cfg.Providers[providerID] if providerConfigured && providerConfig.Disable { continue } @@ -519,7 +519,7 @@ func getFilteredProviders(svc *config.Service) ([]catwalk.Provider, error) { isCopilot = p.ID == catwalk.InferenceProviderCopilot isHyper = string(p.ID) == "hyper" hasAPIKeyEnv = strings.HasPrefix(p.APIKey, "$") - _, isConfigured = cfg.Providers.Get(string(p.ID)) + _, isConfigured = cfg.Providers[string(p.ID)] ) if isAzure || isCopilot || isHyper || hasAPIKeyEnv || isConfigured { filteredProviders = append(filteredProviders, p) diff --git a/internal/ui/model/sidebar.go b/internal/ui/model/sidebar.go index 001242fad9f2b01f437e1cf3caca39a062175a77..7c37584afb2f054f1410659b2b103056e1f42c7b 100644 --- a/internal/ui/model/sidebar.go +++ b/internal/ui/model/sidebar.go @@ -21,7 +21,7 @@ func (m *UI) modelInfo(width int) string { if model != nil { // Get provider name first - providerConfig, ok := m.com.ConfigService().Config().Providers.Get(model.ModelCfg.Provider) + providerConfig, ok := m.com.ConfigService().Config().Providers[model.ModelCfg.Provider] if ok { providerName = providerConfig.Name diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 757aa2ceb315e08c18dc7c2a25731505e25fa7dd..fd32910cfed0ab95d0864d78cef417acd1830ff0 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -1244,7 +1244,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { var ( providerID = msg.Model.Provider isCopilot = providerID == string(catwalk.InferenceProviderCopilot) - isConfigured = func() bool { _, ok := cfg.Providers.Get(providerID); return ok } + isConfigured = func() bool { _, ok := cfg.Providers[providerID]; return ok } ) // Attempt to import GitHub Copilot tokens from VSCode if available.