provider.go

  1package config
  2
  3import (
  4	"cmp"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"sync"
 12	"time"
 13
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/catwalk/pkg/embedded"
 16	"github.com/charmbracelet/crush/internal/home"
 17)
 18
 19type ProviderClient interface {
 20	GetProviders() ([]catwalk.Provider, error)
 21}
 22
 23var (
 24	providerOnce sync.Once
 25	providerList []catwalk.Provider
 26	providerErr  error
 27)
 28
 29// file to cache provider data
 30func providerCacheFileData() string {
 31	xdgDataHome := os.Getenv("XDG_DATA_HOME")
 32	if xdgDataHome != "" {
 33		return filepath.Join(xdgDataHome, appName, "providers.json")
 34	}
 35
 36	// return the path to the main data directory
 37	// for windows, it should be in `%LOCALAPPDATA%/crush/`
 38	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
 39	if runtime.GOOS == "windows" {
 40		localAppData := os.Getenv("LOCALAPPDATA")
 41		if localAppData == "" {
 42			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
 43		}
 44		return filepath.Join(localAppData, appName, "providers.json")
 45	}
 46
 47	return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
 48}
 49
 50func saveProvidersInCache(path string, providers []catwalk.Provider) error {
 51	slog.Info("Saving cached provider data", "path", path)
 52	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
 53		return fmt.Errorf("failed to create directory for provider cache: %w", err)
 54	}
 55
 56	data, err := json.MarshalIndent(providers, "", "  ")
 57	if err != nil {
 58		return fmt.Errorf("failed to marshal provider data: %w", err)
 59	}
 60
 61	if err := os.WriteFile(path, data, 0o644); err != nil {
 62		return fmt.Errorf("failed to write provider data to cache: %w", err)
 63	}
 64	return nil
 65}
 66
 67func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
 68	data, err := os.ReadFile(path)
 69	if err != nil {
 70		return nil, fmt.Errorf("failed to read provider cache file: %w", err)
 71	}
 72
 73	var providers []catwalk.Provider
 74	if err := json.Unmarshal(data, &providers); err != nil {
 75		return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
 76	}
 77	return providers, nil
 78}
 79
 80func Providers(cfg *Config) ([]catwalk.Provider, error) {
 81	providerOnce.Do(func() {
 82		catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 83		client := catwalk.NewWithURL(catwalkURL)
 84		path := providerCacheFileData()
 85
 86		autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
 87		providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
 88	})
 89	return providerList, providerErr
 90}
 91
 92func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
 93	cacheIsStale, cacheExists := isCacheStale(path)
 94
 95	catwalkGetAndSave := func() ([]catwalk.Provider, error) {
 96		providers, err := client.GetProviders()
 97		if err != nil {
 98			return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
 99		}
100		if len(providers) == 0 {
101			return nil, fmt.Errorf("empty providers list from catwalk")
102		}
103		if err := saveProvidersInCache(path, providers); err != nil {
104			return nil, err
105		}
106		return providers, nil
107	}
108
109	backgroundCacheUpdate := func() {
110		go func() {
111			slog.Info("Updating providers cache in background", "path", path)
112
113			providers, err := client.GetProviders()
114			if err != nil {
115				slog.Error("Failed to fetch providers in background from Catwalk", "error", err)
116				return
117			}
118			if len(providers) == 0 {
119				slog.Error("Empty providers list from Catwalk")
120				return
121			}
122			if err := saveProvidersInCache(path, providers); err != nil {
123				slog.Error("Failed to update providers.json in background", "error", err)
124			}
125		}()
126	}
127
128	switch {
129	case autoUpdateDisabled:
130		slog.Warn("Providers auto-update is disabled")
131
132		if cacheExists {
133			slog.Warn("Using locally cached providers")
134			return loadProvidersFromCache(path)
135		}
136
137		slog.Warn("Saving embedded providers to cache")
138		providers := embedded.GetAll()
139		if err := saveProvidersInCache(path, providers); err != nil {
140			return nil, err
141		}
142		return providers, nil
143
144	case cacheExists && !cacheIsStale:
145		slog.Info("Recent providers cache is available.", "path", path)
146
147		providers, err := loadProvidersFromCache(path)
148		if err != nil {
149			return nil, err
150		}
151		if len(providers) == 0 {
152			return catwalkGetAndSave()
153		}
154		backgroundCacheUpdate()
155		return providers, nil
156
157	default:
158		slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
159
160		providers, err := catwalkGetAndSave()
161		if err != nil {
162			catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
163			return nil, fmt.Errorf("crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use embedded version from the time of this Crush release. %w", catwalkUrl, err)
164		}
165		return providers, nil
166	}
167}
168
169func isCacheStale(path string) (stale, exists bool) {
170	info, err := os.Stat(path)
171	if err != nil {
172		return true, false
173	}
174	return time.Since(info.ModTime()) > 24*time.Hour, true
175}