diff --git a/internal/config/load.go b/internal/config/load.go index ab8c35f7bd1cfaf6077cb841d4ba6e8e4dee1403..21b5d6b55c0ccd7e9a892a46b613eeb6275ec6cd 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -81,6 +81,10 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { if err != nil { return nil, err } + // TODO(tauraamui): need to internally emit a known providers data update event or something + // to re-apply basically all of the following: + // + + // + cfg.knownProviders = providers env := env.New() @@ -100,6 +104,8 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to configure selected models: %w", err) } cfg.SetupAgents() + // + + // + return cfg, nil } diff --git a/internal/config/provider.go b/internal/config/provider.go index 0e664e0b48e7a5f61778e580c7c120ecd2a8a255..4cd373a30766021c427438b33e57cd105774fa83 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -22,9 +22,10 @@ type ProviderClient interface { } var ( - providerOnce sync.Once + providerMu sync.RWMutex providerList []catwalk.Provider providerErr error + initialized bool ) // file to cache provider data @@ -118,21 +119,44 @@ func UpdateProviders(pathOrUrl string) error { return nil } -// NOTE(tauraamui) : see note (REF#1), basically this looks like some logic -// should be shared/consolidated. func Providers(cfg *Config) ([]catwalk.Provider, error) { - providerOnce.Do(func() { + providerMu.Lock() + if !initialized { catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) client := catwalk.NewWithURL(catwalkURL) path := providerCacheFileData() autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate - providerList, providerErr = loadProviders(autoUpdateDisabled, client, path) - }) + providerList, providerErr = loadProviders(autoUpdateDisabled, client, path, cfg) + initialized = true + } + providerMu.Unlock() + + providerMu.RLock() + defer providerMu.RUnlock() return providerList, providerErr } -func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) { +func reloadProviders(path string) { + providerMu.Lock() + defer providerMu.Unlock() + + providers, err := loadProvidersFromCache(path) + if err != nil { + slog.Error("Failed to reload providers from cache", "error", err) + return + } + if len(providers) == 0 { + slog.Error("Empty providers list after reload") + return + } + + providerList = providers + providerErr = nil + slog.Info("Providers reloaded successfully", "count", len(providers)) +} + +func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string, cfg *Config) ([]catwalk.Provider, error) { cacheIsStale, cacheExists := isCacheStale(path) catwalkGetAndSave := func() ([]catwalk.Provider, error) { @@ -164,7 +188,10 @@ func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) } if err := saveProvidersInCache(path, providers); err != nil { slog.Error("Failed to update providers.json in background", "error", err) + return } + + reloadProviders(path) }() } diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go index f3691c320ad4e3509b327374c8ce7f5285c39590..7e98cd9288ff5a080c876606e93bd16a2f0c3a19 100644 --- a/internal/config/provider_empty_test.go +++ b/internal/config/provider_empty_test.go @@ -19,7 +19,8 @@ func TestProvider_loadProvidersEmptyResult(t *testing.T) { client := &emptyProviderClient{} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) + cfg := &Config{} + providers, err := loadProviders(false, client, tmpPath, cfg) require.Contains(t, err.Error(), "Crush was unable to fetch an updated list of providers") require.Empty(t, providers) require.Len(t, providers, 0) @@ -39,7 +40,8 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) { require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) // Should refresh and get real providers instead of using empty cache - providers, err := loadProviders(false, client, tmpPath) + cfg := &Config{} + providers, err := loadProviders(false, client, tmpPath, cfg) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 8b499919bca666915a89d38c1e5014a911f4d2d1..ed5dc2b7e65380643fc532db3a993533194648aa 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -28,7 +28,8 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { func TestProvider_loadProvidersNoIssues(t *testing.T) { client := &mockProviderClient{shouldFail: false} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) + cfg := &Config{} + providers, err := loadProviders(false, client, tmpPath, cfg) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) @@ -57,7 +58,8 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { if err != nil { t.Fatalf("Failed to write old providers to file: %v", err) } - providers, err := loadProviders(false, client, tmpPath) + cfg := &Config{} + providers, err := loadProviders(false, client, tmpPath, cfg) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) @@ -67,7 +69,8 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) + cfg := &Config{} + providers, err := loadProviders(false, client, tmpPath, cfg) require.Error(t, err) require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") }