Detailed changes
@@ -137,13 +137,14 @@ type Permissions struct {
}
type Options struct {
- ContextPaths []string `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
- TUI *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
- Debug bool `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
- DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
- DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"`
- DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
- DisabledTools []string `json:"disabled_tools" jsonschema:"description=Tools to disable"`
+ ContextPaths []string `json:"context_paths,omitempty" jsonschema:"description=Paths to files containing context information for the AI,example=.cursorrules,example=CRUSH.md"`
+ TUI *TUIOptions `json:"tui,omitempty" jsonschema:"description=Terminal user interface options"`
+ Debug bool `json:"debug,omitempty" jsonschema:"description=Enable debug logging,default=false"`
+ DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"description=Enable debug logging for LSP servers,default=false"`
+ DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"description=Disable automatic conversation summarization,default=false"`
+ DataDirectory string `json:"data_directory,omitempty" jsonschema:"description=Directory for storing application data (relative to working directory),default=.crush,example=.crush"` // Relative to the cwd
+ DisabledTools []string `json:"disabled_tools" jsonschema:"description=Tools to disable"`
+ DisableProviderAutoUpdate bool `json:"disable_provider_auto_update,omitempty" jsonschema:"description=Disable providers auto-update,default=false"`
}
type MCPs map[string]MCPConfig
@@ -10,6 +10,7 @@ import (
"path/filepath"
"runtime"
"slices"
+ "strconv"
"strings"
"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -66,9 +67,9 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
)
// Load known providers, this loads the config from catwalk
- providers, err := Providers()
- if err != nil || len(providers) == 0 {
- return nil, fmt.Errorf("failed to load providers: %w", err)
+ providers, err := Providers(cfg)
+ if err != nil {
+ return nil, err
}
cfg.knownProviders = providers
@@ -76,7 +77,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
// Configure providers
valueResolver := NewShellVariableResolver(env)
cfg.resolver = valueResolver
- if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
+ if err := cfg.configureProviders(env, valueResolver, cfg.knownProviders); err != nil {
return nil, fmt.Errorf("failed to configure providers: %w", err)
}
@@ -85,7 +86,7 @@ func Load(workingDir, dataDir string, debug bool) (*Config, error) {
return cfg, nil
}
- if err := cfg.configureSelectedModels(providers); err != nil {
+ if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil {
return nil, fmt.Errorf("failed to configure selected models: %w", err)
}
cfg.SetupAgents()
@@ -340,6 +341,10 @@ func (c *Config) setDefaults(workingDir, dataDir string) {
c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...)
slices.Sort(c.Options.ContextPaths)
c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
+
+ if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok {
+ c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str)
+ }
}
var defaultLSPFileTypes = map[string][]string{
@@ -12,6 +12,7 @@ import (
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/catwalk/pkg/embedded"
"github.com/charmbracelet/crush/internal/home"
)
@@ -22,6 +23,7 @@ type ProviderClient interface {
var (
providerOnce sync.Once
providerList []catwalk.Provider
+ providerErr error
)
// file to cache provider data
@@ -75,55 +77,93 @@ func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
return providers, nil
}
-func Providers() ([]catwalk.Provider, error) {
- catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
- client := catwalk.NewWithURL(catwalkURL)
- path := providerCacheFileData()
- return loadProvidersOnce(client, path)
-}
-
-func loadProvidersOnce(client ProviderClient, path string) ([]catwalk.Provider, error) {
- var err error
+func Providers(cfg *Config) ([]catwalk.Provider, error) {
providerOnce.Do(func() {
- providerList, err = loadProviders(client, path)
+ catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)
+ client := catwalk.NewWithURL(catwalkURL)
+ path := providerCacheFileData()
+
+ autoUpdateDisabled := cfg.Options.DisableProviderAutoUpdate
+ providerList, providerErr = loadProviders(autoUpdateDisabled, client, path)
})
- if err != nil {
- return nil, err
- }
- return providerList, nil
+ return providerList, providerErr
}
-func loadProviders(client ProviderClient, path string) (providerList []catwalk.Provider, err error) {
- // if cache is not stale, load from it
- stale, exists := isCacheStale(path)
- if !stale {
- slog.Info("Using cached provider data", "path", path)
- providerList, err = loadProvidersFromCache(path)
- if len(providerList) > 0 && err == nil {
- go func() {
- slog.Info("Updating provider cache in background", "path", path)
- updated, uerr := client.GetProviders()
- if len(updated) > 0 && uerr == nil {
- _ = saveProvidersInCache(path, updated)
- }
- }()
- return
+func loadProviders(autoUpdateDisabled bool, client ProviderClient, path string) ([]catwalk.Provider, error) {
+ cacheIsStale, cacheExists := isCacheStale(path)
+
+ catwalkGetAndSave := func() ([]catwalk.Provider, error) {
+ providers, err := client.GetProviders()
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch providers from catwalk: %w", err)
+ }
+ if len(providers) == 0 {
+ return nil, fmt.Errorf("empty providers list from catwalk")
+ }
+ if err := saveProvidersInCache(path, providers); err != nil {
+ return nil, err
}
+ return providers, nil
}
- slog.Info("Getting live provider data", "path", path)
- providerList, err = client.GetProviders()
- if len(providerList) > 0 && err == nil {
- err = saveProvidersInCache(path, providerList)
- return
+ backgroundCacheUpdate := func() {
+ go func() {
+ slog.Info("Updating providers cache in background", "path", path)
+
+ providers, err := client.GetProviders()
+ if err != nil {
+ slog.Error("Failed to fetch providers in background from Catwalk", "error", err)
+ return
+ }
+ if len(providers) == 0 {
+ slog.Error("Empty providers list from Catwalk")
+ return
+ }
+ if err := saveProvidersInCache(path, providers); err != nil {
+ slog.Error("Failed to update providers.json in background", "error", err)
+ }
+ }()
}
- if !exists {
- err = fmt.Errorf("failed to load providers")
- return
+
+ switch {
+ case autoUpdateDisabled:
+ slog.Warn("Providers auto-update is disabled")
+
+ if cacheExists {
+ slog.Warn("Using locally cached providers")
+ return loadProvidersFromCache(path)
+ }
+
+ slog.Warn("Saving embedded providers to cache")
+ providers := embedded.GetAll()
+ if err := saveProvidersInCache(path, providers); err != nil {
+ return nil, err
+ }
+ return providers, nil
+
+ case cacheExists && !cacheIsStale:
+ slog.Info("Recent providers cache is available.", "path", path)
+
+ providers, err := loadProvidersFromCache(path)
+ if err != nil {
+ return nil, err
+ }
+ if len(providers) == 0 {
+ return catwalkGetAndSave()
+ }
+ backgroundCacheUpdate()
+ return providers, nil
+
+ default:
+ slog.Info("Cache is not available or is stale. Fetching providers from Catwalk.", "path", path)
+
+ providers, err := catwalkGetAndSave()
+ if err != nil {
+ catwalkUrl := fmt.Sprintf("%s/providers", cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL))
+ return nil, fmt.Errorf("crush was unable to fetch an updated list of providers from %s. Consider setting CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1 to use embedded version from the time of this Crush release. %w", catwalkUrl, err)
+ }
+ return providers, nil
}
- slog.Info("Loading provider data from cache", "path", path)
- providerList, err = loadProvidersFromCache(path)
- return
}
func isCacheStale(path string) (stale, exists bool) {
@@ -19,8 +19,8 @@ func TestProvider_loadProvidersEmptyResult(t *testing.T) {
client := &emptyProviderClient{}
tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, tmpPath)
- require.EqualError(t, err, "failed to load providers")
+ providers, err := loadProviders(false, client, tmpPath)
+ require.Contains(t, err.Error(), "crush was unable to fetch an updated list of providers")
require.Empty(t, providers)
require.Len(t, providers, 0)
@@ -39,7 +39,7 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) {
require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
// Should refresh and get real providers instead of using empty cache
- providers, err := loadProviders(client, tmpPath)
+ providers, err := loadProviders(false, client, tmpPath)
require.NoError(t, err)
require.NotNil(t, providers)
require.Len(t, providers, 1)
@@ -28,7 +28,7 @@ func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
func TestProvider_loadProvidersNoIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: false}
tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, tmpPath)
+ providers, err := loadProviders(false, client, tmpPath)
require.NoError(t, err)
require.NotNil(t, providers)
require.Len(t, providers, 1)
@@ -57,7 +57,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
if err != nil {
t.Fatalf("Failed to write old providers to file: %v", err)
}
- providers, err := loadProviders(client, tmpPath)
+ providers, err := loadProviders(false, client, tmpPath)
require.NoError(t, err)
require.NotNil(t, providers)
require.Len(t, providers, 1)
@@ -67,7 +67,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json"
- providers, err := loadProviders(client, tmpPath)
+ providers, err := loadProviders(false, client, tmpPath)
require.Error(t, err)
require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
}
@@ -397,7 +397,8 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
}
func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
- providers, err := config.Providers()
+ cfg := config.Get()
+ providers, err := config.Providers(cfg)
if err != nil {
return nil, err
}
@@ -49,7 +49,8 @@ func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldRe
func (m *ModelListComponent) Init() tea.Cmd {
var cmds []tea.Cmd
if len(m.providers) == 0 {
- providers, err := config.Providers()
+ cfg := config.Get()
+ providers, err := config.Providers(cfg)
filteredProviders := []catwalk.Provider{}
for _, p := range providers {
hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
@@ -119,7 +120,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
// First, add any configured providers that are not in the known providers list
// These should appear at the top of the list
- knownProviders, err := config.Providers()
+ knownProviders, err := config.Providers(cfg)
if err != nil {
return util.ReportError(err)
}
@@ -352,7 +352,8 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
}
func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
- providers, err := config.Providers()
+ cfg := config.Get()
+ providers, err := config.Providers(cfg)
if err != nil {
return nil, err
}
@@ -278,6 +278,11 @@
},
"type": "array",
"description": "Tools to disable"
+ },
+ "disable_provider_auto_update": {
+ "type": "boolean",
+ "description": "Disable providers auto-update",
+ "default": false
}
},
"additionalProperties": false,