Detailed changes
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
+ "log/slog"
"os"
"path/filepath"
"runtime"
@@ -13,10 +14,8 @@ import (
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
- "github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/log"
- "golang.org/x/exp/slog"
)
// LoadReader config via io.Reader.
@@ -63,7 +62,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
)
// Load known providers, this loads the config from fur
- providers, err := LoadProviders(client.New())
+ providers, err := Providers()
if err != nil || len(providers) == 0 {
return nil, fmt.Errorf("failed to load providers: %w", err)
}
@@ -2,10 +2,13 @@ package config
import (
"encoding/json"
+ "fmt"
+ "log/slog"
"os"
"path/filepath"
"runtime"
"sync"
+ "time"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
@@ -42,57 +45,88 @@ func providerCacheFileData() string {
}
func saveProvidersInCache(path string, providers []provider.Provider) error {
- dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0o755); err != nil {
- return err
+ slog.Info("Caching provider data")
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ return fmt.Errorf("failed to create directory for provider cache: %w", err)
}
data, err := json.MarshalIndent(providers, "", " ")
if err != nil {
- return err
+ return fmt.Errorf("failed to marshal provider data: %w", err)
}
- return os.WriteFile(path, data, 0o644)
+ if err := os.WriteFile(path, data, 0o644); err != nil {
+ return fmt.Errorf("failed to write provider data to cache: %w", err)
+ }
+ return nil
}
func loadProvidersFromCache(path string) ([]provider.Provider, error) {
data, err := os.ReadFile(path)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to read provider cache file: %w", err)
}
var providers []provider.Provider
- err = json.Unmarshal(data, &providers)
- return providers, err
-}
-
-func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
- providers, err := client.GetProviders()
- if err != nil {
- fallbackToCache, err := loadProvidersFromCache(path)
- if err != nil {
- return nil, err
- }
- providers = fallbackToCache
- } else {
- if err := saveProvidersInCache(path, providerList); err != nil {
- return nil, err
- }
+ if err := json.Unmarshal(data, &providers); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
}
return providers, nil
}
func Providers() ([]provider.Provider, error) {
- return LoadProviders(client.New())
+ client := client.New()
+ path := providerCacheFileData()
+ return loadProvidersOnce(client, path)
}
-func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
+func loadProvidersOnce(client ProviderClient, path string) ([]provider.Provider, error) {
var err error
providerOnce.Do(func() {
- providerList, err = loadProviders(providerCacheFileData(), client)
+ providerList, err = loadProviders(client, path)
})
if err != nil {
return nil, err
}
return providerList, nil
}
+
+func loadProviders(client ProviderClient, path string) (providerList []provider.Provider, err error) {
+ // if cache is not stale, load from it
+ stale, exists := isCacheStale(path)
+ if !stale {
+ slog.Info("Using cached provider data")
+ providerList, err = loadProvidersFromCache(path)
+ if len(providerList) > 0 && err == nil {
+ go func() {
+ slog.Info("Updating provider cache in background")
+ updated, uerr := client.GetProviders()
+ if len(updated) == 0 && uerr == nil {
+ _ = saveProvidersInCache(path, updated)
+ }
+ }()
+ return
+ }
+ }
+
+ slog.Info("Getting live provider data")
+ providerList, err = client.GetProviders()
+ if len(providerList) > 0 && err == nil {
+ err = saveProvidersInCache(path, providerList)
+ return
+ }
+ if !exists {
+ err = fmt.Errorf("failed to load providers")
+ return
+ }
+ providerList, err = loadProvidersFromCache(path)
+ return
+}
+
+func isCacheStale(path string) (stale, exists bool) {
+ info, err := os.Stat(path)
+ if err != nil {
+ return true, false
+ }
+ return time.Since(info.ModTime()) > 24*time.Hour, true
+}
@@ -0,0 +1,47 @@
+package config
+
+import (
+ "encoding/json"
+ "os"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/stretchr/testify/require"
+)
+
+type emptyProviderClient struct{}
+
+func (m *emptyProviderClient) GetProviders() ([]provider.Provider, error) {
+ return []provider.Provider{}, nil
+}
+
+func TestProvider_loadProvidersEmptyResult(t *testing.T) {
+ client := &emptyProviderClient{}
+ tmpPath := t.TempDir() + "/providers.json"
+
+ providers, err := loadProviders(client, tmpPath)
+ require.EqualError(t, err, "failed to load providers")
+ require.Empty(t, providers)
+ require.Len(t, providers, 0)
+
+ // Check that no cache file was created for empty results
+ require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results")
+}
+
+func TestProvider_loadProvidersEmptyCache(t *testing.T) {
+ client := &mockProviderClient{shouldFail: false}
+ tmpPath := t.TempDir() + "/providers.json"
+
+ // Create an empty cache file
+ emptyProviders := []provider.Provider{}
+ data, err := json.Marshal(emptyProviders)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
+
+ // Should refresh and get real providers instead of using empty cache
+ providers, err := loadProviders(client, tmpPath)
+ require.NoError(t, err)
+ require.NotNil(t, providers)
+ require.Len(t, providers, 1)
+ require.Equal(t, "Mock", providers[0].Name)
+}
@@ -28,7 +28,7 @@ func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
func TestProvider_loadProvidersNoIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: false}
tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(tmpPath, client)
+ providers, err := loadProviders(client, tmpPath)
assert.NoError(t, err)
assert.NotNil(t, providers)
assert.Len(t, providers, 1)
@@ -57,7 +57,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
if err != nil {
t.Fatalf("Failed to write old providers to file: %v", err)
}
- providers, err := loadProviders(tmpPath, client)
+ providers, err := loadProviders(client, tmpPath)
assert.NoError(t, err)
assert.NotNil(t, providers)
assert.Len(t, providers, 1)
@@ -67,7 +67,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(tmpPath, client)
+ providers, err := loadProviders(client, tmpPath)
assert.Error(t, err)
assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
}