1package config
2
3import (
4 "cmp"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/catwalk/pkg/embedded"
16 "github.com/charmbracelet/crush/internal/home"
17 "github.com/charmbracelet/crush/internal/version"
18)
19
20type ProviderClient interface {
21 GetProviders() ([]catwalk.Provider, error)
22}
23
24// file to cache provider data
25func providerCacheFileData() string {
26 xdgDataHome := os.Getenv("XDG_DATA_HOME")
27 if xdgDataHome != "" {
28 return filepath.Join(xdgDataHome, version.AppName, "providers.json")
29 }
30
31 // return the path to the main data directory
32 // for windows, it should be in `%LOCALAPPDATA%/crush/`
33 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
34 if runtime.GOOS == "windows" {
35 localAppData := os.Getenv("LOCALAPPDATA")
36 if localAppData == "" {
37 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
38 }
39 return filepath.Join(localAppData, version.AppName, "providers.json")
40 }
41
42 return filepath.Join(home.Dir(), ".local", "share", version.AppName, "providers.json")
43}
44
45func saveProvidersInCache(path string, providers []catwalk.Provider) error {
46 slog.Info("Saving cached provider data", "path", path)
47 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
48 return fmt.Errorf("failed to create directory for provider cache: %w", err)
49 }
50
51 data, err := json.MarshalIndent(providers, "", " ")
52 if err != nil {
53 return fmt.Errorf("failed to marshal provider data: %w", err)
54 }
55
56 if err := os.WriteFile(path, data, 0o644); err != nil {
57 return fmt.Errorf("failed to write provider data to cache: %w", err)
58 }
59 return nil
60}
61
62func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
63 data, err := os.ReadFile(path)
64 if err != nil {
65 return nil, fmt.Errorf("failed to read provider cache file: %w", err)
66 }
67
68 var providers []catwalk.Provider
69 if err := json.Unmarshal(data, &providers); err != nil {
70 return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
71 }
72 return providers, nil
73}
74
75func UpdateProviders(pathOrUrl string) error {
76 var providers []catwalk.Provider
77 pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
78
79 switch {
80 case pathOrUrl == "embedded":
81 providers = embedded.GetAll()
82 case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
83 var err error
84 providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
85 if err != nil {
86 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
87 }
88 default:
89 content, err := os.ReadFile(pathOrUrl)
90 if err != nil {
91 return fmt.Errorf("failed to read file: %w", err)
92 }
93 if err := json.Unmarshal(content, &providers); err != nil {
94 return fmt.Errorf("failed to unmarshal provider data: %w", err)
95 }
96 if len(providers) == 0 {
97 return fmt.Errorf("no providers found in the provided source")
98 }
99 }
100
101 cachePath := providerCacheFileData()
102 if err := saveProvidersInCache(cachePath, providers); err != nil {
103 return fmt.Errorf("failed to save providers to cache: %w", err)
104 }
105
106 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
107 return nil
108}
109
110func Providers(cfg *Config) ([]catwalk.Provider, error) {
111 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
112 client := catwalk.NewWithURL(catwalkURL)
113 path := providerCacheFileData()
114 return loadProviders(cfg.Options.DisableProviderAutoUpdate, client, path)
115}
116
117func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
118 cacheIsStale, cacheExists := isCacheStale(path)
119
120 catwalkGetAndSave := func() ([]catwalk.Provider, error) {
121 providers, err := client.GetProviders()
122 if err != nil {
123 return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
124 }
125 if len(providers) == 0 {
126 return nil, fmt.Errorf("empty providers list from catwalk")
127 }
128 if err := saveProvidersInCache(path, providers); err != nil {
129 return nil, err
130 }
131 return providers, nil
132 }
133
134 backgroundCacheUpdate := func() {
135 go func() {
136 slog.Info("Updating providers cache in background", "path", path)
137
138 providers, err := client.GetProviders()
139 if err != nil {
140 slog.Error("Failed to fetch providers in background from Catwalk", "error", err)
141 return
142 }
143 if len(providers) == 0 {
144 slog.Error("Empty providers list from Catwalk")
145 return
146 }
147 if err := saveProvidersInCache(path, providers); err != nil {
148 slog.Error("Failed to update providers.json in background", "error", err)
149 }
150 }()
151 }
152
153 switch {
154 case autoUpdateDisabled:
155 slog.Warn("Providers auto-update is disabled")
156
157 if cacheExists {
158 slog.Warn("Using locally cached providers")
159 return loadProvidersFromCache(path)
160 }
161
162 slog.Warn("Saving embedded providers to cache")
163 providers := embedded.GetAll()
164 if err := saveProvidersInCache(path, providers); err != nil {
165 return nil, err
166 }
167 return providers, nil
168
169 case cacheExists && !cacheIsStale:
170 slog.Info("Recent providers cache is available.", "path", path)
171
172 providers, err := loadProvidersFromCache(path)
173 if err != nil {
174 return nil, err
175 }
176 if len(providers) == 0 {
177 return catwalkGetAndSave()
178 }
179 backgroundCacheUpdate()
180 return providers, nil
181
182 default:
183 slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
184
185 providers, err := catwalkGetAndSave()
186 if err != nil {
187 catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
188 return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck
189 }
190 return providers, nil
191 }
192}
193
194func isCacheStale(path string) (stale, exists bool) {
195 info, err := os.Stat(path)
196 if err != nil {
197 return true, false
198 }
199 return time.Since(info.ModTime()) > 24*time.Hour, true
200}