fix: load providers in background

Carlos Alexandro Becker created

Change summary

internal/config/load.go                |  5 -
internal/config/provider.go            | 84 +++++++++++++++++++--------
internal/config/provider_empty_test.go | 47 +++++++++++++++
internal/config/provider_test.go       |  6 +-
4 files changed, 111 insertions(+), 31 deletions(-)

Detailed changes

internal/config/load.go 🔗

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"log/slog"
 	"os"
 	"path/filepath"
 	"runtime"
@@ -13,10 +14,8 @@ import (
 
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/env"
-	"github.com/charmbracelet/crush/internal/fur/client"
 	"github.com/charmbracelet/crush/internal/fur/provider"
 	"github.com/charmbracelet/crush/internal/log"
-	"golang.org/x/exp/slog"
 )
 
 // LoadReader config via io.Reader.
@@ -63,7 +62,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
 	)
 
 	// Load known providers, this loads the config from fur
-	providers, err := LoadProviders(client.New())
+	providers, err := Providers()
 	if err != nil || len(providers) == 0 {
 		return nil, fmt.Errorf("failed to load providers: %w", err)
 	}

internal/config/provider.go 🔗

@@ -2,10 +2,13 @@ package config
 
 import (
 	"encoding/json"
+	"fmt"
+	"log/slog"
 	"os"
 	"path/filepath"
 	"runtime"
 	"sync"
+	"time"
 
 	"github.com/charmbracelet/crush/internal/fur/client"
 	"github.com/charmbracelet/crush/internal/fur/provider"
@@ -42,57 +45,88 @@ func providerCacheFileData() string {
 }
 
 func saveProvidersInCache(path string, providers []provider.Provider) error {
-	dir := filepath.Dir(path)
-	if err := os.MkdirAll(dir, 0o755); err != nil {
-		return err
+	slog.Info("Caching provider data")
+	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+		return fmt.Errorf("failed to create directory for provider cache: %w", err)
 	}
 
 	data, err := json.MarshalIndent(providers, "", "  ")
 	if err != nil {
-		return err
+		return fmt.Errorf("failed to marshal provider data: %w", err)
 	}
 
-	return os.WriteFile(path, data, 0o644)
+	if err := os.WriteFile(path, data, 0o644); err != nil {
+		return fmt.Errorf("failed to write provider data to cache: %w", err)
+	}
+	return nil
 }
 
 func loadProvidersFromCache(path string) ([]provider.Provider, error) {
 	data, err := os.ReadFile(path)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to read provider cache file: %w", err)
 	}
 
 	var providers []provider.Provider
-	err = json.Unmarshal(data, &providers)
-	return providers, err
-}
-
-func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
-	providers, err := client.GetProviders()
-	if err != nil {
-		fallbackToCache, err := loadProvidersFromCache(path)
-		if err != nil {
-			return nil, err
-		}
-		providers = fallbackToCache
-	} else {
-		if err := saveProvidersInCache(path, providerList); err != nil {
-			return nil, err
-		}
+	if err := json.Unmarshal(data, &providers); err != nil {
+		return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
 	}
 	return providers, nil
 }
 
 func Providers() ([]provider.Provider, error) {
-	return LoadProviders(client.New())
+	client := client.New()
+	path := providerCacheFileData()
+	return loadProvidersOnce(client, path)
 }
 
-func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
+func loadProvidersOnce(client ProviderClient, path string) ([]provider.Provider, error) {
 	var err error
 	providerOnce.Do(func() {
-		providerList, err = loadProviders(providerCacheFileData(), client)
+		providerList, err = loadProviders(client, path)
 	})
 	if err != nil {
 		return nil, err
 	}
 	return providerList, nil
 }
+
+func loadProviders(client ProviderClient, path string) (providerList []provider.Provider, err error) {
+	// if cache is not stale, load from it
+	stale, exists := isCacheStale(path)
+	if !stale {
+		slog.Info("Using cached provider data")
+		providerList, err = loadProvidersFromCache(path)
+		if len(providerList) > 0 && err == nil {
+			go func() {
+				slog.Info("Updating provider cache in background")
+				updated, uerr := client.GetProviders()
+				if len(updated) == 0 && uerr == nil {
+					_ = saveProvidersInCache(path, updated)
+				}
+			}()
+			return
+		}
+	}
+
+	slog.Info("Getting live provider data")
+	providerList, err = client.GetProviders()
+	if len(providerList) > 0 && err == nil {
+		err = saveProvidersInCache(path, providerList)
+		return
+	}
+	if !exists {
+		err = fmt.Errorf("failed to load providers")
+		return
+	}
+	providerList, err = loadProvidersFromCache(path)
+	return
+}
+
+func isCacheStale(path string) (stale, exists bool) {
+	info, err := os.Stat(path)
+	if err != nil {
+		return true, false
+	}
+	return time.Since(info.ModTime()) > 24*time.Hour, true
+}

internal/config/provider_empty_test.go 🔗

@@ -0,0 +1,47 @@
+package config
+
+import (
+	"encoding/json"
+	"os"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/fur/provider"
+	"github.com/stretchr/testify/require"
+)
+
+type emptyProviderClient struct{}
+
+func (m *emptyProviderClient) GetProviders() ([]provider.Provider, error) {
+	return []provider.Provider{}, nil
+}
+
+func TestProvider_loadProvidersEmptyResult(t *testing.T) {
+	client := &emptyProviderClient{}
+	tmpPath := t.TempDir() + "/providers.json"
+
+	providers, err := loadProviders(client, tmpPath)
+	require.EqualError(t, err, "failed to load providers")
+	require.Empty(t, providers)
+	require.Len(t, providers, 0)
+
+	// Check that no cache file was created for empty results
+	require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results")
+}
+
+func TestProvider_loadProvidersEmptyCache(t *testing.T) {
+	client := &mockProviderClient{shouldFail: false}
+	tmpPath := t.TempDir() + "/providers.json"
+
+	// Create an empty cache file
+	emptyProviders := []provider.Provider{}
+	data, err := json.Marshal(emptyProviders)
+	require.NoError(t, err)
+	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
+
+	// Should refresh and get real providers instead of using empty cache
+	providers, err := loadProviders(client, tmpPath)
+	require.NoError(t, err)
+	require.NotNil(t, providers)
+	require.Len(t, providers, 1)
+	require.Equal(t, "Mock", providers[0].Name)
+}

internal/config/provider_test.go 🔗

@@ -28,7 +28,7 @@ func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
 func TestProvider_loadProvidersNoIssues(t *testing.T) {
 	client := &mockProviderClient{shouldFail: false}
 	tmpPath := t.TempDir() + "/providers.json"
-	providers, err := loadProviders(tmpPath, client)
+	providers, err := loadProviders(client, tmpPath)
 	assert.NoError(t, err)
 	assert.NotNil(t, providers)
 	assert.Len(t, providers, 1)
@@ -57,7 +57,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to write old providers to file: %v", err)
 	}
-	providers, err := loadProviders(tmpPath, client)
+	providers, err := loadProviders(client, tmpPath)
 	assert.NoError(t, err)
 	assert.NotNil(t, providers)
 	assert.Len(t, providers, 1)
@@ -67,7 +67,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
 func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
 	client := &mockProviderClient{shouldFail: true}
 	tmpPath := t.TempDir() + "/providers.json"
-	providers, err := loadProviders(tmpPath, client)
+	providers, err := loadProviders(client, tmpPath)
 	assert.Error(t, err)
 	assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
 }