1package config
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "os"
11 "path/filepath"
12 "runtime"
13 "slices"
14 "strings"
15 "sync"
16 "time"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/catwalk/pkg/embedded"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/home"
23 "github.com/charmbracelet/x/etag"
24)
25
26type syncer[T any] interface {
27 Get(context.Context) (T, error)
28}
29
30var (
31 providerOnce sync.Once
32 providerList []catwalk.Provider
33 providerErr error
34)
35
36// file to cache provider data
37func cachePathFor(name string) string {
38 xdgDataHome := os.Getenv("XDG_DATA_HOME")
39 if xdgDataHome != "" {
40 return filepath.Join(xdgDataHome, appName, name+".json")
41 }
42
43 // return the path to the main data directory
44 // for windows, it should be in `%LOCALAPPDATA%/crush/`
45 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
46 if runtime.GOOS == "windows" {
47 localAppData := os.Getenv("LOCALAPPDATA")
48 if localAppData == "" {
49 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
50 }
51 return filepath.Join(localAppData, appName, name+".json")
52 }
53
54 return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
55}
56
57// UpdateProviders updates the Catwalk providers list from a specified source.
58func UpdateProviders(pathOrURL string) error {
59 var providers []catwalk.Provider
60 pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
61
62 switch {
63 case pathOrURL == "embedded":
64 providers = embedded.GetAll()
65 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
66 var err error
67 providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
68 if err != nil {
69 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
70 }
71 default:
72 content, err := os.ReadFile(pathOrURL)
73 if err != nil {
74 return fmt.Errorf("failed to read file: %w", err)
75 }
76 if err := json.Unmarshal(content, &providers); err != nil {
77 return fmt.Errorf("failed to unmarshal provider data: %w", err)
78 }
79 if len(providers) == 0 {
80 return fmt.Errorf("no providers found in the provided source")
81 }
82 }
83
84 if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
85 return fmt.Errorf("failed to save providers to cache: %w", err)
86 }
87
88 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
89 return nil
90}
91
92// UpdateHyper updates the Hyper provider information from a specified URL.
93func UpdateHyper(pathOrURL string) error {
94 var provider catwalk.Provider
95 pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
96
97 switch {
98 case pathOrURL == "embedded":
99 provider = hyper.Embedded()
100 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
101 client := realHyperClient{baseURL: pathOrURL}
102 var err error
103 provider, err = client.Get(context.Background(), "")
104 if err != nil {
105 return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
106 }
107 default:
108 content, err := os.ReadFile(pathOrURL)
109 if err != nil {
110 return fmt.Errorf("failed to read file: %w", err)
111 }
112 if err := json.Unmarshal(content, &provider); err != nil {
113 return fmt.Errorf("failed to unmarshal provider data: %w", err)
114 }
115 }
116
117 if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
118 return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
119 }
120
121 slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
122 return nil
123}
124
125var (
126 catwalkSyncer = &catwalkSync{}
127 hyperSyncer = &hyperSync{}
128)
129
130// Providers returns the list of providers, taking into account cached results
131// and whether or not auto update is enabled.
132//
133// It will:
134// 1. if auto update is disabled, it'll return the embedded providers at the
135// time of release.
136// 2. load the cached providers
137// 3. try to get the fresh list of providers, and return either this new list,
138// the cached list, or the embedded list if all others fail.
139func Providers(cfg *Config) ([]catwalk.Provider, error) {
140 providerOnce.Do(func() {
141 var wg sync.WaitGroup
142 var errs []error
143 providers := csync.NewSlice[catwalk.Provider]()
144 autoupdate := !cfg.Options.DisableProviderAutoUpdate
145 customProvidersOnly := cfg.Options.DisableDefaultProviders
146
147 ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
148 defer cancel()
149
150 var hyperProvider catwalk.Provider
151 var hyperFound bool
152
153 wg.Go(func() {
154 if customProvidersOnly {
155 return
156 }
157 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
158 client := catwalk.NewWithURL(catwalkURL)
159 path := cachePathFor("providers")
160 catwalkSyncer.Init(client, path, autoupdate)
161
162 items, err := catwalkSyncer.Get(ctx)
163 if err != nil {
164 catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
165 errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, err)) //nolint:staticcheck
166 return
167 }
168 providers.Append(items...)
169 })
170
171 wg.Go(func() {
172 if customProvidersOnly {
173 return
174 }
175 path := cachePathFor("hyper")
176 hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
177
178 item, err := hyperSyncer.Get(ctx)
179 if err != nil {
180 errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
181 return
182 }
183 hyperProvider = item
184 hyperFound = true
185 })
186
187 wg.Wait()
188
189 if hyperFound {
190 providerList = append([]catwalk.Provider{hyperProvider}, slices.Collect(providers.Seq())...)
191 } else {
192 providerList = slices.Collect(providers.Seq())
193 }
194 providerErr = errors.Join(errs...)
195 })
196 return providerList, providerErr
197}
198
199type cache[T any] struct {
200 path string
201}
202
203func newCache[T any](path string) cache[T] {
204 return cache[T]{path: path}
205}
206
207func (c cache[T]) Get() (T, string, error) {
208 var v T
209 data, err := os.ReadFile(c.path)
210 if err != nil {
211 return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
212 }
213
214 if err := json.Unmarshal(data, &v); err != nil {
215 return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
216 }
217
218 return v, etag.Of(data), nil
219}
220
221func (c cache[T]) Store(v T) error {
222 slog.Info("Saving provider data to disk", "path", c.path)
223 if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
224 return fmt.Errorf("failed to create directory for provider cache: %w", err)
225 }
226
227 data, err := json.Marshal(v)
228 if err != nil {
229 return fmt.Errorf("failed to marshal provider data: %w", err)
230 }
231
232 if err := os.WriteFile(c.path, data, 0o644); err != nil {
233 return fmt.Errorf("failed to write provider data to cache: %w", err)
234 }
235 return nil
236}