1package config
2
3import (
4 "cmp"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "sync"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/catwalk/pkg/embedded"
16 "github.com/charmbracelet/crush/internal/home"
17)
18
19type ProviderClient interface {
20 GetProviders() ([]catwalk.Provider, error)
21}
22
23var (
24 providerOnce sync.Once
25 providerList []catwalk.Provider
26 providerErr error
27)
28
29// file to cache provider data
30func providerCacheFileData() string {
31 xdgDataHome := os.Getenv("XDG_DATA_HOME")
32 if xdgDataHome != "" {
33 return filepath.Join(xdgDataHome, appName, "providers.json")
34 }
35
36 // return the path to the main data directory
37 // for windows, it should be in `%LOCALAPPDATA%/crush/`
38 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
39 if runtime.GOOS == "windows" {
40 localAppData := os.Getenv("LOCALAPPDATA")
41 if localAppData == "" {
42 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
43 }
44 return filepath.Join(localAppData, appName, "providers.json")
45 }
46
47 return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
48}
49
50func saveProvidersInCache(path string, providers []catwalk.Provider) error {
51 slog.Info("Saving cached provider data", "path", path)
52 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
53 return fmt.Errorf("failed to create directory for provider cache: %w", err)
54 }
55
56 data, err := json.MarshalIndent(providers, "", " ")
57 if err != nil {
58 return fmt.Errorf("failed to marshal provider data: %w", err)
59 }
60
61 if err := os.WriteFile(path, data, 0o644); err != nil {
62 return fmt.Errorf("failed to write provider data to cache: %w", err)
63 }
64 return nil
65}
66
67func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
68 data, err := os.ReadFile(path)
69 if err != nil {
70 return nil, fmt.Errorf("failed to read provider cache file: %w", err)
71 }
72
73 var providers []catwalk.Provider
74 if err := json.Unmarshal(data, &providers); err != nil {
75 return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
76 }
77 return providers, nil
78}
79
80func Providers(cfg *Config) ([]catwalk.Provider, error) {
81 providerOnce.Do(func() {
82 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
83 client := catwalk.NewWithURL(catwalkURL)
84 path := providerCacheFileData()
85
86 autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
87 providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
88 })
89 return providerList, providerErr
90}
91
92func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
93 cacheIsStale, cacheExists := isCacheStale(path)
94
95 catwalkGetAndSave := func() ([]catwalk.Provider, error) {
96 providers, err := client.GetProviders()
97 if err != nil {
98 return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
99 }
100 if len(providers) == 0 {
101 return nil, fmt.Errorf("empty providers list from catwalk")
102 }
103 if err := saveProvidersInCache(path, providers); err != nil {
104 return nil, err
105 }
106 return providers, nil
107 }
108
109 backgroundCacheUpdate := func() {
110 go func() {
111 slog.Info("Updating providers cache in background", "path", path)
112
113 providers, err := client.GetProviders()
114 if err != nil {
115 slog.Error("Failed to fetch providers in background from Catwalk", "error", err)
116 return
117 }
118 if len(providers) == 0 {
119 slog.Error("Empty providers list from Catwalk")
120 return
121 }
122 if err := saveProvidersInCache(path, providers); err != nil {
123 slog.Error("Failed to update providers.json in background", "error", err)
124 }
125 }()
126 }
127
128 switch {
129 case autoUpdateDisabled:
130 slog.Warn("Providers auto-update is disabled")
131
132 if cacheExists {
133 slog.Warn("Using locally cached providers")
134 return loadProvidersFromCache(path)
135 }
136
137 slog.Warn("Saving embedded providers to cache")
138 providers := embedded.GetAll()
139 if err := saveProvidersInCache(path, providers); err != nil {
140 return nil, err
141 }
142 return providers, nil
143
144 case cacheExists && !cacheIsStale:
145 slog.Info("Recent providers cache is available.", "path", path)
146
147 providers, err := loadProvidersFromCache(path)
148 if err != nil {
149 return nil, err
150 }
151 if len(providers) == 0 {
152 return catwalkGetAndSave()
153 }
154 backgroundCacheUpdate()
155 return providers, nil
156
157 default:
158 slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
159
160 providers, err := catwalkGetAndSave()
161 if err != nil {
162 catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
163 return nil, fmt.Errorf("crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use embedded version from the time of this Crush release. %w", catwalkUrl, err)
164 }
165 return providers, nil
166 }
167}
168
169func isCacheStale(path string) (stale, exists bool) {
170 info, err := os.Stat(path)
171 if err != nil {
172 return true, false
173 }
174 return time.Since(info.ModTime()) > 24*time.Hour, true
175}