perf: improve startup, specifically provider updates (#1577)

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

go.mod                                 |   4 
go.sum                                 |   4 
internal/cmd/update_providers.go       |   6 
internal/config/provider.go            | 109 ++++++++++++++-------------
internal/config/provider_empty_test.go |  30 +-----
internal/config/provider_test.go       | 111 ++++++++++++++++++++++-----
6 files changed, 161 insertions(+), 103 deletions(-)

Detailed changes

go.mod 🔗

@@ -1,6 +1,6 @@
 module github.com/charmbracelet/crush
 
-go 1.25.0
+go 1.25.5
 
 require (
 	charm.land/bubbles/v2 v2.0.0-rc.1
@@ -16,7 +16,7 @@ require (
 	github.com/aymanbagabas/go-udiff v0.3.1
 	github.com/bmatcuk/doublestar/v4 v4.9.1
 	github.com/charlievieth/fastwalk v1.0.14
-	github.com/charmbracelet/catwalk v0.9.5
+	github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74
 	github.com/charmbracelet/colorprofile v0.3.3
 	github.com/charmbracelet/fang v0.4.4
 	github.com/charmbracelet/glamour/v2 v2.0.0-20251106195642-800eb8175930

go.sum 🔗

@@ -88,8 +88,8 @@ github.com/charlievieth/fastwalk v1.0.14 h1:3Eh5uaFGwHZd8EGwTjJnSpBkfwfsak9h6ICg
 github.com/charlievieth/fastwalk v1.0.14/go.mod h1:diVcUreiU1aQ4/Wu3NbxxH4/KYdKpLDojrQ1Bb2KgNY=
 github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904 h1:rwLdEpG9wE6kL69KkEKDiWprO8pQOZHZXeod6+9K+mw=
 github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251024181547-21d6f3d9a904/go.mod h1:8TIYxZxsuCqqeJ0lga/b91tBwrbjoHDC66Sq5t8N2R4=
-github.com/charmbracelet/catwalk v0.9.5 h1:QLqajLJfjGTVh2MIVIdAhww2XvPklxmu+p0Z4wrT7KU=
-github.com/charmbracelet/catwalk v0.9.5/go.mod h1:ReU4SdrLfe63jkEjWMdX2wlZMV3k9r11oQAmzN0m+KY=
+github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74 h1:pJI8ivIgSWOeCNnZFXeZzL6px1vZVS3LU4Cqk3Gx37I=
+github.com/charmbracelet/catwalk v0.9.7-0.20251208190755-350e2a004c74/go.mod h1:ReU4SdrLfe63jkEjWMdX2wlZMV3k9r11oQAmzN0m+KY=
 github.com/charmbracelet/colorprofile v0.3.3 h1:DjJzJtLP6/NZ8p7Cgjno0CKGr7wwRJGxWUwh2IyhfAI=
 github.com/charmbracelet/colorprofile v0.3.3/go.mod h1:nB1FugsAbzq284eJcjfah2nhdSLppN2NqvfotkfRYP4=
 github.com/charmbracelet/fang v0.4.4 h1:G4qKxF6or/eTPgmAolwPuRNyuci3hTUGGX1rj1YkHJY=

internal/cmd/update_providers.go 🔗

@@ -31,12 +31,12 @@ crush update-providers embedded
 		// NOTE(@andreynering): We want to skip logging output do stdout here.
 		slog.SetDefault(slog.New(slog.DiscardHandler))
 
-		var pathOrUrl string
+		var pathOrURL string
 		if len(args) > 0 {
-			pathOrUrl = args[0]
+			pathOrURL = args[0]
 		}
 
-		if err := config.UpdateProviders(pathOrUrl); err != nil {
+		if err := config.UpdateProviders(pathOrURL); err != nil {
 			return err
 		}
 

internal/config/provider.go 🔗

@@ -2,7 +2,9 @@ package config
 
 import (
 	"cmp"
+	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"log/slog"
 	"os"
@@ -17,7 +19,7 @@ import (
 )
 
 type ProviderClient interface {
-	GetProviders() ([]catwalk.Provider, error)
+	GetProviders(context.Context, string) ([]catwalk.Provider, error)
 }
 
 var (
@@ -53,7 +55,7 @@ func saveProvidersInCache(path string, providers []catwalk.Provider) error {
 		return fmt.Errorf("failed to create directory for provider cache: %w", err)
 	}
 
-	data, err := json.MarshalIndent(providers, "", "  ")
+	data, err := json.Marshal(providers)
 	if err != nil {
 		return fmt.Errorf("failed to marshal provider data: %w", err)
 	}
@@ -64,34 +66,35 @@ func saveProvidersInCache(path string, providers []catwalk.Provider) error {
 	return nil
 }
 
-func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
+func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) {
 	data, err := os.ReadFile(path)
 	if err != nil {
-		return nil, fmt.Errorf("failed to read provider cache file: %w", err)
+		return nil, "", fmt.Errorf("failed to read provider cache file: %w", err)
 	}
 
 	var providers []catwalk.Provider
 	if err := json.Unmarshal(data, &providers); err != nil {
-		return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
+		return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
 	}
-	return providers, nil
+
+	return providers, catwalk.Etag(data), nil
 }
 
-func UpdateProviders(pathOrUrl string) error {
+func UpdateProviders(pathOrURL string) error {
 	var providers []catwalk.Provider
-	pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
+	pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 
 	switch {
-	case pathOrUrl == "embedded":
+	case pathOrURL == "embedded":
 		providers = embedded.GetAll()
-	case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
+	case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
 		var err error
-		providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
+		providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
 		if err != nil {
 			return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
 		}
 	default:
-		content, err := os.ReadFile(pathOrUrl)
+		content, err := os.ReadFile(pathOrURL)
 		if err != nil {
 			return fmt.Errorf("failed to read file: %w", err)
 		}
@@ -108,61 +111,61 @@ func UpdateProviders(pathOrUrl string) error {
 		return fmt.Errorf("failed to save providers to cache: %w", err)
 	}
 
-	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
+	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath)
 	return nil
 }
 
+// Providers returns the list of providers, taking into account cached results
+// and whether or not auto update is enabled.
+//
+// It will:
+// 1. if auto update is disabled, it'll return the embedded providers at the
+// time of release.
+// 2. load the cached providers
+// 3. try to get the fresh list of providers, and return either this new list,
+// the cached list, or the embedded list if all others fail.
 func Providers(cfg *Config) ([]catwalk.Provider, error) {
 	providerOnce.Do(func() {
 		catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 		client := catwalk.NewWithURL(catwalkURL)
 		path := providerCacheFileData()
 
-		autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
-		providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
-	})
-	return providerList, providerErr
-}
-
-func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
-	catwalkGetAndSave := func() ([]catwalk.Provider, error) {
-		providers, err := client.GetProviders()
-		if err != nil {
-			return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
+		if cfg.Options.DisableProviderAutoUpdate {
+			slog.Info("Using embedded Catwalk providers")
+			providerList, providerErr = embedded.GetAll(), nil
+			return
 		}
-		if len(providers) == 0 {
-			return nil, fmt.Errorf("empty providers list from catwalk")
-		}
-		if err := saveProvidersInCache(path, providers); err != nil {
-			return nil, err
-		}
-		return providers, nil
-	}
-
-	switch {
-	case autoUpdateDisabled:
-		slog.Warn("Providers auto-update is disabled")
 
-		if _, err := os.Stat(path); err == nil {
-			slog.Warn("Using locally cached providers")
-			return loadProvidersFromCache(path)
+		cached, etag, cachedErr := loadProvidersFromCache(path)
+		if len(cached) == 0 || cachedErr != nil {
+			// if cached file is empty, default to embedded providers
+			cached = embedded.GetAll()
 		}
 
-		slog.Warn("Saving embedded providers to cache")
-		providers := embedded.GetAll()
-		if err := saveProvidersInCache(path, providers); err != nil {
-			return nil, err
+		providerList, providerErr = loadProviders(client, etag, path)
+		if errors.Is(providerErr, catwalk.ErrNotModified) {
+			slog.Info("Catwalk providers not modified")
+			providerList, providerErr = cached, nil
 		}
-		return providers, nil
-
-	default:
-		slog.Info("Fetching providers from Catwalk.", "path", path)
+	})
+	if providerErr != nil {
+		catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
+		return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck
+	}
+	return providerList, nil
+}
 
-		providers, err := catwalkGetAndSave()
-		if err != nil {
-			catwalkUrl := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
-			return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck
-		}
-		return providers, nil
+func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) {
+	slog.Info("Fetching providers from Catwalk.", "path", path)
+	providers, err := client.GetProviders(context.Background(), etag)
+	if err != nil {
+		return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
+	}
+	if len(providers) == 0 {
+		return nil, errors.New("empty providers list from catwalk")
 	}
+	if err := saveProvidersInCache(path, providers); err != nil {
+		return nil, err
+	}
+	return providers, nil
 }

internal/config/provider_empty_test.go 🔗

@@ -1,8 +1,7 @@
 package config
 
 import (
-	"encoding/json"
-	"os"
+	"context"
 	"testing"
 
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -11,37 +10,22 @@ import (
 
 type emptyProviderClient struct{}
 
-func (m *emptyProviderClient) GetProviders() ([]catwalk.Provider, error) {
+func (m *emptyProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) {
 	return []catwalk.Provider{}, nil
 }
 
+// TestProvider_loadProvidersEmptyResult tests that loadProviders returns an
+// error when the client returns an empty list. This ensures we don't cache
+// empty provider lists.
 func TestProvider_loadProvidersEmptyResult(t *testing.T) {
 	client := &emptyProviderClient{}
 	tmpPath := t.TempDir() + "/providers.json"
 
-	providers, err := loadProviders(false, client, tmpPath)
-	require.Contains(t, err.Error(), "Crush was unable to fetch an updated list of providers")
+	providers, err := loadProviders(client, "", tmpPath)
+	require.Contains(t, err.Error(), "empty providers list from catwalk")
 	require.Empty(t, providers)
 	require.Len(t, providers, 0)
 
 	// Check that no cache file was created for empty results
 	require.NoFileExists(t, tmpPath, "Cache file should not exist for empty results")
 }
-
-func TestProvider_loadProvidersEmptyCache(t *testing.T) {
-	client := &mockProviderClient{shouldFail: false}
-	tmpPath := t.TempDir() + "/providers.json"
-
-	// Create an empty cache file
-	emptyProviders := []catwalk.Provider{}
-	data, err := json.Marshal(emptyProviders)
-	require.NoError(t, err)
-	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
-
-	// Should refresh and get real providers instead of using empty cache
-	providers, err := loadProviders(false, client, tmpPath)
-	require.NoError(t, err)
-	require.NotNil(t, providers)
-	require.Len(t, providers, 1)
-	require.Equal(t, "Mock", providers[0].Name)
-}

internal/config/provider_test.go 🔗

@@ -1,9 +1,11 @@
 package config
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"os"
+	"sync"
 	"testing"
 
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -11,10 +13,14 @@ import (
 )
 
 type mockProviderClient struct {
-	shouldFail bool
+	shouldFail      bool
+	shouldReturnErr error
 }
 
-func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
+func (m *mockProviderClient) GetProviders(context.Context, string) ([]catwalk.Provider, error) {
+	if m.shouldReturnErr != nil {
+		return nil, m.shouldReturnErr
+	}
 	if m.shouldFail {
 		return nil, errors.New("failed to load providers")
 	}
@@ -25,10 +31,16 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
 	}, nil
 }
 
+func resetProviderState() {
+	providerOnce = sync.Once{}
+	providerList = nil
+	providerErr = nil
+}
+
 func TestProvider_loadProvidersNoIssues(t *testing.T) {
 	client := &mockProviderClient{shouldFail: false}
 	tmpPath := t.TempDir() + "/providers.json"
-	providers, err := loadProviders(false, client, tmpPath)
+	providers, err := loadProviders(client, "", tmpPath)
 	require.NoError(t, err)
 	require.NotNil(t, providers)
 	require.Len(t, providers, 1)
@@ -39,35 +51,94 @@ func TestProvider_loadProvidersNoIssues(t *testing.T) {
 	require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
 }
 
-func TestProvider_loadProvidersWithIssues(t *testing.T) {
-	client := &mockProviderClient{shouldFail: true}
-	tmpPath := t.TempDir() + "/providers.json"
-	// store providers to a temporary file
-	oldProviders := []catwalk.Provider{
-		{
-			Name: "OldProvider",
+func TestProvider_DisableAutoUpdate(t *testing.T) {
+	t.Setenv("XDG_DATA_HOME", t.TempDir())
+
+	resetProviderState()
+	defer resetProviderState()
+
+	cfg := &Config{
+		Options: &Options{
+			DisableProviderAutoUpdate: true,
 		},
 	}
-	data, err := json.Marshal(oldProviders)
-	if err != nil {
-		t.Fatalf("Failed to marshal old providers: %v", err)
-	}
 
-	err = os.WriteFile(tmpPath, data, 0o644)
-	if err != nil {
-		t.Fatalf("Failed to write old providers to file: %v", err)
+	providers, err := Providers(cfg)
+	require.NoError(t, err)
+	require.NotNil(t, providers)
+	require.Greater(t, len(providers), 5, "Expected embedded providers")
+}
+
+func TestProvider_WithValidCache(t *testing.T) {
+	tmpDir := t.TempDir()
+	t.Setenv("XDG_DATA_HOME", tmpDir)
+
+	resetProviderState()
+	defer resetProviderState()
+
+	cachePath := tmpDir + "/crush/providers.json"
+	require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
+	cachedProviders := []catwalk.Provider{
+		{Name: "Cached"},
 	}
-	providers, err := loadProviders(true, client, tmpPath)
+	data, err := json.Marshal(cachedProviders)
+	require.NoError(t, err)
+	require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+
+	mockClient := &mockProviderClient{shouldFail: false}
+
+	providers, err := loadProviders(mockClient, "", cachePath)
 	require.NoError(t, err)
 	require.NotNil(t, providers)
 	require.Len(t, providers, 1)
-	require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
+	require.Equal(t, "Mock", providers[0].Name, "Expected fresh provider from fetch")
+}
+
+func TestProvider_NotModifiedUsesCached(t *testing.T) {
+	tmpDir := t.TempDir()
+	t.Setenv("XDG_DATA_HOME", tmpDir)
+
+	resetProviderState()
+	defer resetProviderState()
+
+	cachePath := tmpDir + "/crush/providers.json"
+	require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
+	cachedProviders := []catwalk.Provider{
+		{Name: "Cached"},
+	}
+	data, err := json.Marshal(cachedProviders)
+	require.NoError(t, err)
+	require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+
+	mockClient := &mockProviderClient{shouldReturnErr: catwalk.ErrNotModified}
+	providers, err := loadProviders(mockClient, "", cachePath)
+	require.ErrorIs(t, err, catwalk.ErrNotModified)
+	require.Nil(t, providers)
+}
+
+func TestProvider_EmptyCacheDefaultsToEmbedded(t *testing.T) {
+	tmpDir := t.TempDir()
+	t.Setenv("XDG_DATA_HOME", tmpDir)
+
+	resetProviderState()
+	defer resetProviderState()
+
+	cachePath := tmpDir + "/crush/providers.json"
+	require.NoError(t, os.MkdirAll(tmpDir+"/crush", 0o755))
+	emptyProviders := []catwalk.Provider{}
+	data, err := json.Marshal(emptyProviders)
+	require.NoError(t, err)
+	require.NoError(t, os.WriteFile(cachePath, data, 0o644))
+
+	cached, _, err := loadProvidersFromCache(cachePath)
+	require.NoError(t, err)
+	require.Empty(t, cached, "Expected empty cache")
 }
 
 func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
 	client := &mockProviderClient{shouldFail: true}
 	tmpPath := t.TempDir() + "/providers.json"
-	providers, err := loadProviders(false, client, tmpPath)
+	providers, err := loadProviders(client, "", tmpPath)
 	require.Error(t, err)
 	require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
 }