provider.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"runtime"
 13	"strings"
 14	"sync"
 15
 16	"github.com/charmbracelet/catwalk/pkg/catwalk"
 17	"github.com/charmbracelet/catwalk/pkg/embedded"
 18	"github.com/charmbracelet/crush/internal/home"
 19)
 20
 21type ProviderClient interface {
 22	GetProviders(context.Context, string) ([]catwalk.Provider, error)
 23}
 24
 25var (
 26	providerOnce sync.Once
 27	providerList []catwalk.Provider
 28	providerErr  error
 29)
 30
 31// file to cache provider data
 32func providerCacheFileData() string {
 33	xdgDataHome := os.Getenv("XDG_DATA_HOME")
 34	if xdgDataHome != "" {
 35		return filepath.Join(xdgDataHome, appName, "providers.json")
 36	}
 37
 38	// return the path to the main data directory
 39	// for windows, it should be in `%LOCALAPPDATA%/crush/`
 40	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
 41	if runtime.GOOS == "windows" {
 42		localAppData := os.Getenv("LOCALAPPDATA")
 43		if localAppData == "" {
 44			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
 45		}
 46		return filepath.Join(localAppData, appName, "providers.json")
 47	}
 48
 49	return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
 50}
 51
 52func saveProvidersInCache(path string, providers []catwalk.Provider) error {
 53	slog.Info("Saving provider data to disk", "path", path)
 54	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
 55		return fmt.Errorf("failed to create directory for provider cache: %w", err)
 56	}
 57
 58	data, err := json.Marshal(providers)
 59	if err != nil {
 60		return fmt.Errorf("failed to marshal provider data: %w", err)
 61	}
 62
 63	if err := os.WriteFile(path, data, 0o644); err != nil {
 64		return fmt.Errorf("failed to write provider data to cache: %w", err)
 65	}
 66	return nil
 67}
 68
 69func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) {
 70	data, err := os.ReadFile(path)
 71	if err != nil {
 72		return nil, "", fmt.Errorf("failed to read provider cache file: %w", err)
 73	}
 74
 75	var providers []catwalk.Provider
 76	if err := json.Unmarshal(data, &providers); err != nil {
 77		return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
 78	}
 79
 80	return providers, catwalk.Etag(data), nil
 81}
 82
 83func UpdateProviders(pathOrURL string) error {
 84	var providers []catwalk.Provider
 85	pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 86
 87	switch {
 88	case pathOrURL == "embedded":
 89		providers = embedded.GetAll()
 90	case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
 91		var err error
 92		providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
 93		if err != nil {
 94			return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
 95		}
 96	default:
 97		content, err := os.ReadFile(pathOrURL)
 98		if err != nil {
 99			return fmt.Errorf("failed to read file: %w", err)
100		}
101		if err := json.Unmarshal(content, &providers); err != nil {
102			return fmt.Errorf("failed to unmarshal provider data: %w", err)
103		}
104		if len(providers) == 0 {
105			return fmt.Errorf("no providers found in the provided source")
106		}
107	}
108
109	cachePath := providerCacheFileData()
110	if err := saveProvidersInCache(cachePath, providers); err != nil {
111		return fmt.Errorf("failed to save providers to cache: %w", err)
112	}
113
114	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath)
115	return nil
116}
117
118// Providers returns the list of providers, taking into account cached results
119// and whether or not auto update is enabled.
120//
121// It will:
122// 1. if auto update is disabled, it'll return the embedded providers at the
123// time of release.
124// 2. load the cached providers
125// 3. try to get the fresh list of providers, and return either this new list,
126// the cached list, or the embedded list if all others fail.
127func Providers(cfg *Config) ([]catwalk.Provider, error) {
128	providerOnce.Do(func() {
129		catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
130		client := catwalk.NewWithURL(catwalkURL)
131		path := providerCacheFileData()
132
133		if cfg.Options.DisableProviderAutoUpdate {
134			slog.Info("Using embedded Catwalk providers")
135			providerList, providerErr = embedded.GetAll(), nil
136			return
137		}
138
139		cached, etag, cachedErr := loadProvidersFromCache(path)
140		if len(cached) == 0 || cachedErr != nil {
141			// if cached file is empty, default to embedded providers
142			cached = embedded.GetAll()
143		}
144
145		providerList, providerErr = loadProviders(client, etag, path)
146		if errors.Is(providerErr, catwalk.ErrNotModified) {
147			slog.Info("Catwalk providers not modified")
148			providerList, providerErr = cached, nil
149		}
150	})
151	if providerErr != nil {
152		catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
153		return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck
154	}
155	return providerList, nil
156}
157
158func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) {
159	slog.Info("Fetching providers from Catwalk.", "path", path)
160	providers, err := client.GetProviders(context.Background(), etag)
161	if err != nil {
162		return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
163	}
164	if len(providers) == 0 {
165		return nil, errors.New("empty providers list from catwalk")
166	}
167	if err := saveProvidersInCache(path, providers); err != nil {
168		return nil, err
169	}
170	return providers, nil
171}