provider.go

  1package config
  2
  3import (
  4	"cmp"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/catwalk/pkg/embedded"
 16	"github.com/charmbracelet/crush/internal/home"
 17	"github.com/charmbracelet/crush/internal/version"
 18)
 19
 20type ProviderClient interface {
 21	GetProviders() ([]catwalk.Provider, error)
 22}
 23
 24// file to cache provider data
 25func providerCacheFileData() string {
 26	xdgDataHome := os.Getenv("XDG_DATA_HOME")
 27	if xdgDataHome != "" {
 28		return filepath.Join(xdgDataHome, version.AppName, "providers.json")
 29	}
 30
 31	// return the path to the main data directory
 32	// for windows, it should be in `%LOCALAPPDATA%/crush/`
 33	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
 34	if runtime.GOOS == "windows" {
 35		localAppData := os.Getenv("LOCALAPPDATA")
 36		if localAppData == "" {
 37			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
 38		}
 39		return filepath.Join(localAppData, version.AppName, "providers.json")
 40	}
 41
 42	return filepath.Join(home.Dir(), ".local", "share", version.AppName, "providers.json")
 43}
 44
 45func saveProvidersInCache(path string, providers []catwalk.Provider) error {
 46	slog.Info("Saving cached provider data", "path", path)
 47	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
 48		return fmt.Errorf("failed to create directory for provider cache: %w", err)
 49	}
 50
 51	data, err := json.MarshalIndent(providers, "", "  ")
 52	if err != nil {
 53		return fmt.Errorf("failed to marshal provider data: %w", err)
 54	}
 55
 56	if err := os.WriteFile(path, data, 0o644); err != nil {
 57		return fmt.Errorf("failed to write provider data to cache: %w", err)
 58	}
 59	return nil
 60}
 61
 62func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
 63	data, err := os.ReadFile(path)
 64	if err != nil {
 65		return nil, fmt.Errorf("failed to read provider cache file: %w", err)
 66	}
 67
 68	var providers []catwalk.Provider
 69	if err := json.Unmarshal(data, &providers); err != nil {
 70		return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
 71	}
 72	return providers, nil
 73}
 74
 75func UpdateProviders(pathOrUrl string) error {
 76	var providers []catwalk.Provider
 77	pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 78
 79	switch {
 80	case pathOrUrl == "embedded":
 81		providers = embedded.GetAll()
 82	case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
 83		var err error
 84		providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
 85		if err != nil {
 86			return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
 87		}
 88	default:
 89		content, err := os.ReadFile(pathOrUrl)
 90		if err != nil {
 91			return fmt.Errorf("failed to read file: %w", err)
 92		}
 93		if err := json.Unmarshal(content, &providers); err != nil {
 94			return fmt.Errorf("failed to unmarshal provider data: %w", err)
 95		}
 96		if len(providers) == 0 {
 97			return fmt.Errorf("no providers found in the provided source")
 98		}
 99	}
100
101	cachePath := providerCacheFileData()
102	if err := saveProvidersInCache(cachePath, providers); err != nil {
103		return fmt.Errorf("failed to save providers to cache: %w", err)
104	}
105
106	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
107	return nil
108}
109
110func Providers(cfg *Config) ([]catwalk.Provider, error) {
111	catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
112	client := catwalk.NewWithURL(catwalkURL)
113	path := providerCacheFileData()
114	return loadProviders(cfg.Options.DisableProviderAutoUpdate, client, path)
115}
116
117func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
118	cacheIsStale, cacheExists := isCacheStale(path)
119
120	catwalkGetAndSave := func() ([]catwalk.Provider, error) {
121		providers, err := client.GetProviders()
122		if err != nil {
123			return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
124		}
125		if len(providers) == 0 {
126			return nil, fmt.Errorf("empty providers list from catwalk")
127		}
128		if err := saveProvidersInCache(path, providers); err != nil {
129			return nil, err
130		}
131		return providers, nil
132	}
133
134	backgroundCacheUpdate := func() {
135		go func() {
136			slog.Info("Updating providers cache in background", "path", path)
137
138			providers, err := client.GetProviders()
139			if err != nil {
140				slog.Error("Failed to fetch providers in background from Catwalk", "error", err)
141				return
142			}
143			if len(providers) == 0 {
144				slog.Error("Empty providers list from Catwalk")
145				return
146			}
147			if err := saveProvidersInCache(path, providers); err != nil {
148				slog.Error("Failed to update providers.json in background", "error", err)
149			}
150		}()
151	}
152
153	switch {
154	case autoUpdateDisabled:
155		slog.Warn("Providers auto-update is disabled")
156
157		if cacheExists {
158			slog.Warn("Using locally cached providers")
159			return loadProvidersFromCache(path)
160		}
161
162		slog.Warn("Saving embedded providers to cache")
163		providers := embedded.GetAll()
164		if err := saveProvidersInCache(path, providers); err != nil {
165			return nil, err
166		}
167		return providers, nil
168
169	case cacheExists && !cacheIsStale:
170		slog.Info("Recent providers cache is available.", "path", path)
171
172		providers, err := loadProvidersFromCache(path)
173		if err != nil {
174			return nil, err
175		}
176		if len(providers) == 0 {
177			return catwalkGetAndSave()
178		}
179		backgroundCacheUpdate()
180		return providers, nil
181
182	default:
183		slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
184
185		providers, err := catwalkGetAndSave()
186		if err != nil {
187			catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
188			return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck
189		}
190		return providers, nil
191	}
192}
193
194func isCacheStale(path string) (stale, exists bool) {
195	info, err := os.Stat(path)
196	if err != nil {
197		return true, false
198	}
199	return time.Since(info.ModTime()) > 24*time.Hour, true
200}