1package config
2
3import (
4 "cmp"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "sync"
13
14 "git.secluded.site/crush/internal/home"
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/catwalk/pkg/embedded"
17)
18
19type ProviderClient interface {
20 GetProviders() ([]catwalk.Provider, error)
21}
22
23var (
24 providerOnce sync.Once
25 providerList []catwalk.Provider
26 providerErr error
27)
28
29// file to cache provider data
30func providerCacheFileData() string {
31 xdgDataHome := os.Getenv("XDG_DATA_HOME")
32 if xdgDataHome != "" {
33 return filepath.Join(xdgDataHome, appName, "providers.json")
34 }
35
36 // return the path to the main data directory
37 // for windows, it should be in `%LOCALAPPDATA%/crush/`
38 // for linux and macOS, it should be in `$HOME/.local/share/crush/`
39 if runtime.GOOS == "windows" {
40 localAppData := os.Getenv("LOCALAPPDATA")
41 if localAppData == "" {
42 localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
43 }
44 return filepath.Join(localAppData, appName, "providers.json")
45 }
46
47 return filepath.Join(home.Dir(), ".local", "share", appName, "providers.json")
48}
49
50func saveProvidersInCache(path string, providers []catwalk.Provider) error {
51 slog.Info("Saving provider data to disk", "path", path)
52 if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
53 return fmt.Errorf("failed to create directory for provider cache: %w", err)
54 }
55
56 data, err := json.MarshalIndent(providers, "", " ")
57 if err != nil {
58 return fmt.Errorf("failed to marshal provider data: %w", err)
59 }
60
61 if err := os.WriteFile(path, data, 0o644); err != nil {
62 return fmt.Errorf("failed to write provider data to cache: %w", err)
63 }
64 return nil
65}
66
67func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
68 data, err := os.ReadFile(path)
69 if err != nil {
70 return nil, fmt.Errorf("failed to read provider cache file: %w", err)
71 }
72
73 var providers []catwalk.Provider
74 if err := json.Unmarshal(data, &providers); err != nil {
75 return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
76 }
77 return providers, nil
78}
79
80func UpdateProviders(pathOrUrl string) error {
81 var providers []catwalk.Provider
82 pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
83
84 switch {
85 case pathOrUrl == "embedded":
86 providers = embedded.GetAll()
87 case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
88 var err error
89 providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
90 if err != nil {
91 return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
92 }
93 default:
94 content, err := os.ReadFile(pathOrUrl)
95 if err != nil {
96 return fmt.Errorf("failed to read file: %w", err)
97 }
98 if err := json.Unmarshal(content, &providers); err != nil {
99 return fmt.Errorf("failed to unmarshal provider data: %w", err)
100 }
101 if len(providers) == 0 {
102 return fmt.Errorf("no providers found in the provided source")
103 }
104 }
105
106 cachePath := providerCacheFileData()
107 if err := saveProvidersInCache(cachePath, providers); err != nil {
108 return fmt.Errorf("failed to save providers to cache: %w", err)
109 }
110
111 slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
112 return nil
113}
114
115func Providers(cfg *Config) ([]catwalk.Provider, error) {
116 providerOnce.Do(func() {
117 catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
118 client := catwalk.NewWithURL(catwalkURL)
119 path := providerCacheFileData()
120
121 autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
122 providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
123 })
124 return providerList, providerErr
125}
126
127func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
128 catwalkGetAndSave := func() ([]catwalk.Provider, error) {
129 providers, err := client.GetProviders()
130 if err != nil {
131 return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
132 }
133 if len(providers) == 0 {
134 return nil, fmt.Errorf("empty providers list from catwalk")
135 }
136 if err := saveProvidersInCache(path, providers); err != nil {
137 return nil, err
138 }
139 return providers, nil
140 }
141
142 switch {
143 case autoUpdateDisabled:
144 slog.Warn("Providers auto-update is disabled")
145
146 if _, err := os.Stat(path); err == nil {
147 slog.Warn("Using locally cached providers")
148 return loadProvidersFromCache(path)
149 }
150
151 slog.Warn("Saving embedded providers to cache")
152 providers := embedded.GetAll()
153 if err := saveProvidersInCache(path, providers); err != nil {
154 return nil, err
155 }
156 return providers, nil
157
158 default:
159 slog.Info("Fetching providers from Catwalk.", "path", path)
160
161 providers, err := catwalkGetAndSave()
162 if err != nil {
163 catwalkUrl := fmt.Sprintf("%s/v2/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
164 return nil, fmt.Errorf("Crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use the embedded providers bundled at the time of this Crush release. You can also update providers manually. For more info see crush update-providers --help. %w", catwalkUrl, err) //nolint:staticcheck
165 }
166 return providers, nil
167 }
168}