1package config
2
3import (
4 "cmp"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/catwalk/pkg/embedded"
17 "github.com/charmbracelet/crush/internal/home"
18)
19
20type ProviderClient interface {
21 GetProviders() ([]catwalk.Provider, error)
22}
23
24var (
25 providerOnce sync.Once
26 providerList []catwalk.Provider
27 providerErr error
28)
29
30// file to cache provider data
31func providerCacheFileData() string {
32 xdgDataHome := os.Getenv("XDG_DATA_HOME")
33 if xdgDataHome != "" {
34 return filepath.Join(xdgDataHome, appName, "providers.json")
35 }
36
37 // return the path to the main data directory
38 // for windows, it should be in `%LOCALAPPDATA%/crush/`
39 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
40 if runtime.GOOS == "windows" {
41 localAppData := os.Getenv("LOCALAPPDATA")
42 if localAppData == "" {
43 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
44 }
45 return filepath.Join(localAppData, appName, "providers.json")
46 }
47
48 return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
49}
50
51func saveProvidersInCache(path string, providers []catwalk.Provider) error {
52 slog.Info("Saving cached provider data", "path", path)
53 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
54 return fmt.Errorf("failed to create directory for provider cache: %w", err)
55 }
56
57 data, err := json.MarshalIndent(providers, "", " ")
58 if err != nil {
59 return fmt.Errorf("failed to marshal provider data: %w", err)
60 }
61
62 if err := os.WriteFile(path, data, 0o644); err != nil {
63 return fmt.Errorf("failed to write provider data to cache: %w", err)
64 }
65 return nil
66}
67
68func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
69 data, err := os.ReadFile(path)
70 if err != nil {
71 return nil, fmt.Errorf("failed to read provider cache file: %w", err)
72 }
73
74 var providers []catwalk.Provider
75 if err := json.Unmarshal(data, &providers); err != nil {
76 return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
77 }
78 return providers, nil
79}
80
81func UpdateProviders(pathOrUrl string) error {
82 var providers []catwalk.Provider
83 pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
84
85 switch {
86 case pathOrUrl == "embedded":
87 providers = embedded.GetAll()
88 case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
89 var err error
90 providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
91 if err != nil {
92 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
93 }
94 default:
95 content, err := os.ReadFile(pathOrUrl)
96 if err != nil {
97 return fmt.Errorf("failed to read file: %w", err)
98 }
99 if err := json.Unmarshal(content, &providers); err != nil {
100 return fmt.Errorf("failed to unmarshal provider data: %w", err)
101 }
102 if len(providers) == 0 {
103 return fmt.Errorf("no providers found in the provided source")
104 }
105 }
106
107 cachePath := providerCacheFileData()
108 if err := saveProvidersInCache(cachePath, providers); err != nil {
109 return fmt.Errorf("failed to save providers to cache: %w", err)
110 }
111
112 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
113 return nil
114}
115
116func Providers(cfg *Config) ([]catwalk.Provider, error) {
117 providerOnce.Do(func() {
118 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
119 client := catwalk.NewWithURL(catwalkURL)
120 path := providerCacheFileData()
121
122 autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
123 providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
124 })
125 return providerList, providerErr
126}
127
128func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
129 cacheIsStale, cacheExists := isCacheStale(path)
130
131 catwalkGetAndSave := func() ([]catwalk.Provider, error) {
132 providers, err := client.GetProviders()
133 if err != nil {
134 return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
135 }
136 if len(providers) == 0 {
137 return nil, fmt.Errorf("empty providers list from catwalk")
138 }
139 if err := saveProvidersInCache(path, providers); err != nil {
140 return nil, err
141 }
142 return providers, nil
143 }
144
145 backgroundCacheUpdate := func() {
146 go func() {
147 slog.Info("Updating providers cache in background", "path", path)
148
149 providers, err := client.GetProviders()
150 if err != nil {
151 slog.Error("Failed to fetch providers in background from Catwalk", "error", err)
152 return
153 }
154 if len(providers) == 0 {
155 slog.Error("Empty providers list from Catwalk")
156 return
157 }
158 if err := saveProvidersInCache(path, providers); err != nil {
159 slog.Error("Failed to update providers.json in background", "error", err)
160 }
161 }()
162 }
163
164 switch {
165 case autoUpdateDisabled:
166 slog.Warn("Providers auto-update is disabled")
167
168 if cacheExists {
169 slog.Warn("Using locally cached providers")
170 return loadProvidersFromCache(path)
171 }
172
173 slog.Warn("Saving embedded providers to cache")
174 providers := embedded.GetAll()
175 if err := saveProvidersInCache(path, providers); err != nil {
176 return nil, err
177 }
178 return providers, nil
179
180 case cacheExists && !cacheIsStale:
181 slog.Info("Recent providers cache is available.", "path", path)
182
183 providers, err := loadProvidersFromCache(path)
184 if err != nil {
185 return nil, err
186 }
187 if len(providers) == 0 {
188 return catwalkGetAndSave()
189 }
190 backgroundCacheUpdate()
191 return providers, nil
192
193 default:
194 slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
195
196 providers, err := catwalkGetAndSave()
197 if err != nil {
198 catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
199 return nil, fmt.Errorf("crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use embedded version from the time of this Crush release. %w", catwalkUrl, err)
200 }
201 return providers, nil
202 }
203}
204
205func isCacheStale(path string) (stale, exists bool) {
206 info, err := os.Stat(path)
207 if err != nil {
208 return true, false
209 }
210 return time.Since(info.ModTime()) > 24*time.Hour, true
211}