diff --git a/go.mod b/go.mod index cbfaa3077007fce91f5c7602c77c7f2353a4682a..b419bebcbb13d92ead60b4e7b444cfd53845715a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/charmbracelet/crush -go 1.25.0 +go 1.25.5 require ( charm.land/bubbles/v2 v2.0.0-rc.1 @@ -16,7 +16,7 @@ require ( github.com/aymanbagabas/go-udiff v0.3.1 github.com/bmatcuk/doublestar/v4 v4.9.1 github.com/charlievieth/fastwalk v1.0.14 - github.com/charmbracelet/catwalk v0.9.5 + github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74 github.com/charmbracelet/colorprofile v0.3.3 github.com/charmbracelet/fang v0.4.4 github.com/charmbracelet/glamour/v2 v2.0.0-20251106195642-800eb8175930 diff --git a/go.sum b/go.sum index 238b46788328c5ba0ada41bc882fd0d40949f09e..c406f84cbc6d629b765101bfaa27ba71f5ed7c30 100644 --- a/go.sum +++ b/go.sum @@ -88,8 +88,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4= -github.com/charmbracelet/catwalk v0.9.5 h1:QLqajLJfjGTVh2MIVIdAhww2XvPklxmu+p0Z4wrT7KU= -github.com/charmbracelet/catwalk v0.9.5/go.mod h1:ReU4SdrLfe63jkEjWMdX2wlZMV3k9r11oQAmzN0m+KY= +github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74 h1:pJI8ivIgSWOeCNnZFXeZzL6px1vZVS3LU4Cqk3Gx37I= +github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74/go.mod h1:ReU4SdrLfe63jkEjWMdX2wlZMV3k9r11oQAmzN0m+KY= github.com/charmbracelet/colorprofile v0.3.3 h1:DjJzJtLP6/NZ8p7Cgjno0CKGr7wwRJGxWUwh2IyhfAI= github.com/charmbracelet/colorprofile v0.3.3/go.mod h1:nB1FugsAbzq284eJcjfah2nhdSLppN2NqvfotkfRYP4= github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY= diff --git a/internal/cmd/update_providers.go b/internal/cmd/update_providers.go index 4949c31e2e8b87f212d8ac0ed94e2416f363a53b..599d2c90954ca43888197961ec3bea4372285071 100644 --- a/internal/cmd/update_providers.go +++ b/internal/cmd/update_providers.go @@ -31,12 +31,12 @@ crush update-providers embedded // NOTE(@andreynering): We want to skip logging output do stdout here. slog.SetDefault(slog.New(slog.DiscardHandler)) - var pathOrUrl string + var pathOrURL string if len(args) > 0 { - pathOrUrl = args[0] + pathOrURL = args[0] } - if err := config.UpdateProviders(pathOrUrl); err != nil { + if err := config.UpdateProviders(pathOrURL); err != nil { return err } diff --git a/internal/config/provider.go b/internal/config/provider.go index 7a98caaaa92a1c0401adccd79c3e945f813fbfc7..e9d4dfcc9d0947eebe51d5ae303106404ab47292 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -2,7 +2,9 @@ package config import ( "cmp" + "context" "encoding/json" + "errors" "fmt" "log/slog" "os" @@ -17,7 +19,7 @@ import ( ) type ProviderClient interface { - GetProviders() ([]catwalk.Provider, error) + GetProviders(context.Context, string) ([]catwalk.Provider, error) } var ( @@ -53,7 +55,7 @@ func saveProvidersInCache(path string, providers []catwalk.Provider) error { return fmt.Errorf("failed to create directory for provider cache: %w", err) } - data, err := json.MarshalIndent(providers, "", " ") + data, err := json.Marshal(providers) if err != nil { return fmt.Errorf("failed to marshal provider data: %w", err) } @@ -64,34 +66,35 @@ func saveProvidersInCache(path string, providers []catwalk.Provider) error { return nil } -func loadProvidersFromCache(path string) ([]catwalk.Provider, error) { +func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) { data, err := os.ReadFile(path) if err != nil { - return nil, fmt.Errorf("failed to read provider cache file: %w", err) + return nil, "", fmt.Errorf("failed to read provider cache file: %w", err) } var providers []catwalk.Provider if err := json.Unmarshal(data, &providers); err != nil { - return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err) + return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err) } - return providers, nil + + return providers, catwalk.Etag(data), nil } -func UpdateProviders(pathOrUrl string) error { +func UpdateProviders(pathOrURL string) error { var providers []catwalk.Provider - pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL) + pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL) switch { - case pathOrUrl == "embedded": + case pathOrURL == "embedded": providers = embedded.GetAll() - case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"): + case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"): var err error - providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders() + providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "") if err != nil { return fmt.Errorf("failed to fetch providers from Catwalk: %w", err) } default: - content, err := os.ReadFile(pathOrUrl) + content, err := os.ReadFile(pathOrURL) if err != nil { return fmt.Errorf("failed to read file: %w", err) } @@ -108,61 +111,61 @@ func UpdateProviders(pathOrUrl string) error { return fmt.Errorf("failed to save providers to cache: %w", err) } - slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath) + slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath) return nil } +// Providers returns the list of providers, taking into account cached results +// and whether or not auto update is enabled. +// +// It will: +// 1. if auto update is disabled, it'll return the embedded providers at the +// time of release. +// 2. load the cached providers +// 3. try to get the fresh list of providers, and return either this new list, +// the cached list, or the embedded list if all others fail. func Providers(cfg *Config) ([]catwalk.Provider, error) { providerOnce.Do(func() { catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) client := catwalk.NewWithURL(catwalkURL) path := providerCacheFileData() - autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate - providerList, providerErr = loadProviders(autoUpdateDisabled, client, path) - }) - return providerList, providerErr -} - -func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) { - catwalkGetAndSave := func() ([]catwalk.Provider, error) { - providers, err := client.GetProviders() - if err != nil { - return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err) + if cfg.Options.DisableProviderAutoUpdate { + slog.Info("Using embedded Catwalk providers") + providerList, providerErr = embedded.GetAll(), nil + return } - if len(providers) == 0 { - return nil, fmt.Errorf("empty providers list from catwalk") - } - if err := saveProvidersInCache(path, providers); err != nil { - return nil, err - } - return providers, nil - } - - switch { - case autoUpdateDisabled: - slog.Warn("Providers auto-update is disabled") - if _, err := os.Stat(path); err == nil { - slog.Warn("Using locally cached providers") - return loadProvidersFromCache(path) + cached, etag, cachedErr := loadProvidersFromCache(path) + if len(cached) == 0 || cachedErr != nil { + // if cached file is empty, default to embedded providers + cached = embedded.GetAll() } - slog.Warn("Saving embedded providers to cache") - providers := embedded.GetAll() - if err := saveProvidersInCache(path, providers); err != nil { - return nil, err + providerList, providerErr = loadProviders(client, etag, path) + if errors.Is(providerErr, catwalk.ErrNotModified) { + slog.Info("Catwalk providers not modified") + providerList, providerErr = cached, nil } - return providers, nil - - default: - slog.Info("Fetching providers from Catwalk.", "path", path) + }) + if providerErr != nil { + catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) + return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck + } + return providerList, nil +} - providers, err := catwalkGetAndSave() - if err != nil { - catwalkUrl := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) - return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck - } - return providers, nil +func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) { + slog.Info("Fetching providers from Catwalk.", "path", path) + providers, err := client.GetProviders(context.Background(), etag) + if err != nil { + return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err) + } + if len(providers) == 0 { + return nil, errors.New("empty providers list from catwalk") } + if err := saveProvidersInCache(path, providers); err != nil { + return nil, err + } + return providers, nil } diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go index f3691c320ad4e3509b327374c8ce7f5285c39590..5e889f82a4997d9f5761e6fc4a01ec6d54a0623a 100644 --- a/internal/config/provider_empty_test.go +++ b/internal/config/provider_empty_test.go @@ -1,8 +1,7 @@ package config import ( - "encoding/json" - "os" + "context" "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -11,37 +10,22 @@ import ( type emptyProviderClient struct{} -func (m *emptyProviderClient) GetProviders() ([]catwalk.Provider, error) { +func (m *emptyProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) { return []catwalk.Provider{}, nil } +// TestProvider_loadProvidersEmptyResult tests that loadProviders returns an +// error when the client returns an empty list. This ensures we don't cache +// empty provider lists. func TestProvider_loadProvidersEmptyResult(t *testing.T) { client := &emptyProviderClient{} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) - require.Contains(t, err.Error(), "Crush was unable to fetch an updated list of providers") + providers, err := loadProviders(client, "", tmpPath) + require.Contains(t, err.Error(), "empty providers list from catwalk") require.Empty(t, providers) require.Len(t, providers, 0) // Check that no cache file was created for empty results require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results") } - -func TestProvider_loadProvidersEmptyCache(t *testing.T) { - client := &mockProviderClient{shouldFail: false} - tmpPath := t.TempDir() + "/providers.json" - - // Create an empty cache file - emptyProviders := []catwalk.Provider{} - data, err := json.Marshal(emptyProviders) - require.NoError(t, err) - require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) - - // Should refresh and get real providers instead of using empty cache - providers, err := loadProviders(false, client, tmpPath) - require.NoError(t, err) - require.NotNil(t, providers) - require.Len(t, providers, 1) - require.Equal(t, "Mock", providers[0].Name) -} diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 1262b60ef42050b9061c9f7c6be4dc431efe3548..f101dd8de8ef624ed041ff88115c6c286f902659 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,9 +1,11 @@ package config import ( + "context" "encoding/json" "errors" "os" + "sync" "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -11,10 +13,14 @@ import ( ) type mockProviderClient struct { - shouldFail bool + shouldFail bool + shouldReturnErr error } -func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { +func (m *mockProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) { + if m.shouldReturnErr != nil { + return nil, m.shouldReturnErr + } if m.shouldFail { return nil, errors.New("failed to load providers") } @@ -25,10 +31,16 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { }, nil } +func resetProviderState() { + providerOnce = sync.Once{} + providerList = nil + providerErr = nil +} + func TestProvider_loadProvidersNoIssues(t *testing.T) { client := &mockProviderClient{shouldFail: false} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) + providers, err := loadProviders(client, "", tmpPath) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) @@ -39,35 +51,94 @@ func TestProvider_loadProvidersNoIssues(t *testing.T) { require.False(t, fileInfo.IsDir(), "Expected a file, not a directory") } -func TestProvider_loadProvidersWithIssues(t *testing.T) { - client := &mockProviderClient{shouldFail: true} - tmpPath := t.TempDir() + "/providers.json" - // store providers to a temporary file - oldProviders := []catwalk.Provider{ - { - Name: "OldProvider", +func TestProvider_DisableAutoUpdate(t *testing.T) { + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + resetProviderState() + defer resetProviderState() + + cfg := &Config{ + Options: &Options{ + DisableProviderAutoUpdate: true, }, } - data, err := json.Marshal(oldProviders) - if err != nil { - t.Fatalf("Failed to marshal old providers: %v", err) - } - err = os.WriteFile(tmpPath, data, 0o644) - if err != nil { - t.Fatalf("Failed to write old providers to file: %v", err) + providers, err := Providers(cfg) + require.NoError(t, err) + require.NotNil(t, providers) + require.Greater(t, len(providers), 5, "Expected embedded providers") +} + +func TestProvider_WithValidCache(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + resetProviderState() + defer resetProviderState() + + cachePath := tmpDir + "/crush/providers.json" + require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) + cachedProviders := []catwalk.Provider{ + {Name: "Cached"}, } - providers, err := loadProviders(true, client, tmpPath) + data, err := json.Marshal(cachedProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + + mockClient := &mockProviderClient{shouldFail: false} + + providers, err := loadProviders(mockClient, "", cachePath) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) - require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails") + require.Equal(t, "Mock", providers[0].Name, "Expected fresh provider from fetch") +} + +func TestProvider_NotModifiedUsesCached(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + resetProviderState() + defer resetProviderState() + + cachePath := tmpDir + "/crush/providers.json" + require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) + cachedProviders := []catwalk.Provider{ + {Name: "Cached"}, + } + data, err := json.Marshal(cachedProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + + mockClient := &mockProviderClient{shouldReturnErr: catwalk.ErrNotModified} + providers, err := loadProviders(mockClient, "", cachePath) + require.ErrorIs(t, err, catwalk.ErrNotModified) + require.Nil(t, providers) +} + +func TestProvider_EmptyCacheDefaultsToEmbedded(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + resetProviderState() + defer resetProviderState() + + cachePath := tmpDir + "/crush/providers.json" + require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755)) + emptyProviders := []catwalk.Provider{} + data, err := json.Marshal(emptyProviders) + require.NoError(t, err) + require.NoError(t, os.WriteFile(cachePath, data, 0o644)) + + cached, _, err := loadProvidersFromCache(cachePath) + require.NoError(t, err) + require.Empty(t, cached, "Expected empty cache") } func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(false, client, tmpPath) + providers, err := loadProviders(client, "", tmpPath) require.Error(t, err) require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") }