provider.go

  1package config
  2
  3import (
  4	"cmp"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"strings"
 12	"sync"
 13
 14	"git.secluded.site/crush/internal/home"
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/catwalk/pkg/embedded"
 17)
 18
 19type ProviderClient interface {
 20	GetProviders() ([]catwalk.Provider, error)
 21}
 22
 23var (
 24	providerOnce sync.Once
 25	providerList []catwalk.Provider
 26	providerErr  error
 27)
 28
 29// file to cache provider data
 30func providerCacheFileData() string {
 31	xdgDataHome := os.Getenv("XDG_DATA_HOME")
 32	if xdgDataHome != "" {
 33		return filepath.Join(xdgDataHome, appName, "providers.json")
 34	}
 35
 36	// return the path to the main data directory
 37	// for windows, it should be in `%LOCALAPPDATA%/crush/`
 38	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
 39	if runtime.GOOS == "windows" {
 40		localAppData := os.Getenv("LOCALAPPDATA")
 41		if localAppData == "" {
 42			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
 43		}
 44		return filepath.Join(localAppData, appName, "providers.json")
 45	}
 46
 47	return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
 48}
 49
 50func saveProvidersInCache(path string, providers []catwalk.Provider) error {
 51	slog.Info("Saving provider data to disk", "path", path)
 52	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
 53		return fmt.Errorf("failed to create directory for provider cache: %w", err)
 54	}
 55
 56	data, err := json.MarshalIndent(providers, "", "  ")
 57	if err != nil {
 58		return fmt.Errorf("failed to marshal provider data: %w", err)
 59	}
 60
 61	if err := os.WriteFile(path, data, 0o644); err != nil {
 62		return fmt.Errorf("failed to write provider data to cache: %w", err)
 63	}
 64	return nil
 65}
 66
 67func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
 68	data, err := os.ReadFile(path)
 69	if err != nil {
 70		return nil, fmt.Errorf("failed to read provider cache file: %w", err)
 71	}
 72
 73	var providers []catwalk.Provider
 74	if err := json.Unmarshal(data, &providers); err != nil {
 75		return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
 76	}
 77	return providers, nil
 78}
 79
 80func UpdateProviders(pathOrUrl string) error {
 81	var providers []catwalk.Provider
 82	pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 83
 84	switch {
 85	case pathOrUrl == "embedded":
 86		providers = embedded.GetAll()
 87	case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
 88		var err error
 89		providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
 90		if err != nil {
 91			return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
 92		}
 93	default:
 94		content, err := os.ReadFile(pathOrUrl)
 95		if err != nil {
 96			return fmt.Errorf("failed to read file: %w", err)
 97		}
 98		if err := json.Unmarshal(content, &providers); err != nil {
 99			return fmt.Errorf("failed to unmarshal provider data: %w", err)
100		}
101		if len(providers) == 0 {
102			return fmt.Errorf("no providers found in the provided source")
103		}
104	}
105
106	cachePath := providerCacheFileData()
107	if err := saveProvidersInCache(cachePath, providers); err != nil {
108		return fmt.Errorf("failed to save providers to cache: %w", err)
109	}
110
111	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
112	return nil
113}
114
115func Providers(cfg *Config) ([]catwalk.Provider, error) {
116	providerOnce.Do(func() {
117		catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
118		client := catwalk.NewWithURL(catwalkURL)
119		path := providerCacheFileData()
120
121		autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
122		providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
123	})
124	return providerList, providerErr
125}
126
127func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
128	catwalkGetAndSave := func() ([]catwalk.Provider, error) {
129		providers, err := client.GetProviders()
130		if err != nil {
131			return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
132		}
133		if len(providers) == 0 {
134			return nil, fmt.Errorf("empty providers list from catwalk")
135		}
136		if err := saveProvidersInCache(path, providers); err != nil {
137			return nil, err
138		}
139		return providers, nil
140	}
141
142	switch {
143	case autoUpdateDisabled:
144		slog.Warn("Providers auto-update is disabled")
145
146		if _, err := os.Stat(path); err == nil {
147			slog.Warn("Using locally cached providers")
148			return loadProvidersFromCache(path)
149		}
150
151		slog.Warn("Saving embedded providers to cache")
152		providers := embedded.GetAll()
153		if err := saveProvidersInCache(path, providers); err != nil {
154			return nil, err
155		}
156		return providers, nil
157
158	default:
159		slog.Info("Fetching providers from Catwalk.", "path", path)
160
161		providers, err := catwalkGetAndSave()
162		if err != nil {
163			catwalkUrl := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
164			return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck
165		}
166		return providers, nil
167	}
168}