diff --git a/internal/config/load.go b/internal/config/load.go index 044ef504859bcbcc051b93322099f6d03b1fa601..48ef9b1caf1e5d9ec1877f7fc9c3a53ab996d129 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "os" "path/filepath" "runtime" @@ -13,10 +14,8 @@ import ( "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/log" - "golang.org/x/exp/slog" ) // LoadReader config via io.Reader. @@ -63,7 +62,7 @@ func Load(workingDir string, debug bool) (*Config, error) { ) // Load known providers, this loads the config from fur - providers, err := LoadProviders(client.New()) + providers, err := Providers() if err != nil || len(providers) == 0 { return nil, fmt.Errorf("failed to load providers: %w", err) } diff --git a/internal/config/provider.go b/internal/config/provider.go index b8369b934963aca0a7f449fb219764ee079493ef..caeba48707be933d222313729934cc69c819f68e 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -2,10 +2,13 @@ package config import ( "encoding/json" + "fmt" + "log/slog" "os" "path/filepath" "runtime" "sync" + "time" "github.com/charmbracelet/crush/internal/fur/client" "github.com/charmbracelet/crush/internal/fur/provider" @@ -42,57 +45,88 @@ func providerCacheFileData() string { } func saveProvidersInCache(path string, providers []provider.Provider) error { - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0o755); err != nil { - return err + slog.Info("Caching provider data") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("failed to create directory for provider cache: %w", err) } data, err := json.MarshalIndent(providers, "", " ") if err != nil { - return err + return fmt.Errorf("failed to marshal provider data: %w", err) } - return os.WriteFile(path, data, 0o644) + if err := os.WriteFile(path, data, 0o644); err != nil { + return fmt.Errorf("failed to write provider data to cache: %w", err) + } + return nil } func loadProvidersFromCache(path string) ([]provider.Provider, error) { data, err := os.ReadFile(path) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read provider cache file: %w", err) } var providers []provider.Provider - err = json.Unmarshal(data, &providers) - return providers, err -} - -func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) { - providers, err := client.GetProviders() - if err != nil { - fallbackToCache, err := loadProvidersFromCache(path) - if err != nil { - return nil, err - } - providers = fallbackToCache - } else { - if err := saveProvidersInCache(path, providerList); err != nil { - return nil, err - } + if err := json.Unmarshal(data, &providers); err != nil { + return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err) } return providers, nil } func Providers() ([]provider.Provider, error) { - return LoadProviders(client.New()) + client := client.New() + path := providerCacheFileData() + return loadProvidersOnce(client, path) } -func LoadProviders(client ProviderClient) ([]provider.Provider, error) { +func loadProvidersOnce(client ProviderClient, path string) ([]provider.Provider, error) { var err error providerOnce.Do(func() { - providerList, err = loadProviders(providerCacheFileData(), client) + providerList, err = loadProviders(client, path) }) if err != nil { return nil, err } return providerList, nil } + +func loadProviders(client ProviderClient, path string) (providerList []provider.Provider, err error) { + // if cache is not stale, load from it + stale, exists := isCacheStale(path) + if !stale { + slog.Info("Using cached provider data") + providerList, err = loadProvidersFromCache(path) + if len(providerList) > 0 && err == nil { + go func() { + slog.Info("Updating provider cache in background") + updated, uerr := client.GetProviders() + if len(updated) == 0 && uerr == nil { + _ = saveProvidersInCache(path, updated) + } + }() + return + } + } + + slog.Info("Getting live provider data") + providerList, err = client.GetProviders() + if len(providerList) > 0 && err == nil { + err = saveProvidersInCache(path, providerList) + return + } + if !exists { + err = fmt.Errorf("failed to load providers") + return + } + providerList, err = loadProvidersFromCache(path) + return +} + +func isCacheStale(path string) (stale, exists bool) { + info, err := os.Stat(path) + if err != nil { + return true, false + } + return time.Since(info.ModTime()) > 24*time.Hour, true +} diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go new file mode 100644 index 0000000000000000000000000000000000000000..480869d98e4d69087aefc5759de0776f7910ebec --- /dev/null +++ b/internal/config/provider_empty_test.go @@ -0,0 +1,47 @@ +package config + +import ( + "encoding/json" + "os" + "testing" + + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/stretchr/testify/require" +) + +type emptyProviderClient struct{} + +func (m *emptyProviderClient) GetProviders() ([]provider.Provider, error) { + return []provider.Provider{}, nil +} + +func TestProvider_loadProvidersEmptyResult(t *testing.T) { + client := &emptyProviderClient{} + tmpPath := t.TempDir() + "/providers.json" + + providers, err := loadProviders(client, tmpPath) + require.EqualError(t, err, "failed to load providers") + require.Empty(t, providers) + require.Len(t, providers, 0) + + // Check that no cache file was created for empty results + require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results") +} + +func TestProvider_loadProvidersEmptyCache(t *testing.T) { + client := &mockProviderClient{shouldFail: false} + tmpPath := t.TempDir() + "/providers.json" + + // Create an empty cache file + emptyProviders := []provider.Provider{} + data, err := json.Marshal(emptyProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) + + // Should refresh and get real providers instead of using empty cache + providers, err := loadProviders(client, tmpPath) + require.NoError(t, err) + require.NotNil(t, providers) + require.Len(t, providers, 1) + require.Equal(t, "Mock", providers[0].Name) +} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index a3562838c7103239aa303c906c866220164a4ba0..abfb6592bcd5e46a7cbf40dba54a10722ee69980 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -28,7 +28,7 @@ func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) { func TestProvider_loadProvidersNoIssues(t *testing.T) { client := &mockProviderClient{shouldFail: false} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.NoError(t, err) assert.NotNil(t, providers) assert.Len(t, providers, 1) @@ -57,7 +57,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { if err != nil { t.Fatalf("Failed to write old providers to file: %v", err) } - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.NoError(t, err) assert.NotNil(t, providers) assert.Len(t, providers, 1) @@ -67,7 +67,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(tmpPath, client) + providers, err := loadProviders(client, tmpPath) assert.Error(t, err) assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") }