provider.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"runtime"
 13	"slices"
 14	"strings"
 15	"sync"
 16	"time"
 17
 18	"charm.land/catwalk/pkg/catwalk"
 19	"charm.land/catwalk/pkg/embedded"
 20	"github.com/charmbracelet/crush/internal/agent/hyper"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/home"
 23	"github.com/charmbracelet/x/etag"
 24)
 25
 26type syncer[T any] interface {
 27	Get(context.Context) (T, error)
 28}
 29
 30var (
 31	providerOnce sync.Once
 32	providerList []catwalk.Provider
 33	providerErr  error
 34)
 35
 36// file to cache provider data
 37func cachePathFor(name string) string {
 38	xdgDataHome := os.Getenv("XDG_DATA_HOME")
 39	if xdgDataHome != "" {
 40		return filepath.Join(xdgDataHome, appName, name+".json")
 41	}
 42
 43	// return the path to the main data directory
 44	// for windows, it should be in `%LOCALAPPDATA%/crush/`
 45	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
 46	if runtime.GOOS == "windows" {
 47		localAppData := os.Getenv("LOCALAPPDATA")
 48		if localAppData == "" {
 49			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
 50		}
 51		return filepath.Join(localAppData, appName, name+".json")
 52	}
 53
 54	return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
 55}
 56
 57// UpdateProviders updates the Catwalk providers list from a specified source.
 58func UpdateProviders(pathOrURL string) error {
 59	var providers []catwalk.Provider
 60	pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 61
 62	switch {
 63	case pathOrURL == "embedded":
 64		providers = embedded.GetAll()
 65	case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
 66		var err error
 67		providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
 68		if err != nil {
 69			return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
 70		}
 71	default:
 72		content, err := os.ReadFile(pathOrURL)
 73		if err != nil {
 74			return fmt.Errorf("failed to read file: %w", err)
 75		}
 76		if err := json.Unmarshal(content, &providers); err != nil {
 77			return fmt.Errorf("failed to unmarshal provider data: %w", err)
 78		}
 79		if len(providers) == 0 {
 80			return fmt.Errorf("no providers found in the provided source")
 81		}
 82	}
 83
 84	if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
 85		return fmt.Errorf("failed to save providers to cache: %w", err)
 86	}
 87
 88	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
 89	return nil
 90}
 91
 92// UpdateHyper updates the Hyper provider information from a specified URL.
 93func UpdateHyper(pathOrURL string) error {
 94	var provider catwalk.Provider
 95	pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
 96
 97	switch {
 98	case pathOrURL == "embedded":
 99		provider = hyper.Embedded()
100	case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
101		client := realHyperClient{baseURL: pathOrURL}
102		var err error
103		provider, err = client.Get(context.Background(), "")
104		if err != nil {
105			return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
106		}
107	default:
108		content, err := os.ReadFile(pathOrURL)
109		if err != nil {
110			return fmt.Errorf("failed to read file: %w", err)
111		}
112		if err := json.Unmarshal(content, &provider); err != nil {
113			return fmt.Errorf("failed to unmarshal provider data: %w", err)
114		}
115	}
116
117	if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
118		return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
119	}
120
121	slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
122	return nil
123}
124
125var (
126	catwalkSyncer = &catwalkSync{}
127	hyperSyncer   = &hyperSync{}
128)
129
130// Providers returns the list of providers, taking into account cached results
131// and whether or not auto update is enabled.
132//
133// It will:
134// 1. if auto update is disabled, it'll return the embedded providers at the
135// time of release.
136// 2. load the cached providers
137// 3. try to get the fresh list of providers, and return either this new list,
138// the cached list, or the embedded list if all others fail.
139func Providers(cfg *Config) ([]catwalk.Provider, error) {
140	providerOnce.Do(func() {
141		var wg sync.WaitGroup
142		var errs []error
143		providers := csync.NewSlice[catwalk.Provider]()
144		autoupdate := !cfg.Options.DisableProviderAutoUpdate
145		customProvidersOnly := cfg.Options.DisableDefaultProviders
146
147		ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
148		defer cancel()
149
150		wg.Go(func() {
151			if customProvidersOnly {
152				return
153			}
154			catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
155			client := catwalk.NewWithURL(catwalkURL)
156			path := cachePathFor("providers")
157			catwalkSyncer.Init(client, path, autoupdate)
158
159			items, err := catwalkSyncer.Get(ctx)
160			if err != nil {
161				catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
162				errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, err)) //nolint:staticcheck
163				return
164			}
165			providers.Append(items...)
166		})
167
168		wg.Go(func() {
169			if customProvidersOnly {
170				return
171			}
172			path := cachePathFor("hyper")
173			hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
174
175			item, err := hyperSyncer.Get(ctx)
176			if err != nil {
177				errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
178				return
179			}
180			providers.Append(item)
181		})
182
183		wg.Wait()
184
185		providerList = slices.Collect(providers.Seq())
186		providerErr = errors.Join(errs...)
187	})
188	return providerList, providerErr
189}
190
191type cache[T any] struct {
192	path string
193}
194
195func newCache[T any](path string) cache[T] {
196	return cache[T]{path: path}
197}
198
199func (c cache[T]) Get() (T, string, error) {
200	var v T
201	data, err := os.ReadFile(c.path)
202	if err != nil {
203		return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
204	}
205
206	if err := json.Unmarshal(data, &v); err != nil {
207		return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
208	}
209
210	return v, etag.Of(data), nil
211}
212
213func (c cache[T]) Store(v T) error {
214	slog.Info("Saving provider data to disk", "path", c.path)
215	if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
216		return fmt.Errorf("failed to create directory for provider cache: %w", err)
217	}
218
219	data, err := json.Marshal(v)
220	if err != nil {
221		return fmt.Errorf("failed to marshal provider data: %w", err)
222	}
223
224	if err := os.WriteFile(c.path, data, 0o644); err != nil {
225		return fmt.Errorf("failed to write provider data to cache: %w", err)
226	}
227	return nil
228}