1package config
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "os"
11 "path/filepath"
12 "runtime"
13 "slices"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/charmbracelet/catwalk/pkg/catwalk"
19 "github.com/charmbracelet/catwalk/pkg/embedded"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/home"
23 "github.com/charmbracelet/x/etag"
24)
25
26type syncer[T any] interface {
27 Get(context.Context) (T, error)
28}
29
30var (
31 providerOnce sync.Once
32 providerList []catwalk.Provider
33 providerErr error
34)
35
36// file to cache provider data
37func cachePathFor(name string) string {
38 xdgDataHome := os.Getenv("XDG_DATA_HOME")
39 if xdgDataHome != "" {
40 return filepath.Join(xdgDataHome, appName, name+".json")
41 }
42
43 // return the path to the main data directory
44 // for windows, it should be in `%LOCALAPPDATA%/crush/`
45 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
46 if runtime.GOOS == "windows" {
47 localAppData := os.Getenv("LOCALAPPDATA")
48 if localAppData == "" {
49 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
50 }
51 return filepath.Join(localAppData, appName, name+".json")
52 }
53
54 return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
55}
56
57// UpdateProviders updates the Catwalk providers list from a specified source.
58func UpdateProviders(pathOrURL string) error {
59 var providers []catwalk.Provider
60 pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
61
62 switch {
63 case pathOrURL == "embedded":
64 providers = embedded.GetAll()
65 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
66 var err error
67 providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
68 if err != nil {
69 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
70 }
71 default:
72 content, err := os.ReadFile(pathOrURL)
73 if err != nil {
74 return fmt.Errorf("failed to read file: %w", err)
75 }
76 if err := json.Unmarshal(content, &providers); err != nil {
77 return fmt.Errorf("failed to unmarshal provider data: %w", err)
78 }
79 if len(providers) == 0 {
80 return fmt.Errorf("no providers found in the provided source")
81 }
82 }
83
84 if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
85 return fmt.Errorf("failed to save providers to cache: %w", err)
86 }
87
88 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
89 return nil
90}
91
92// UpdateHyper updates the Hyper provider information from a specified URL.
93func UpdateHyper(pathOrURL string) error {
94 if !hyper.Enabled() {
95 return fmt.Errorf("hyper not enabled")
96 }
97 var provider catwalk.Provider
98 pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
99
100 switch {
101 case pathOrURL == "embedded":
102 provider = hyper.Embedded()
103 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
104 client := realHyperClient{baseURL: pathOrURL}
105 var err error
106 provider, err = client.Get(context.Background(), "")
107 if err != nil {
108 return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
109 }
110 default:
111 content, err := os.ReadFile(pathOrURL)
112 if err != nil {
113 return fmt.Errorf("failed to read file: %w", err)
114 }
115 if err := json.Unmarshal(content, &provider); err != nil {
116 return fmt.Errorf("failed to unmarshal provider data: %w", err)
117 }
118 }
119
120 if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
121 return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
122 }
123
124 slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
125 return nil
126}
127
128var (
129 catwalkSyncer = &catwalkSync{}
130 hyperSyncer = &hyperSync{}
131)
132
133// Providers returns the list of providers, taking into account cached results
134// and whether or not auto update is enabled.
135//
136// It will:
137// 1. if auto update is disabled, it'll return the embedded providers at the
138// time of release.
139// 2. load the cached providers
140// 3. try to get the fresh list of providers, and return either this new list,
141// the cached list, or the embedded list if all others fail.
142func Providers(cfg *Config) ([]catwalk.Provider, error) {
143 providerOnce.Do(func() {
144 var wg sync.WaitGroup
145 var errs []error
146 providers := csync.NewSlice[catwalk.Provider]()
147 autoupdate := !cfg.Options.DisableProviderAutoUpdate
148
149 ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
150 defer cancel()
151
152 wg.Go(func() {
153 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
154 client := catwalk.NewWithURL(catwalkURL)
155 path := cachePathFor("providers")
156 catwalkSyncer.Init(client, path, autoupdate)
157
158 items, err := catwalkSyncer.Get(ctx)
159 if err != nil {
160 catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
161 errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, providerErr)) //nolint:staticcheck
162 return
163 }
164 providers.Append(items...)
165 })
166
167 wg.Go(func() {
168 if !hyper.Enabled() {
169 return
170 }
171 path := cachePathFor("hyper")
172 hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
173
174 item, err := hyperSyncer.Get(ctx)
175 if err != nil {
176 errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
177 return
178 }
179 providers.Append(item)
180 })
181
182 wg.Wait()
183
184 providerList = slices.Collect(providers.Seq())
185 providerErr = errors.Join(errs...)
186 })
187 return providerList, providerErr
188}
189
190type cache[T any] struct {
191 path string
192}
193
194func newCache[T any](path string) cache[T] {
195 return cache[T]{path: path}
196}
197
198func (c cache[T]) Get() (T, string, error) {
199 var v T
200 data, err := os.ReadFile(c.path)
201 if err != nil {
202 return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
203 }
204
205 if err := json.Unmarshal(data, &v); err != nil {
206 return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
207 }
208
209 return v, etag.Of(data), nil
210}
211
212func (c cache[T]) Store(v T) error {
213 slog.Info("Saving provider data to disk", "path", c.path)
214 if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
215 return fmt.Errorf("failed to create directory for provider cache: %w", err)
216 }
217
218 data, err := json.Marshal(v)
219 if err != nil {
220 return fmt.Errorf("failed to marshal provider data: %w", err)
221 }
222
223 if err := os.WriteFile(c.path, data, 0o644); err != nil {
224 return fmt.Errorf("failed to write provider data to cache: %w", err)
225 }
226 return nil
227}