1package config
2
3import (
4 "cmp"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/catwalk/pkg/embedded"
17 "github.com/charmbracelet/crush/internal/home"
18)
19
20type ProviderClient interface {
21 GetProviders() ([]catwalk.Provider, error)
22}
23
24var (
25 providerOnce sync.Once
26 providerList []catwalk.Provider
27 providerErr error
28)
29
30// file to cache provider data
31func providerCacheFileData() string {
32 xdgDataHome := os.Getenv("XDG_DATA_HOME")
33 if xdgDataHome != "" {
34 return filepath.Join(xdgDataHome, appName, "providers.json")
35 }
36
37 // return the path to the main data directory
38 // for windows, it should be in `%LOCALAPPDATA%/crush/`
39 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
40 if runtime.GOOS == "windows" {
41 localAppData := os.Getenv("LOCALAPPDATA")
42 if localAppData == "" {
43 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
44 }
45 return filepath.Join(localAppData, appName, "providers.json")
46 }
47
48 return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
49}
50
51func saveProvidersInCache(path string, providers []catwalk.Provider) error {
52 slog.Info("Saving cached provider data", "path", path)
53 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
54 return fmt.Errorf("failed to create directory for provider cache: %w", err)
55 }
56
57 data, err := json.MarshalIndent(providers, "", " ")
58 if err != nil {
59 return fmt.Errorf("failed to marshal provider data: %w", err)
60 }
61
62 if err := os.WriteFile(path, data, 0o644); err != nil {
63 return fmt.Errorf("failed to write provider data to cache: %w", err)
64 }
65 return nil
66}
67
68func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
69 data, err := os.ReadFile(path)
70 if err != nil {
71 return nil, fmt.Errorf("failed to read provider cache file: %w", err)
72 }
73
74 var providers []catwalk.Provider
75 if err := json.Unmarshal(data, &providers); err != nil {
76 return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
77 }
78 return providers, nil
79}
80
81func UpdateProviders(pathOrUrl string) error {
82 var providers []catwalk.Provider
83 pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
84
85 switch {
86 case pathOrUrl == "embedded":
87 providers = embedded.GetAll()
88 case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
89 var err error
90 providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
91 if err != nil {
92 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
93 }
94 default:
95 content, err := os.ReadFile(pathOrUrl)
96 if err != nil {
97 return fmt.Errorf("failed to read file: %w", err)
98 }
99 if err := json.Unmarshal(content, &providers); err != nil {
100 return fmt.Errorf("failed to unmarshal provider data: %w", err)
101 }
102 if len(providers) == 0 {
103 return fmt.Errorf("no providers found in the provided source")
104 }
105 }
106
107 cachePath := providerCacheFileData()
108 if err := saveProvidersInCache(cachePath, providers); err != nil {
109 return fmt.Errorf("failed to save providers to cache: %w", err)
110 }
111
112 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
113 return nil
114}
115
116func Providers(cfg *Config) ([]catwalk.Provider, error) {
117 providerOnce.Do(func() {
118 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
119 client := catwalk.NewWithURL(catwalkURL)
120 path := providerCacheFileData()
121
122 autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
123 providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
124 })
125 return providerList, providerErr
126}
127
128func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
129 _, cacheExists := isCacheStale(path)
130
131 catwalkGetAndSave := func() ([]catwalk.Provider, error) {
132 providers, err := client.GetProviders()
133 if err != nil {
134 return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
135 }
136 if len(providers) == 0 {
137 return nil, fmt.Errorf("empty providers list from catwalk")
138 }
139 if err := saveProvidersInCache(path, providers); err != nil {
140 return nil, err
141 }
142 return providers, nil
143 }
144 switch {
145 case autoUpdateDisabled:
146 slog.Warn("Providers auto-update is disabled")
147
148 if cacheExists {
149 slog.Warn("Using locally cached providers")
150 return loadProvidersFromCache(path)
151 }
152
153 slog.Warn("Saving embedded providers to cache")
154 providers := embedded.GetAll()
155 if err := saveProvidersInCache(path, providers); err != nil {
156 return nil, err
157 }
158 return providers, nil
159
160 default:
161 slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
162
163 providers, err := catwalkGetAndSave()
164 if err != nil {
165 catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
166 return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck
167 }
168 return providers, nil
169 }
170}
171
172func isCacheStale(path string) (stale, exists bool) {
173 info, err := os.Stat(path)
174 if err != nil {
175 return true, false
176 }
177 return time.Since(info.ModTime()) > 24*time.Hour, true
178}