provider.go

  1package config
  2
  3import (
  4	"cmp"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/catwalk/pkg/embedded"
 17	"github.com/charmbracelet/crush/internal/home"
 18)
 19
 20type ProviderClient interface {
 21	GetProviders() ([]catwalk.Provider, error)
 22}
 23
 24var (
 25	providerOnce sync.Once
 26	providerList []catwalk.Provider
 27	providerErr  error
 28)
 29
 30// file to cache provider data
 31func providerCacheFileData() string {
 32	xdgDataHome := os.Getenv("XDG_DATA_HOME")
 33	if xdgDataHome != "" {
 34		return filepath.Join(xdgDataHome, appName, "providers.json")
 35	}
 36
 37	// return the path to the main data directory
 38	// for windows, it should be in `%LOCALAPPDATA%/crush/`
 39	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
 40	if runtime.GOOS == "windows" {
 41		localAppData := os.Getenv("LOCALAPPDATA")
 42		if localAppData == "" {
 43			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
 44		}
 45		return filepath.Join(localAppData, appName, "providers.json")
 46	}
 47
 48	return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
 49}
 50
 51func saveProvidersInCache(path string, providers []catwalk.Provider) error {
 52	slog.Info("Saving cached provider data", "path", path)
 53	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
 54		return fmt.Errorf("failed to create directory for provider cache: %w", err)
 55	}
 56
 57	data, err := json.MarshalIndent(providers, "", "  ")
 58	if err != nil {
 59		return fmt.Errorf("failed to marshal provider data: %w", err)
 60	}
 61
 62	if err := os.WriteFile(path, data, 0o644); err != nil {
 63		return fmt.Errorf("failed to write provider data to cache: %w", err)
 64	}
 65	return nil
 66}
 67
 68func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
 69	data, err := os.ReadFile(path)
 70	if err != nil {
 71		return nil, fmt.Errorf("failed to read provider cache file: %w", err)
 72	}
 73
 74	var providers []catwalk.Provider
 75	if err := json.Unmarshal(data, &providers); err != nil {
 76		return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
 77	}
 78	return providers, nil
 79}
 80
 81// NOTE(tauraamui) <REF#1>: there seems to be duplication of logic for updating
 82// the providers cache/from catwalk, in that this method is only invoked/used by the
 83// catwalk CLI command, when it should probably be the exact same behaviour that
 84// crush carries out internally as well when it does the sync internally/automatically
 85// see (REF#2)
 86func UpdateProviders(pathOrUrl string) error {
 87	var providers []catwalk.Provider
 88	pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 89
 90	switch {
 91	case pathOrUrl == "embedded":
 92		providers = embedded.GetAll()
 93	case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
 94		var err error
 95		providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
 96		if err != nil {
 97			return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
 98		}
 99	default:
100		content, err := os.ReadFile(pathOrUrl)
101		if err != nil {
102			return fmt.Errorf("failed to read file: %w", err)
103		}
104		if err := json.Unmarshal(content, &providers); err != nil {
105			return fmt.Errorf("failed to unmarshal provider data: %w", err)
106		}
107		if len(providers) == 0 {
108			return fmt.Errorf("no providers found in the provided source")
109		}
110	}
111
112	cachePath := providerCacheFileData()
113	if err := saveProvidersInCache(cachePath, providers); err != nil {
114		return fmt.Errorf("failed to save providers to cache: %w", err)
115	}
116
117	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
118	return nil
119}
120
121// NOTE(tauraamui) <REF#2>: see note (REF#1), basically this looks like some logic
122// should be shared/consolidated.
123func Providers(cfg *Config) ([]catwalk.Provider, error) {
124	providerOnce.Do(func() {
125		catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
126		client := catwalk.NewWithURL(catwalkURL)
127		path := providerCacheFileData()
128
129		autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
130		providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
131	})
132	return providerList, providerErr
133}
134
135func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
136	cacheIsStale, cacheExists := isCacheStale(path)
137
138	catwalkGetAndSave := func() ([]catwalk.Provider, error) {
139		providers, err := client.GetProviders()
140		if err != nil {
141			return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
142		}
143		if len(providers) == 0 {
144			return nil, fmt.Errorf("empty providers list from catwalk")
145		}
146		if err := saveProvidersInCache(path, providers); err != nil {
147			return nil, err
148		}
149		return providers, nil
150	}
151
152	backgroundCacheUpdate := func() {
153		go func() {
154			slog.Info("Updating providers cache in background", "path", path)
155
156			providers, err := client.GetProviders()
157			if err != nil {
158				slog.Error("Failed to fetch providers in background from Catwalk", "error", err)
159				return
160			}
161			if len(providers) == 0 {
162				slog.Error("Empty providers list from Catwalk")
163				return
164			}
165			if err := saveProvidersInCache(path, providers); err != nil {
166				slog.Error("Failed to update providers.json in background", "error", err)
167			}
168		}()
169	}
170
171	switch {
172	case autoUpdateDisabled:
173		slog.Warn("Providers auto-update is disabled")
174
175		if cacheExists {
176			slog.Warn("Using locally cached providers")
177			return loadProvidersFromCache(path)
178		}
179
180		slog.Warn("Saving embedded providers to cache")
181		providers := embedded.GetAll()
182		if err := saveProvidersInCache(path, providers); err != nil {
183			return nil, err
184		}
185		return providers, nil
186
187	case cacheExists && !cacheIsStale:
188		slog.Info("Recent providers cache is available.", "path", path)
189
190		providers, err := loadProvidersFromCache(path)
191		if err != nil {
192			return nil, err
193		}
194		if len(providers) == 0 {
195			return catwalkGetAndSave()
196		}
197		backgroundCacheUpdate()
198		return providers, nil
199
200	default:
201		slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
202
203		providers, err := catwalkGetAndSave()
204		if err != nil {
205			catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
206			return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck
207		}
208		return providers, nil
209	}
210}
211
212func isCacheStale(path string) (stale, exists bool) {
213	info, err := os.Stat(path)
214	if err != nil {
215		return true, false
216	}
217	return time.Since(info.ModTime()) > 24*time.Hour, true
218}