From 230e4de83c23292ab42598dbffa327405bcb3c42 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Wed, 10 Sep 2025 17:44:04 -0300 Subject: [PATCH] feat: add ability to disable providers auto-update from catwalk --- internal/config/config.go | 15 ++- internal/config/load.go | 15 ++- internal/config/provider.go | 120 ++++++++++++------ internal/config/provider_empty_test.go | 6 +- internal/config/provider_test.go | 6 +- internal/tui/components/chat/splash/splash.go | 3 +- .../tui/components/dialogs/models/list.go | 5 +- .../tui/components/dialogs/models/models.go | 3 +- schema.json | 5 + 9 files changed, 116 insertions(+), 62 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 4e42a56e361c81feca31cd95bd778d14c312cd20..17ed626838cb555db163ee6c4db47d9d1be61b2a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -137,13 +137,14 @@ type Permissions struct { } type Options struct { - ContextPaths []string `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"` - TUI *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"` - Debug bool `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"` - DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"` - DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"` - DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd - DisabledTools []string `json:"disabled_tools" jsonschema:"description=Tools to disable"` + ContextPaths []string `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"` + TUI *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"` + Debug bool `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"` + DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"` + DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"` + DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd + DisabledTools []string `json:"disabled_tools" jsonschema:"description=Tools to disable"` + DisableProviderAutoUpdate bool `json:"disable_provider_auto_update,omitempty" jsonschema:"description=Disable providers auto-update,default=false"` } type MCPs map[string]MCPConfig diff --git a/internal/config/load.go b/internal/config/load.go index a703a049c7697be9209d3994c857ff0548f60b8b..9e1c9d0f7b739d7d6bdd974657b6efb5ea52d2ee 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "slices" + "strconv" "strings" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -66,9 +67,9 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { ) // Load known providers, this loads the config from catwalk - providers, err := Providers() - if err != nil || len(providers) == 0 { - return nil, fmt.Errorf("failed to load providers: %w", err) + providers, err := Providers(cfg) + if err != nil { + return nil, err } cfg.knownProviders = providers @@ -76,7 +77,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { // Configure providers valueResolver := NewShellVariableResolver(env) cfg.resolver = valueResolver - if err := cfg.configureProviders(env, valueResolver, providers); err != nil { + if err := cfg.configureProviders(env, valueResolver, cfg.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } @@ -85,7 +86,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) { return cfg, nil } - if err := cfg.configureSelectedModels(providers); err != nil { + if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure selected models: %w", err) } cfg.SetupAgents() @@ -340,6 +341,10 @@ func (c *Config) setDefaults(workingDir, dataDir string) { c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...) slices.Sort(c.Options.ContextPaths) c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths) + + if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok { + c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str) + } } var defaultLSPFileTypes = map[string][]string{ diff --git a/internal/config/provider.go b/internal/config/provider.go index 68ede5095506b21dc4d744e309aaa836917345e5..2248c8949a9880a4f555db8c2c5098742a5772b0 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -12,6 +12,7 @@ import ( "time" "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/catwalk/pkg/embedded" "github.com/charmbracelet/crush/internal/home" ) @@ -22,6 +23,7 @@ type ProviderClient interface { var ( providerOnce sync.Once providerList []catwalk.Provider + providerErr error ) // file to cache provider data @@ -75,55 +77,93 @@ func loadProvidersFromCache(path string) ([]catwalk.Provider, error) { return providers, nil } -func Providers() ([]catwalk.Provider, error) { - catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) - client := catwalk.NewWithURL(catwalkURL) - path := providerCacheFileData() - return loadProvidersOnce(client, path) -} - -func loadProvidersOnce(client ProviderClient, path string) ([]catwalk.Provider, error) { - var err error +func Providers(cfg *Config) ([]catwalk.Provider, error) { providerOnce.Do(func() { - providerList, err = loadProviders(client, path) + catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL) + client := catwalk.NewWithURL(catwalkURL) + path := providerCacheFileData() + + autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate + providerList, providerErr = loadProviders(autoUpdateDisabled, client, path) }) - if err != nil { - return nil, err - } - return providerList, nil + return providerList, providerErr } -func loadProviders(client ProviderClient, path string) (providerList []catwalk.Provider, err error) { - // if cache is not stale, load from it - stale, exists := isCacheStale(path) - if !stale { - slog.Info("Using cached provider data", "path", path) - providerList, err = loadProvidersFromCache(path) - if len(providerList) > 0 && err == nil { - go func() { - slog.Info("Updating provider cache in background", "path", path) - updated, uerr := client.GetProviders() - if len(updated) > 0 && uerr == nil { - _ = saveProvidersInCache(path, updated) - } - }() - return +func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) { + cacheIsStale, cacheExists := isCacheStale(path) + + catwalkGetAndSave := func() ([]catwalk.Provider, error) { + providers, err := client.GetProviders() + if err != nil { + return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err) + } + if len(providers) == 0 { + return nil, fmt.Errorf("empty providers list from catwalk") + } + if err := saveProvidersInCache(path, providers); err != nil { + return nil, err } + return providers, nil } - slog.Info("Getting live provider data", "path", path) - providerList, err = client.GetProviders() - if len(providerList) > 0 && err == nil { - err = saveProvidersInCache(path, providerList) - return + backgroundCacheUpdate := func() { + go func() { + slog.Info("Updating providers cache in background", "path", path) + + providers, err := client.GetProviders() + if err != nil { + slog.Error("Failed to fetch providers in background from Catwalk", "error", err) + return + } + if len(providers) == 0 { + slog.Error("Empty providers list from Catwalk") + return + } + if err := saveProvidersInCache(path, providers); err != nil { + slog.Error("Failed to update providers.json in background", "error", err) + } + }() } - if !exists { - err = fmt.Errorf("failed to load providers") - return + + switch { + case autoUpdateDisabled: + slog.Warn("Providers auto-update is disabled") + + if cacheExists { + slog.Warn("Using locally cached providers") + return loadProvidersFromCache(path) + } + + slog.Warn("Saving embedded providers to cache") + providers := embedded.GetAll() + if err := saveProvidersInCache(path, providers); err != nil { + return nil, err + } + return providers, nil + + case cacheExists && !cacheIsStale: + slog.Info("Recent providers cache is available.", "path", path) + + providers, err := loadProvidersFromCache(path) + if err != nil { + return nil, err + } + if len(providers) == 0 { + return catwalkGetAndSave() + } + backgroundCacheUpdate() + return providers, nil + + default: + slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path) + + providers, err := catwalkGetAndSave() + if err != nil { + catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)) + return nil, fmt.Errorf("crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use embedded version from the time of this Crush release. %w", catwalkUrl, err) + } + return providers, nil } - slog.Info("Loading provider data from cache", "path", path) - providerList, err = loadProvidersFromCache(path) - return } func isCacheStale(path string) (stale, exists bool) { diff --git a/internal/config/provider_empty_test.go b/internal/config/provider_empty_test.go index cb71cabfa5a01cb16b6ef2b6708d1780e31951a9..3cd55ae7921171a580dccc91aa1d22d2f7934271 100644 --- a/internal/config/provider_empty_test.go +++ b/internal/config/provider_empty_test.go @@ -19,8 +19,8 @@ func TestProvider_loadProvidersEmptyResult(t *testing.T) { client := &emptyProviderClient{} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, tmpPath) - require.EqualError(t, err, "failed to load providers") + providers, err := loadProviders(false, client, tmpPath) + require.Contains(t, err.Error(), "crush was unable to fetch an updated list of providers") require.Empty(t, providers) require.Len(t, providers, 0) @@ -39,7 +39,7 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) { require.NoError(t, os.WriteFile(tmpPath, data, 0o644)) // Should refresh and get real providers instead of using empty cache - providers, err := loadProviders(client, tmpPath) + providers, err := loadProviders(false, client, tmpPath) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index ed2568d68a840798872af60c5132707e84a5cbbf..8b499919bca666915a89d38c1e5014a911f4d2d1 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -28,7 +28,7 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) { func TestProvider_loadProvidersNoIssues(t *testing.T) { client := &mockProviderClient{shouldFail: false} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, tmpPath) + providers, err := loadProviders(false, client, tmpPath) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) @@ -57,7 +57,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { if err != nil { t.Fatalf("Failed to write old providers to file: %v", err) } - providers, err := loadProviders(client, tmpPath) + providers, err := loadProviders(false, client, tmpPath) require.NoError(t, err) require.NotNil(t, providers) require.Len(t, providers, 1) @@ -67,7 +67,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) { func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { client := &mockProviderClient{shouldFail: true} tmpPath := t.TempDir() + "/providers.json" - providers, err := loadProviders(client, tmpPath) + providers, err := loadProviders(false, client, tmpPath) require.Error(t, err) require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") } diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 7fa46cdd279a2cbe98a86654a23e81a49bc8aebf..b49bd862876f6b3eb880bfe732b956026421aabe 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -397,7 +397,8 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd { } func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { - providers, err := config.Providers() + cfg := config.Get() + providers, err := config.Providers(cfg) if err != nil { return nil, err } diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go index 66b55d85b299cb0bacb4cc2466c7b4146248ba05..77398c4d17d85126ab155a9e9c5b2085c0691672 100644 --- a/internal/tui/components/dialogs/models/list.go +++ b/internal/tui/components/dialogs/models/list.go @@ -49,7 +49,8 @@ func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldRe func (m *ModelListComponent) Init() tea.Cmd { var cmds []tea.Cmd if len(m.providers) == 0 { - providers, err := config.Providers() + cfg := config.Get() + providers, err := config.Providers(cfg) filteredProviders := []catwalk.Provider{} for _, p := range providers { hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$") @@ -119,7 +120,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { // First, add any configured providers that are not in the known providers list // These should appear at the top of the list - knownProviders, err := config.Providers() + knownProviders, err := config.Providers(cfg) if err != nil { return util.ReportError(err) } diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index 56d9eac17c277e8cbbb7c4349bbf420c56fb8610..7c2863706c29180cffcfb88c385a012e39df464c 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -352,7 +352,8 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { } func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) { - providers, err := config.Providers() + cfg := config.Get() + providers, err := config.Providers(cfg) if err != nil { return nil, err } diff --git a/schema.json b/schema.json index 060f9738884da739a186898d859ac5618c35b5b8..9dee9055050c8e29fb689e9700b33aa8e9842cd2 100644 --- a/schema.json +++ b/schema.json @@ -278,6 +278,11 @@ }, "type": "array", "description": "Tools to disable" + }, + "disable_provider_auto_update": { + "type": "boolean", + "description": "Disable providers auto-update", + "default": false } }, "additionalProperties": false,