From 1ffc87f48f45da42004ad6a00f3dc253a0a92294 Mon Sep 17 00:00:00 2001 From: tauraamui Date: Thu, 16 Oct 2025 11:37:09 +0100 Subject: [PATCH] refactor: stop passing full config when only single field needed --- internal/config/load.go | 2 +- internal/config/provider.go | 21 ++++++++++--------- internal/config/provider_empty_test.go | 6 ++---- internal/tui/components/chat/splash/splash.go | 2 +- .../tui/components/dialogs/models/list.go | 4 ++-- .../tui/components/dialogs/models/models.go | 2 +- 6 files changed, 18 insertions(+), 19 deletions(-) diff --git a/internal/config/load.go b/internal/config/load.go index 21b5d6b55c0ccd7e9a892a46b613eeb6275ec6cd..d87b94fb5de5d47cfdab5fe36b69784b2af2f908 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -77,7 +77,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { // NOTE(tauraamui): current entrypoint for invoking providers fetch for // the rest of the app as a whole. // Load known providers, this loads the config from catwalk - providers, err := Providers(cfg) + providers, err := Providers(cfg.Options.DisableProviderAutoUpdate) if err != nil { return nil, err } diff --git a/internal/config/provider.go b/internal/config/provider.go index 4cd373a30766021c427438b33e57cd105774fa83..0e5022e691da1f10c825e90c1978af347e17f78d 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -119,18 +119,19 @@ func UpdateProviders(pathOrUrl string) error { return nil } -func Providers(cfg *Config) ([]catwalk.Provider, error) { - providerMu.Lock() - if !initialized { - catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) - client := catwalk.NewWithURL(catwalkURL) - path := providerCacheFileData() +func Providers(autoUpdateDisabled bool) ([]catwalk.Provider, error) { + catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) + client := catwalk.NewWithURL(catwalkURL) + return ProvidersWithClient(autoUpdateDisabled, client, providerCacheFileData()) +} - autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate - providerList, providerErr = loadProviders(autoUpdateDisabled, client, path, cfg) +func ProvidersWithClient(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) { + if !initialized { + providerMu.Lock() + providerList, providerErr = loadProviders(autoUpdateDisabled, client, path) initialized = true + providerMu.Unlock() } - providerMu.Unlock() providerMu.RLock() defer providerMu.RUnlock() @@ -156,7 +157,7 @@ func reloadProviders(path string) { slog.Info("Providers reloaded successfully", "count", len(providers)) } -func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string, cfg *Config) ([]catwalk.Provider, error) { +func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) { cacheIsStale, cacheExists := isCacheStale(path) catwalkGetAndSave := func() ([]catwalk.Provider, error) { diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go index 7e98cd9288ff5a080c876606e93bd16a2f0c3a19..f3691c320ad4e3509b327374c8ce7f5285c39590 100644 --- a/internal/config/provider_empty_test.go +++ b/internal/config/provider_empty_test.go @@ -19,8 +19,7 @@ func TestProvider_loadProvidersEmptyResult(t *testing.T) { client := &emptyProviderClient{} tmpPath := t.TempDir() + "/providers.json" - cfg := &Config{} - providers, err := loadProviders(false, client, tmpPath, cfg) + providers, err := loadProviders(false, client, tmpPath) require.Contains(t, err.Error(), "Crush was unable to fetch an updated list of providers") require.Empty(t, providers) require.Len(t, providers, 0) @@ -40,8 +39,7 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) { require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) // Should refresh and get real providers instead of using empty cache - cfg := &Config{} - providers, err := loadProviders(false, client, tmpPath, cfg) + providers, err := loadProviders(false, client, tmpPath) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 187fc35e6ec47a858b99f35e135a8cef3500fbf1..13fcd06430acbb72a9d2900a4b50964733116745 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -399,7 +399,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd { func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { cfg := config.Get() - providers, err := config.Providers(cfg) + providers, err := config.Providers(cfg.Options.DisableProviderAutoUpdate) if err != nil { return nil, err } diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 77398c4d17d85126ab155a9e9c5b2085c0691672..b78700cb16290f35b28d08fc8a53aef52adb3d28 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -50,7 +50,7 @@ func (m *ModelListComponent) Init() tea.Cmd { var cmds []tea.Cmd if len(m.providers) == 0 { cfg := config.Get() - providers, err := config.Providers(cfg) + providers, err := config.Providers(cfg.Options.DisableProviderAutoUpdate) filteredProviders := []catwalk.Provider{} for _, p := range providers { hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$") @@ -120,7 +120,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { // First, add any configured providers that are not in the known providers list // These should appear at the top of the list - knownProviders, err := config.Providers(cfg) + knownProviders, err := config.Providers(cfg.Options.DisableProviderAutoUpdate) if err != nil { return util.ReportError(err) } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 7c2863706c29180cffcfb88c385a012e39df464c..bff70b40c73eef22680c6a24f8d986ecb1fd0a45 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -353,7 +353,7 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { cfg := config.Get() - providers, err := config.Providers(cfg) + providers, err := config.Providers(cfg.Options.DisableProviderAutoUpdate) if err != nil { return nil, err }