1package config
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "os"
11 "path/filepath"
12 "runtime"
13 "slices"
14 "strings"
15 "sync"
16 "time"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/catwalk/pkg/embedded"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/home"
23 "github.com/charmbracelet/x/etag"
24)
25
26type syncer[T any] interface {
27 Get(context.Context) (T, error)
28}
29
30var (
31 providerOnce sync.Once
32 providerList []catwalk.Provider
33 providerErr error
34)
35
36// file to cache provider data
37func cachePathFor(name string) string {
38 xdgDataHome := os.Getenv("XDG_DATA_HOME")
39 if xdgDataHome != "" {
40 return filepath.Join(xdgDataHome, appName, name+".json")
41 }
42
43 // return the path to the main data directory
44 // for windows, it should be in `%LOCALAPPDATA%/crush/`
45 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
46 if runtime.GOOS == "windows" {
47 localAppData := os.Getenv("LOCALAPPDATA")
48 if localAppData == "" {
49 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
50 }
51 return filepath.Join(localAppData, appName, name+".json")
52 }
53
54 return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
55}
56
57// UpdateProviders updates the Catwalk providers list from a specified source.
58func UpdateProviders(pathOrURL string) error {
59 var providers []catwalk.Provider
60 pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
61
62 switch {
63 case pathOrURL == "embedded":
64 providers = embedded.GetAll()
65 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
66 var err error
67 providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
68 if err != nil {
69 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
70 }
71 default:
72 content, err := os.ReadFile(pathOrURL)
73 if err != nil {
74 return fmt.Errorf("failed to read file: %w", err)
75 }
76 if err := json.Unmarshal(content, &providers); err != nil {
77 return fmt.Errorf("failed to unmarshal provider data: %w", err)
78 }
79 if len(providers) == 0 {
80 return fmt.Errorf("no providers found in the provided source")
81 }
82 }
83
84 if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
85 return fmt.Errorf("failed to save providers to cache: %w", err)
86 }
87
88 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
89 return nil
90}
91
92// UpdateHyper updates the Hyper provider information from a specified URL.
93func UpdateHyper(pathOrURL string) error {
94 var provider catwalk.Provider
95 pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
96
97 switch {
98 case pathOrURL == "embedded":
99 provider = hyper.Embedded()
100 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
101 client := realHyperClient{baseURL: pathOrURL}
102 var err error
103 provider, err = client.Get(context.Background(), "")
104 if err != nil {
105 return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
106 }
107 default:
108 content, err := os.ReadFile(pathOrURL)
109 if err != nil {
110 return fmt.Errorf("failed to read file: %w", err)
111 }
112 if err := json.Unmarshal(content, &provider); err != nil {
113 return fmt.Errorf("failed to unmarshal provider data: %w", err)
114 }
115 }
116
117 if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
118 return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
119 }
120
121 slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
122 return nil
123}
124
125var (
126 catwalkSyncer = &catwalkSync{}
127 hyperSyncer = &hyperSync{}
128)
129
130// Providers returns the list of providers, taking into account cached results
131// and whether or not auto update is enabled.
132//
133// It will:
134// 1. if auto update is disabled, it'll return the embedded providers at the
135// time of release.
136// 2. load the cached providers
137// 3. try to get the fresh list of providers, and return either this new list,
138// the cached list, or the embedded list if all others fail.
139func Providers(cfg *Config) ([]catwalk.Provider, error) {
140 providerOnce.Do(func() {
141 var wg sync.WaitGroup
142 var errs []error
143 providers := csync.NewSlice[catwalk.Provider]()
144 autoupdate := !cfg.Options.DisableProviderAutoUpdate
145 customProvidersOnly := cfg.Options.DisableDefaultProviders
146
147 ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
148 defer cancel()
149
150 wg.Go(func() {
151 if customProvidersOnly {
152 return
153 }
154 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
155 client := catwalk.NewWithURL(catwalkURL)
156 path := cachePathFor("providers")
157 catwalkSyncer.Init(client, path, autoupdate)
158
159 items, err := catwalkSyncer.Get(ctx)
160 if err != nil {
161 catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
162 errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, err)) //nolint:staticcheck
163 return
164 }
165 providers.Append(items...)
166 })
167
168 wg.Go(func() {
169 if customProvidersOnly {
170 return
171 }
172 path := cachePathFor("hyper")
173 hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
174
175 item, err := hyperSyncer.Get(ctx)
176 if err != nil {
177 errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
178 return
179 }
180 providers.Append(item)
181 })
182
183 wg.Wait()
184
185 providerList = slices.Collect(providers.Seq())
186 providerErr = errors.Join(errs...)
187 })
188 return providerList, providerErr
189}
190
191type cache[T any] struct {
192 path string
193}
194
195func newCache[T any](path string) cache[T] {
196 return cache[T]{path: path}
197}
198
199func (c cache[T]) Get() (T, string, error) {
200 var v T
201 data, err := os.ReadFile(c.path)
202 if err != nil {
203 return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
204 }
205
206 if err := json.Unmarshal(data, &v); err != nil {
207 return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
208 }
209
210 return v, etag.Of(data), nil
211}
212
213func (c cache[T]) Store(v T) error {
214 slog.Info("Saving provider data to disk", "path", c.path)
215 if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
216 return fmt.Errorf("failed to create directory for provider cache: %w", err)
217 }
218
219 data, err := json.Marshal(v)
220 if err != nil {
221 return fmt.Errorf("failed to marshal provider data: %w", err)
222 }
223
224 if err := os.WriteFile(c.path, data, 0o644); err != nil {
225 return fmt.Errorf("failed to write provider data to cache: %w", err)
226 }
227 return nil
228}