1package config
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "os"
11 "path/filepath"
12 "runtime"
13 "strings"
14 "sync"
15
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/catwalk/pkg/embedded"
18 "github.com/charmbracelet/crush/internal/home"
19)
20
21type ProviderClient interface {
22 GetProviders(context.Context, string) ([]catwalk.Provider, error)
23}
24
25var (
26 providerOnce sync.Once
27 providerList []catwalk.Provider
28 providerErr error
29)
30
31// file to cache provider data
32func providerCacheFileData() string {
33 xdgDataHome := os.Getenv("XDG_DATA_HOME")
34 if xdgDataHome != "" {
35 return filepath.Join(xdgDataHome, appName, "providers.json")
36 }
37
38 // return the path to the main data directory
39 // for windows, it should be in `%LOCALAPPDATA%/crush/`
40 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
41 if runtime.GOOS == "windows" {
42 localAppData := os.Getenv("LOCALAPPDATA")
43 if localAppData == "" {
44 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
45 }
46 return filepath.Join(localAppData, appName, "providers.json")
47 }
48
49 return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
50}
51
52func saveProvidersInCache(path string, providers []catwalk.Provider) error {
53 slog.Info("Saving provider data to disk", "path", path)
54 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
55 return fmt.Errorf("failed to create directory for provider cache: %w", err)
56 }
57
58 data, err := json.Marshal(providers)
59 if err != nil {
60 return fmt.Errorf("failed to marshal provider data: %w", err)
61 }
62
63 if err := os.WriteFile(path, data, 0o644); err != nil {
64 return fmt.Errorf("failed to write provider data to cache: %w", err)
65 }
66 return nil
67}
68
69func loadProvidersFromCache(path string) ([]catwalk.Provider, string, error) {
70 data, err := os.ReadFile(path)
71 if err != nil {
72 return nil, "", fmt.Errorf("failed to read provider cache file: %w", err)
73 }
74
75 var providers []catwalk.Provider
76 if err := json.Unmarshal(data, &providers); err != nil {
77 return nil, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
78 }
79
80 return providers, catwalk.Etag(data), nil
81}
82
83func UpdateProviders(pathOrURL string) error {
84 var providers []catwalk.Provider
85 pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
86
87 switch {
88 case pathOrURL == "embedded":
89 providers = embedded.GetAll()
90 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
91 var err error
92 providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
93 if err != nil {
94 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
95 }
96 default:
97 content, err := os.ReadFile(pathOrURL)
98 if err != nil {
99 return fmt.Errorf("failed to read file: %w", err)
100 }
101 if err := json.Unmarshal(content, &providers); err != nil {
102 return fmt.Errorf("failed to unmarshal provider data: %w", err)
103 }
104 if len(providers) == 0 {
105 return fmt.Errorf("no providers found in the provided source")
106 }
107 }
108
109 cachePath := providerCacheFileData()
110 if err := saveProvidersInCache(cachePath, providers); err != nil {
111 return fmt.Errorf("failed to save providers to cache: %w", err)
112 }
113
114 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePath)
115 return nil
116}
117
118// Providers returns the list of providers, taking into account cached results
119// and whether or not auto update is enabled.
120//
121// It will:
122// 1. if auto update is disabled, it'll return the embedded providers at the
123// time of release.
124// 2. load the cached providers
125// 3. try to get the fresh list of providers, and return either this new list,
126// the cached list, or the embedded list if all others fail.
127func Providers(cfg *Config) ([]catwalk.Provider, error) {
128 providerOnce.Do(func() {
129 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
130 client := catwalk.NewWithURL(catwalkURL)
131 path := providerCacheFileData()
132
133 if cfg.Options.DisableProviderAutoUpdate {
134 slog.Info("Using embedded Catwalk providers")
135 providerList, providerErr = embedded.GetAll(), nil
136 return
137 }
138
139 cached, etag, cachedErr := loadProvidersFromCache(path)
140 if len(cached) == 0 || cachedErr != nil {
141 // if cached file is empty, default to embedded providers
142 cached = embedded.GetAll()
143 }
144
145 providerList, providerErr = loadProviders(client, etag, path)
146 if errors.Is(providerErr, catwalk.ErrNotModified) {
147 slog.Info("Catwalk providers not modified")
148 providerList, providerErr = cached, nil
149 }
150 })
151 if providerErr != nil {
152 catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
153 return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr) //nolint:staticcheck
154 }
155 return providerList, nil
156}
157
158func loadProviders(client ProviderClient, etag, path string) ([]catwalk.Provider, error) {
159 slog.Info("Fetching providers from Catwalk.", "path", path)
160 providers, err := client.GetProviders(context.Background(), etag)
161 if err != nil {
162 return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
163 }
164 if len(providers) == 0 {
165 return nil, errors.New("empty providers list from catwalk")
166 }
167 if err := saveProvidersInCache(path, providers); err != nil {
168 return nil, err
169 }
170 return providers, nil
171}