provider.go

  1package config
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"runtime"
 13	"slices"
 14	"strings"
 15	"sync"
 16	"time"
 17
 18	"github.com/charmbracelet/catwalk/pkg/catwalk"
 19	"github.com/charmbracelet/catwalk/pkg/embedded"
 20	"github.com/charmbracelet/crush/internal/agent/hyper"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/home"
 23	"github.com/charmbracelet/x/etag"
 24)
 25
 26type syncer[T any] interface {
 27	Get(context.Context) (T, error)
 28}
 29
 30var (
 31	providerOnce sync.Once
 32	providerList []catwalk.Provider
 33	providerErr  error
 34)
 35
 36// file to cache provider data
 37func cachePathFor(name string) string {
 38	xdgDataHome := os.Getenv("XDG_DATA_HOME")
 39	if xdgDataHome != "" {
 40		return filepath.Join(xdgDataHome, appName, name+".json")
 41	}
 42
 43	// return the path to the main data directory
 44	// for windows, it should be in `%LOCALAPPDATA%/crush/`
 45	// for linux and macOS, it should be in `$HOME/.local/share/crush/`
 46	if runtime.GOOS == "windows" {
 47		localAppData := os.Getenv("LOCALAPPDATA")
 48		if localAppData == "" {
 49			localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
 50		}
 51		return filepath.Join(localAppData, appName, name+".json")
 52	}
 53
 54	return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
 55}
 56
 57// UpdateProviders updates the Catwalk providers list from a specified source.
 58func UpdateProviders(pathOrURL string) error {
 59	var providers []catwalk.Provider
 60	pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
 61
 62	switch {
 63	case pathOrURL == "embedded":
 64		providers = embedded.GetAll()
 65	case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
 66		var err error
 67		providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
 68		if err != nil {
 69			return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
 70		}
 71	default:
 72		content, err := os.ReadFile(pathOrURL)
 73		if err != nil {
 74			return fmt.Errorf("failed to read file: %w", err)
 75		}
 76		if err := json.Unmarshal(content, &providers); err != nil {
 77			return fmt.Errorf("failed to unmarshal provider data: %w", err)
 78		}
 79		if len(providers) == 0 {
 80			return fmt.Errorf("no providers found in the provided source")
 81		}
 82	}
 83
 84	if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
 85		return fmt.Errorf("failed to save providers to cache: %w", err)
 86	}
 87
 88	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
 89	return nil
 90}
 91
 92// UpdateHyper updates the Hyper provider information from a specified URL.
 93func UpdateHyper(pathOrURL string) error {
 94	if !hyper.Enabled() {
 95		return fmt.Errorf("hyper not enabled")
 96	}
 97	var provider catwalk.Provider
 98	pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
 99
100	switch {
101	case pathOrURL == "embedded":
102		provider = hyper.Embedded()
103	case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
104		client := realHyperClient{baseURL: pathOrURL}
105		var err error
106		provider, err = client.Get(context.Background(), "")
107		if err != nil {
108			return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
109		}
110	default:
111		content, err := os.ReadFile(pathOrURL)
112		if err != nil {
113			return fmt.Errorf("failed to read file: %w", err)
114		}
115		if err := json.Unmarshal(content, &provider); err != nil {
116			return fmt.Errorf("failed to unmarshal provider data: %w", err)
117		}
118	}
119
120	if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
121		return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
122	}
123
124	slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
125	return nil
126}
127
128var (
129	catwalkSyncer = &catwalkSync{}
130	hyperSyncer   = &hyperSync{}
131)
132
133// Providers returns the list of providers, taking into account cached results
134// and whether or not auto update is enabled.
135//
136// It will:
137// 1. if auto update is disabled, it'll return the embedded providers at the
138// time of release.
139// 2. load the cached providers
140// 3. try to get the fresh list of providers, and return either this new list,
141// the cached list, or the embedded list if all others fail.
142func Providers(cfg *Config) ([]catwalk.Provider, error) {
143	providerOnce.Do(func() {
144		var wg sync.WaitGroup
145		var errs []error
146		providers := csync.NewSlice[catwalk.Provider]()
147		autoupdate := !cfg.Options.DisableProviderAutoUpdate
148
149		ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
150		defer cancel()
151
152		wg.Go(func() {
153			catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
154			client := catwalk.NewWithURL(catwalkURL)
155			path := cachePathFor("providers")
156			catwalkSyncer.Init(client, path, autoupdate)
157
158			items, err := catwalkSyncer.Get(ctx)
159			if err != nil {
160				catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
161				errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr)) //nolint:staticcheck
162				return
163			}
164			providers.Append(items...)
165		})
166
167		wg.Go(func() {
168			if !hyper.Enabled() {
169				return
170			}
171			path := cachePathFor("hyper")
172			hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
173
174			item, err := hyperSyncer.Get(ctx)
175			if err != nil {
176				errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
177				return
178			}
179			providers.Append(item)
180		})
181
182		wg.Wait()
183
184		providerList = slices.Collect(providers.Seq())
185		providerErr = errors.Join(errs...)
186	})
187	return providerList, providerErr
188}
189
190type cache[T any] struct {
191	path string
192}
193
194func newCache[T any](path string) cache[T] {
195	return cache[T]{path: path}
196}
197
198func (c cache[T]) Get() (T, string, error) {
199	var v T
200	data, err := os.ReadFile(c.path)
201	if err != nil {
202		return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
203	}
204
205	if err := json.Unmarshal(data, &v); err != nil {
206		return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
207	}
208
209	return v, etag.Of(data), nil
210}
211
212func (c cache[T]) Store(v T) error {
213	slog.Info("Saving provider data to disk", "path", c.path)
214	if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
215		return fmt.Errorf("failed to create directory for provider cache: %w", err)
216	}
217
218	data, err := json.Marshal(v)
219	if err != nil {
220		return fmt.Errorf("failed to marshal provider data: %w", err)
221	}
222
223	if err := os.WriteFile(c.path, data, 0o644); err != nil {
224		return fmt.Errorf("failed to write provider data to cache: %w", err)
225	}
226	return nil
227}