Detailed changes
@@ -77,7 +77,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
// NOTE(tauraamui): current entrypoint for invoking providers fetch for
// the rest of the app as a whole.
// Load known providers, this loads the config from catwalk
- providers, err := Providers(cfg)
+ providers, err := Providers(cfg.Options.DisableProviderAutoUpdate)
if err != nil {
return nil, err
}
@@ -119,18 +119,19 @@ func UpdateProviders(pathOrUrl string) error {
return nil
}
-func Providers(cfg *Config) ([]catwalk.Provider, error) {
- providerMu.Lock()
- if !initialized {
- catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
- client := catwalk.NewWithURL(catwalkURL)
- path := providerCacheFileData()
+func Providers(autoUpdateDisabled bool) ([]catwalk.Provider, error) {
+ catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
+ client := catwalk.NewWithURL(catwalkURL)
+ return ProvidersWithClient(autoUpdateDisabled, client, providerCacheFileData())
+}
- autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
- providerList, providerErr = loadProviders(autoUpdateDisabled, client, path, cfg)
+func ProvidersWithClient(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
+ if !initialized {
+ providerMu.Lock()
+ providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
initialized = true
+ providerMu.Unlock()
}
- providerMu.Unlock()
providerMu.RLock()
defer providerMu.RUnlock()
@@ -156,7 +157,7 @@ func reloadProviders(path string) {
slog.Info("Providers reloaded successfully", "count", len(providers))
}
-func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string, cfg *Config) ([]catwalk.Provider, error) {
+func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
cacheIsStale, cacheExists := isCacheStale(path)
catwalkGetAndSave := func() ([]catwalk.Provider, error) {
@@ -19,8 +19,7 @@ func TestProvider_loadProvidersEmptyResult(t *testing.T) {
client := &emptyProviderClient{}
tmpPath := t.TempDir() + "/providers.json"
- cfg := &Config{}
- providers, err := loadProviders(false, client, tmpPath, cfg)
+ providers, err := loadProviders(false, client, tmpPath)
require.Contains(t, err.Error(), "Crush was unable to fetch an updated list of providers")
require.Empty(t, providers)
require.Len(t, providers, 0)
@@ -40,8 +39,7 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) {
require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
// Should refresh and get real providers instead of using empty cache
- cfg := &Config{}
- providers, err := loadProviders(false, client, tmpPath, cfg)
+ providers, err := loadProviders(false, client, tmpPath)
require.NoError(t, err)
require.NotNil(t, providers)
require.Len(t, providers, 1)
@@ -399,7 +399,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
cfg := config.Get()
- providers, err := config.Providers(cfg)
+ providers, err := config.Providers(cfg.Options.DisableProviderAutoUpdate)
if err != nil {
return nil, err
}
@@ -50,7 +50,7 @@ func (m *ModelListComponent) Init() tea.Cmd {
var cmds []tea.Cmd
if len(m.providers) == 0 {
cfg := config.Get()
- providers, err := config.Providers(cfg)
+ providers, err := config.Providers(cfg.Options.DisableProviderAutoUpdate)
filteredProviders := []catwalk.Provider{}
for _, p := range providers {
hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
@@ -120,7 +120,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
// First, add any configured providers that are not in the known providers list
// These should appear at the top of the list
- knownProviders, err := config.Providers(cfg)
+ knownProviders, err := config.Providers(cfg.Options.DisableProviderAutoUpdate)
if err != nil {
return util.ReportError(err)
}
@@ -353,7 +353,7 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
cfg := config.Get()
- providers, err := config.Providers(cfg)
+ providers, err := config.Providers(cfg.Options.DisableProviderAutoUpdate)
if err != nil {
return nil, err
}