initial attempt at reloading provider list post background sync

tauraamui created

Change summary

internal/config/load.go                |  6 ++++
internal/config/provider.go            | 41 +++++++++++++++++++++++----
internal/config/provider_empty_test.go |  6 ++-
internal/config/provider_test.go       |  9 ++++--
4 files changed, 50 insertions(+), 12 deletions(-)

Detailed changes

internal/config/load.go 🔗

@@ -81,6 +81,10 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
 	if err != nil {
 		return nil, err
 	}
+	// TODO(tauraamui): need to internally emit a known providers data update event or something
+	// to re-apply basically all of the following:
+	// +
+	// +
 	cfg.knownProviders = providers
 
 	env := env.New()
@@ -100,6 +104,8 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
 		return nil, fmt.Errorf("failed to configure selected models: %w", err)
 	}
 	cfg.SetupAgents()
+	// +
+	// +
 	return cfg, nil
 }
 

internal/config/provider.go 🔗

@@ -22,9 +22,10 @@ type ProviderClient interface {
 }
 
 var (
-	providerOnce sync.Once
+	providerMu   sync.RWMutex
 	providerList []catwalk.Provider
 	providerErr  error
+	initialized  bool
 )
 
 // file to cache provider data
@@ -118,21 +119,44 @@ func UpdateProviders(pathOrUrl string) error {
 	return nil
 }
 
-// NOTE(tauraamui) <REF#2>: see note (REF#1), basically this looks like some logic
-// should be shared/consolidated.
 func Providers(cfg *Config) ([]catwalk.Provider, error) {
-	providerOnce.Do(func() {
+	providerMu.Lock()
+	if !initialized {
 		catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 		client := catwalk.NewWithURL(catwalkURL)
 		path := providerCacheFileData()
 
 		autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
-		providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
-	})
+		providerList, providerErr = loadProviders(autoUpdateDisabled, client, path, cfg)
+		initialized = true
+	}
+	providerMu.Unlock()
+
+	providerMu.RLock()
+	defer providerMu.RUnlock()
 	return providerList, providerErr
 }
 
-func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
+func reloadProviders(path string) {
+	providerMu.Lock()
+	defer providerMu.Unlock()
+
+	providers, err := loadProvidersFromCache(path)
+	if err != nil {
+		slog.Error("Failed to reload providers from cache", "error", err)
+		return
+	}
+	if len(providers) == 0 {
+		slog.Error("Empty providers list after reload")
+		return
+	}
+
+	providerList = providers
+	providerErr = nil
+	slog.Info("Providers reloaded successfully", "count", len(providers))
+}
+
+func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string, cfg *Config) ([]catwalk.Provider, error) {
 	cacheIsStale, cacheExists := isCacheStale(path)
 
 	catwalkGetAndSave := func() ([]catwalk.Provider, error) {
@@ -164,7 +188,10 @@ func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string)
 			}
 			if err := saveProvidersInCache(path, providers); err != nil {
 				slog.Error("Failed to update providers.json in background", "error", err)
+				return
 			}
+
+			reloadProviders(path)
 		}()
 	}
 

internal/config/provider_empty_test.go 🔗

@@ -19,7 +19,8 @@ func TestProvider_loadProvidersEmptyResult(t *testing.T) {
 	client := &emptyProviderClient{}
 	tmpPath := t.TempDir() + "/providers.json"
 
-	providers, err := loadProviders(false, client, tmpPath)
+	cfg := &Config{}
+	providers, err := loadProviders(false, client, tmpPath, cfg)
 	require.Contains(t, err.Error(), "Crush was unable to fetch an updated list of providers")
 	require.Empty(t, providers)
 	require.Len(t, providers, 0)
@@ -39,7 +40,8 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) {
 	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
 
 	// Should refresh and get real providers instead of using empty cache
-	providers, err := loadProviders(false, client, tmpPath)
+	cfg := &Config{}
+	providers, err := loadProviders(false, client, tmpPath, cfg)
 	require.NoError(t, err)
 	require.NotNil(t, providers)
 	require.Len(t, providers, 1)

internal/config/provider_test.go 🔗

@@ -28,7 +28,8 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
 func TestProvider_loadProvidersNoIssues(t *testing.T) {
 	client := &mockProviderClient{shouldFail: false}
 	tmpPath := t.TempDir() + "/providers.json"
-	providers, err := loadProviders(false, client, tmpPath)
+	cfg := &Config{}
+	providers, err := loadProviders(false, client, tmpPath, cfg)
 	require.NoError(t, err)
 	require.NotNil(t, providers)
 	require.Len(t, providers, 1)
@@ -57,7 +58,8 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to write old providers to file: %v", err)
 	}
-	providers, err := loadProviders(false, client, tmpPath)
+	cfg := &Config{}
+	providers, err := loadProviders(false, client, tmpPath, cfg)
 	require.NoError(t, err)
 	require.NotNil(t, providers)
 	require.Len(t, providers, 1)
@@ -67,7 +69,8 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
 func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
 	client := &mockProviderClient{shouldFail: true}
 	tmpPath := t.TempDir() + "/providers.json"
-	providers, err := loadProviders(false, client, tmpPath)
+	cfg := &Config{}
+	providers, err := loadProviders(false, client, tmpPath, cfg)
 	require.Error(t, err)
 	require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
 }