feat: add `crush update-providers` command

Andrey Nering created

Change summary

internal/cmd/root.go             |  1 
internal/cmd/update_providers.go | 60 ++++++++++++++++++++++++++++++++++
internal/config/provider.go      | 36 ++++++++++++++++++++
3 files changed, 97 insertions(+)

Detailed changes

internal/cmd/root.go 🔗

@@ -28,6 +28,7 @@ func init() {
 	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 
 	rootCmd.AddCommand(runCmd)
+	rootCmd.AddCommand(updateProvidersCmd)
 }
 
 var rootCmd = &cobra.Command{

internal/cmd/update_providers.go 🔗

@@ -0,0 +1,60 @@
+package cmd
+
+import (
+	"fmt"
+	"log/slog"
+
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/lipgloss/v2"
+	"github.com/charmbracelet/x/exp/charmtone"
+	"github.com/spf13/cobra"
+)
+
+var updateProvidersCmd = &cobra.Command{
+	Use:   "update-providers [path-or-url]",
+	Short: "Update providers",
+	Long:  `Update the list of providers from a specified local path or remote URL.`,
+	Example: `
+# Update providers remotely from Catwalk
+crush update-providers
+
+# Update providers from a custom URL
+crush update-providers https://example.com/
+
+# Update providers from a local file
+crush update-providers /path/to/local-providers.json
+
+# Update providers from embedded version
+crush update-providers embedded
+`,
+	RunE: func(cmd *cobra.Command, args []string) error {
+		// NOTE(@andreynering): We want to skip logging output do stdout here.
+		slog.SetDefault(slog.New(slog.DiscardHandler))
+
+		var pathOrUrl string
+		if len(args) > 0 {
+			pathOrUrl = args[0]
+		}
+
+		if err := config.UpdateProviders(pathOrUrl); err != nil {
+			return err
+		}
+
+		// NOTE(@andreynering): This style is more-or-less copied from Fang's
+		// error message, adapted for success.
+		headerStyle := lipgloss.NewStyle().
+			Foreground(charmtone.Butter).
+			Background(charmtone.Guac).
+			Bold(true).
+			Padding(0, 1).
+			Margin(1).
+			MarginLeft(2).
+			SetString("SUCCESS")
+		textStyle := lipgloss.NewStyle().
+			MarginLeft(2).
+			SetString("Providers updated successfully.")
+
+		fmt.Printf("%s\n%s\n\n", headerStyle.Render(), textStyle.Render())
+		return nil
+	},
+}

internal/config/provider.go 🔗

@@ -8,6 +8,7 @@ import (
 	"os"
 	"path/filepath"
 	"runtime"
+	"strings"
 	"sync"
 	"time"
 
@@ -77,6 +78,41 @@ func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
 	return providers, nil
 }
 
+func UpdateProviders(pathOrUrl string) error {
+	var providers []catwalk.Provider
+	pathOrUrl = cmp.Or(pathOrUrl, os.Getenv("CATWALK_URL"), defaultCatwalkURL)
+
+	switch {
+	case pathOrUrl == "embedded":
+		providers = embedded.GetAll()
+	case strings.HasPrefix(pathOrUrl, "http://") || strings.HasPrefix(pathOrUrl, "https://"):
+		var err error
+		providers, err = catwalk.NewWithURL(pathOrUrl).GetProviders()
+		if err != nil {
+			return fmt.Errorf("failed to fetch providers from Catwalk: %w", err)
+		}
+	default:
+		content, err := os.ReadFile(pathOrUrl)
+		if err != nil {
+			return fmt.Errorf("failed to read file: %w", err)
+		}
+		if err := json.Unmarshal(content, &providers); err != nil {
+			return fmt.Errorf("failed to unmarshal provider data: %w", err)
+		}
+		if len(providers) == 0 {
+			return fmt.Errorf("no providers found in the provided source")
+		}
+	}
+
+	cachePath := providerCacheFileData()
+	if err := saveProvidersInCache(cachePath, providers); err != nil {
+		return fmt.Errorf("failed to save providers to cache: %w", err)
+	}
+
+	slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrUrl, "to", cachePath)
+	return nil
+}
+
 func Providers(cfg *Config) ([]catwalk.Provider, error) {
 	providerOnce.Do(func() {
 		catwalkURL := cmp.Or(os.Getenv("CATWALK_URL"), defaultCatwalkURL)