1package config
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "os"
11 "path/filepath"
12 "runtime"
13 "slices"
14 "strings"
15 "sync"
16 "time"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/catwalk/pkg/embedded"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/home"
23 "github.com/charmbracelet/x/etag"
24)
25
26type syncer[T any] interface {
27 Get(context.Context) (T, error)
28}
29
30var (
31 providerOnce sync.Once
32 providerList []catwalk.Provider
33 providerErr error
34)
35
36// file to cache provider data
37func cachePathFor(name string) string {
38 xdgDataHome := os.Getenv("XDG_DATA_HOME")
39 if xdgDataHome != "" {
40 return filepath.Join(xdgDataHome, appName, name+".json")
41 }
42
43 // return the path to the main data directory
44 // for windows, it should be in `%LOCALAPPDATA%/crush/`
45 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
46 if runtime.GOOS == "windows" {
47 localAppData := os.Getenv("LOCALAPPDATA")
48 if localAppData == "" {
49 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
50 }
51 return filepath.Join(localAppData, appName, name+".json")
52 }
53
54 return filepath.Join(home.Dir(), ".local", "share", appName, name+".json")
55}
56
57// UpdateProviders updates the Catwalk providers list from a specified source.
58func UpdateProviders(pathOrURL string) error {
59 var providers []catwalk.Provider
60 pathOrURL = cmp.Or(pathOrURL, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
61
62 switch {
63 case pathOrURL == "embedded":
64 providers = embedded.GetAll()
65 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
66 var err error
67 providers, err = catwalk.NewWithURL(pathOrURL).GetProviders(context.Background(), "")
68 if err != nil {
69 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
70 }
71 default:
72 content, err := os.ReadFile(pathOrURL)
73 if err != nil {
74 return fmt.Errorf("failed to read file: %w", err)
75 }
76 if err := json.Unmarshal(content, &providers); err != nil {
77 return fmt.Errorf("failed to unmarshal provider data: %w", err)
78 }
79 if len(providers) == 0 {
80 return fmt.Errorf("no providers found in the provided source")
81 }
82 }
83
84 if err := newCache[[]catwalk.Provider](cachePathFor("providers")).Store(providers); err != nil {
85 return fmt.Errorf("failed to save providers to cache: %w", err)
86 }
87
88 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor)
89 return nil
90}
91
92// UpdateHyper updates the Hyper provider information from a specified URL.
93func UpdateHyper(pathOrURL string) error {
94 if !hyper.Enabled() {
95 return fmt.Errorf("hyper not enabled")
96 }
97 var provider catwalk.Provider
98 pathOrURL = cmp.Or(pathOrURL, hyper.BaseURL())
99
100 switch {
101 case pathOrURL == "embedded":
102 provider = hyper.Embedded()
103 case strings.HasPrefix(pathOrURL, "http://") || strings.HasPrefix(pathOrURL, "https://"):
104 client := realHyperClient{baseURL: pathOrURL}
105 var err error
106 provider, err = client.Get(context.Background(), "")
107 if err != nil {
108 return fmt.Errorf("failed to fetch provider from Hyper: %w", err)
109 }
110 default:
111 content, err := os.ReadFile(pathOrURL)
112 if err != nil {
113 return fmt.Errorf("failed to read file: %w", err)
114 }
115 if err := json.Unmarshal(content, &provider); err != nil {
116 return fmt.Errorf("failed to unmarshal provider data: %w", err)
117 }
118 }
119
120 if err := newCache[catwalk.Provider](cachePathFor("hyper")).Store(provider); err != nil {
121 return fmt.Errorf("failed to save Hyper provider to cache: %w", err)
122 }
123
124 slog.Info("Hyper provider updated successfully", "from", pathOrURL, "to", cachePathFor("hyper"))
125 return nil
126}
127
128var (
129 catwalkSyncer = &catwalkSync{}
130 hyperSyncer = &hyperSync{}
131)
132
133// Providers returns the list of providers, taking into account cached results
134// and whether or not auto update is enabled.
135//
136// It will:
137// 1. if auto update is disabled, it'll return the embedded providers at the
138// time of release.
139// 2. load the cached providers
140// 3. try to get the fresh list of providers, and return either this new list,
141// the cached list, or the embedded list if all others fail.
142func Providers(cfg *Config) ([]catwalk.Provider, error) {
143 providerOnce.Do(func() {
144 var wg sync.WaitGroup
145 var errs []error
146 providers := csync.NewSlice[catwalk.Provider]()
147 autoupdate := !cfg.Options.DisableProviderAutoUpdate
148 customProvidersOnly := cfg.Options.DisableDefaultProviders
149
150 ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
151 defer cancel()
152
153 wg.Go(func() {
154 if customProvidersOnly {
155 return
156 }
157 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
158 client := catwalk.NewWithURL(catwalkURL)
159 path := cachePathFor("providers")
160 catwalkSyncer.Init(client, path, autoupdate)
161
162 items, err := catwalkSyncer.Get(ctx)
163 if err != nil {
164 catwalkURL := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
165 errs = append(errs, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help.\n\nCause: %w", catwalkURL, err)) //nolint:staticcheck
166 return
167 }
168 providers.Append(items...)
169 })
170
171 wg.Go(func() {
172 if customProvidersOnly || !hyper.Enabled() {
173 return
174 }
175 path := cachePathFor("hyper")
176 hyperSyncer.Init(realHyperClient{baseURL: hyper.BaseURL()}, path, autoupdate)
177
178 item, err := hyperSyncer.Get(ctx)
179 if err != nil {
180 errs = append(errs, fmt.Errorf("Crush was unable to fetch updated information from Hyper: %w", err)) //nolint:staticcheck
181 return
182 }
183 providers.Append(item)
184 })
185
186 wg.Wait()
187
188 providerList = slices.Collect(providers.Seq())
189 providerErr = errors.Join(errs...)
190 })
191 return providerList, providerErr
192}
193
194type cache[T any] struct {
195 path string
196}
197
198func newCache[T any](path string) cache[T] {
199 return cache[T]{path: path}
200}
201
202func (c cache[T]) Get() (T, string, error) {
203 var v T
204 data, err := os.ReadFile(c.path)
205 if err != nil {
206 return v, "", fmt.Errorf("failed to read provider cache file: %w", err)
207 }
208
209 if err := json.Unmarshal(data, &v); err != nil {
210 return v, "", fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
211 }
212
213 return v, etag.Of(data), nil
214}
215
216func (c cache[T]) Store(v T) error {
217 slog.Info("Saving provider data to disk", "path", c.path)
218 if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
219 return fmt.Errorf("failed to create directory for provider cache: %w", err)
220 }
221
222 data, err := json.Marshal(v)
223 if err != nil {
224 return fmt.Errorf("failed to marshal provider data: %w", err)
225 }
226
227 if err := os.WriteFile(c.path, data, 0o644); err != nil {
228 return fmt.Errorf("failed to write provider data to cache: %w", err)
229 }
230 return nil
231}