feat: add ability to disable providers auto-update from catwalk

Andrey Nering created

Change summary

internal/config/config.go                        |  15 +-
internal/config/load.go                          |  15 +
internal/config/provider.go                      | 120 ++++++++++++------
internal/config/provider_empty_test.go           |   6 
internal/config/provider_test.go                 |   6 
internal/tui/components/chat/splash/splash.go    |   3 
internal/tui/components/dialogs/models/list.go   |   5 
internal/tui/components/dialogs/models/models.go |   3 
schema.json                                      |   5 
9 files changed, 116 insertions(+), 62 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -137,13 +137,14 @@ type Permissions struct {
 }
 
 type Options struct {
-	ContextPaths         []string    `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
-	TUI                  *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
-	Debug                bool        `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
-	DebugLSP             bool        `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
-	DisableAutoSummarize bool        `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"`
-	DataDirectory        string      `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
-	DisabledTools        []string    `json:"disabled_tools" jsonschema:"description=Tools to disable"`
+	ContextPaths              []string    `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
+	TUI                       *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
+	Debug                     bool        `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
+	DebugLSP                  bool        `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
+	DisableAutoSummarize      bool        `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"`
+	DataDirectory             string      `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
+	DisabledTools             []string    `json:"disabled_tools" jsonschema:"description=Tools to disable"`
+	DisableProviderAutoUpdate bool        `json:"disable_provider_auto_update,omitempty" jsonschema:"description=Disable providers auto-update,default=false"`
 }
 
 type MCPs map[string]MCPConfig

internal/config/load.go 🔗

@@ -10,6 +10,7 @@ import (
 	"path/filepath"
 	"runtime"
 	"slices"
+	"strconv"
 	"strings"
 
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -66,9 +67,9 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
 	)
 
 	// Load known providers, this loads the config from catwalk
-	providers, err := Providers()
-	if err != nil || len(providers) == 0 {
-		return nil, fmt.Errorf("failed to load providers: %w", err)
+	providers, err := Providers(cfg)
+	if err != nil {
+		return nil, err
 	}
 	cfg.knownProviders = providers
 
@@ -76,7 +77,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
 	// Configure providers
 	valueResolver := NewShellVariableResolver(env)
 	cfg.resolver = valueResolver
-	if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
+	if err := cfg.configureProviders(env, valueResolver, cfg.knownProviders); err != nil {
 		return nil, fmt.Errorf("failed to configure providers: %w", err)
 	}
 
@@ -85,7 +86,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
 		return cfg, nil
 	}
 
-	if err := cfg.configureSelectedModels(providers); err != nil {
+	if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil {
 		return nil, fmt.Errorf("failed to configure selected models: %w", err)
 	}
 	cfg.SetupAgents()
@@ -340,6 +341,10 @@ func (c *Config) setDefaults(workingDir, dataDir string) {
 	c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...)
 	slices.Sort(c.Options.ContextPaths)
 	c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
+
+	if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok {
+		c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str)
+	}
 }
 
 var defaultLSPFileTypes = map[string][]string{

internal/config/provider.go 🔗

@@ -12,6 +12,7 @@ import (
 	"time"
 
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/catwalk/pkg/embedded"
 	"github.com/charmbracelet/crush/internal/home"
 )
 
@@ -22,6 +23,7 @@ type ProviderClient interface {
 var (
 	providerOnce sync.Once
 	providerList []catwalk.Provider
+	providerErr  error
 )
 
 // file to cache provider data
@@ -75,55 +77,93 @@ func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
 	return providers, nil
 }
 
-func Providers() ([]catwalk.Provider, error) {
-	catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
-	client := catwalk.NewWithURL(catwalkURL)
-	path := providerCacheFileData()
-	return loadProvidersOnce(client, path)
-}
-
-func loadProvidersOnce(client ProviderClient, path string) ([]catwalk.Provider, error) {
-	var err error
+func Providers(cfg *Config) ([]catwalk.Provider, error) {
 	providerOnce.Do(func() {
-		providerList, err = loadProviders(client, path)
+		catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
+		client := catwalk.NewWithURL(catwalkURL)
+		path := providerCacheFileData()
+
+		autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
+		providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
 	})
-	if err != nil {
-		return nil, err
-	}
-	return providerList, nil
+	return providerList, providerErr
 }
 
-func loadProviders(client ProviderClient, path string) (providerList []catwalk.Provider, err error) {
-	// if cache is not stale, load from it
-	stale, exists := isCacheStale(path)
-	if !stale {
-		slog.Info("Using cached provider data", "path", path)
-		providerList, err = loadProvidersFromCache(path)
-		if len(providerList) > 0 && err == nil {
-			go func() {
-				slog.Info("Updating provider cache in background", "path", path)
-				updated, uerr := client.GetProviders()
-				if len(updated) > 0 && uerr == nil {
-					_ = saveProvidersInCache(path, updated)
-				}
-			}()
-			return
+func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
+	cacheIsStale, cacheExists := isCacheStale(path)
+
+	catwalkGetAndSave := func() ([]catwalk.Provider, error) {
+		providers, err := client.GetProviders()
+		if err != nil {
+			return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
+		}
+		if len(providers) == 0 {
+			return nil, fmt.Errorf("empty providers list from catwalk")
+		}
+		if err := saveProvidersInCache(path, providers); err != nil {
+			return nil, err
 		}
+		return providers, nil
 	}
 
-	slog.Info("Getting live provider data", "path", path)
-	providerList, err = client.GetProviders()
-	if len(providerList) > 0 && err == nil {
-		err = saveProvidersInCache(path, providerList)
-		return
+	backgroundCacheUpdate := func() {
+		go func() {
+			slog.Info("Updating providers cache in background", "path", path)
+
+			providers, err := client.GetProviders()
+			if err != nil {
+				slog.Error("Failed to fetch providers in background from Catwalk", "error", err)
+				return
+			}
+			if len(providers) == 0 {
+				slog.Error("Empty providers list from Catwalk")
+				return
+			}
+			if err := saveProvidersInCache(path, providers); err != nil {
+				slog.Error("Failed to update providers.json in background", "error", err)
+			}
+		}()
 	}
-	if !exists {
-		err = fmt.Errorf("failed to load providers")
-		return
+
+	switch {
+	case autoUpdateDisabled:
+		slog.Warn("Providers auto-update is disabled")
+
+		if cacheExists {
+			slog.Warn("Using locally cached providers")
+			return loadProvidersFromCache(path)
+		}
+
+		slog.Warn("Saving embedded providers to cache")
+		providers := embedded.GetAll()
+		if err := saveProvidersInCache(path, providers); err != nil {
+			return nil, err
+		}
+		return providers, nil
+
+	case cacheExists && !cacheIsStale:
+		slog.Info("Recent providers cache is available.", "path", path)
+
+		providers, err := loadProvidersFromCache(path)
+		if err != nil {
+			return nil, err
+		}
+		if len(providers) == 0 {
+			return catwalkGetAndSave()
+		}
+		backgroundCacheUpdate()
+		return providers, nil
+
+	default:
+		slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
+
+		providers, err := catwalkGetAndSave()
+		if err != nil {
+			catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
+			return nil, fmt.Errorf("crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use embedded version from the time of this Crush release. %w", catwalkUrl, err)
+		}
+		return providers, nil
 	}
-	slog.Info("Loading provider data from cache", "path", path)
-	providerList, err = loadProvidersFromCache(path)
-	return
 }
 
 func isCacheStale(path string) (stale, exists bool) {

internal/config/provider_empty_test.go 🔗

@@ -19,8 +19,8 @@ func TestProvider_loadProvidersEmptyResult(t *testing.T) {
 	client := &emptyProviderClient{}
 	tmpPath := t.TempDir() + "/providers.json"
 
-	providers, err := loadProviders(client, tmpPath)
-	require.EqualError(t, err, "failed to load providers")
+	providers, err := loadProviders(false, client, tmpPath)
+	require.Contains(t, err.Error(), "crush was unable to fetch an updated list of providers")
 	require.Empty(t, providers)
 	require.Len(t, providers, 0)
 
@@ -39,7 +39,7 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) {
 	require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
 
 	// Should refresh and get real providers instead of using empty cache
-	providers, err := loadProviders(client, tmpPath)
+	providers, err := loadProviders(false, client, tmpPath)
 	require.NoError(t, err)
 	require.NotNil(t, providers)
 	require.Len(t, providers, 1)

internal/config/provider_test.go 🔗

@@ -28,7 +28,7 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
 func TestProvider_loadProvidersNoIssues(t *testing.T) {
 	client := &mockProviderClient{shouldFail: false}
 	tmpPath := t.TempDir() + "/providers.json"
-	providers, err := loadProviders(client, tmpPath)
+	providers, err := loadProviders(false, client, tmpPath)
 	require.NoError(t, err)
 	require.NotNil(t, providers)
 	require.Len(t, providers, 1)
@@ -57,7 +57,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to write old providers to file: %v", err)
 	}
-	providers, err := loadProviders(client, tmpPath)
+	providers, err := loadProviders(false, client, tmpPath)
 	require.NoError(t, err)
 	require.NotNil(t, providers)
 	require.Len(t, providers, 1)
@@ -67,7 +67,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
 func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
 	client := &mockProviderClient{shouldFail: true}
 	tmpPath := t.TempDir() + "/providers.json"
-	providers, err := loadProviders(client, tmpPath)
+	providers, err := loadProviders(false, client, tmpPath)
 	require.Error(t, err)
 	require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
 }

internal/tui/components/chat/splash/splash.go 🔗

@@ -397,7 +397,8 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
 }
 
 func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
-	providers, err := config.Providers()
+	cfg := config.Get()
+	providers, err := config.Providers(cfg)
 	if err != nil {
 		return nil, err
 	}

internal/tui/components/dialogs/models/list.go 🔗

@@ -49,7 +49,8 @@ func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldRe
 func (m *ModelListComponent) Init() tea.Cmd {
 	var cmds []tea.Cmd
 	if len(m.providers) == 0 {
-		providers, err := config.Providers()
+		cfg := config.Get()
+		providers, err := config.Providers(cfg)
 		filteredProviders := []catwalk.Provider{}
 		for _, p := range providers {
 			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
@@ -119,7 +120,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
 
 	// First, add any configured providers that are not in the known providers list
 	// These should appear at the top of the list
-	knownProviders, err := config.Providers()
+	knownProviders, err := config.Providers(cfg)
 	if err != nil {
 		return util.ReportError(err)
 	}

internal/tui/components/dialogs/models/models.go 🔗

@@ -352,7 +352,8 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
 }
 
 func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
-	providers, err := config.Providers()
+	cfg := config.Get()
+	providers, err := config.Providers(cfg)
 	if err != nil {
 		return nil, err
 	}

schema.json 🔗

@@ -278,6 +278,11 @@
           },
           "type": "array",
           "description": "Tools to disable"
+        },
+        "disable_provider_auto_update": {
+          "type": "boolean",
+          "description": "Disable providers auto-update",
+          "default": false
         }
       },
       "additionalProperties": false,