models.go

  1package cmd
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"sort"
  8	"strings"
  9
 10	"charm.land/catwalk/pkg/catwalk"
 11	"charm.land/lipgloss/v2/tree"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/mattn/go-isatty"
 14	"github.com/spf13/cobra"
 15)
 16
 17var modelsCmd = &cobra.Command{
 18	Use:   "models",
 19	Short: "List all available models from configured providers",
 20	Long:  `List all available models from configured providers. Shows provider name and model IDs.`,
 21	Example: `# List all available models
 22crush models
 23
 24# Search models
 25crush models gpt5`,
 26	Args: cobra.ArbitraryArgs,
 27	RunE: func(cmd *cobra.Command, args []string) error {
 28		cwd, err := ResolveCwd(cmd)
 29		if err != nil {
 30			return err
 31		}
 32
 33		dataDir, _ := cmd.Flags().GetString("data-dir")
 34		debug, _ := cmd.Flags().GetBool("debug")
 35
 36		cfg, err := config.Init(cwd, dataDir, debug)
 37		if err != nil {
 38			return err
 39		}
 40
 41		if !cfg.IsConfigured() {
 42			return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
 43		}
 44
 45		term := strings.ToLower(strings.Join(args, " "))
 46		filter := func(p config.ProviderConfig, m catwalk.Model) bool {
 47			for _, s := range []string{p.ID, p.Name, m.ID, m.Name} {
 48				if term == "" || strings.Contains(strings.ToLower(s), term) {
 49					return true
 50				}
 51			}
 52			return false
 53		}
 54
 55		var providerIDs []string
 56		providerModels := make(map[string][]string)
 57
 58		for providerID, provider := range cfg.Providers.Seq2() {
 59			if provider.Disable {
 60				continue
 61			}
 62			var found bool
 63			for _, model := range provider.Models {
 64				if !filter(provider, model) {
 65					continue
 66				}
 67				providerModels[providerID] = append(providerModels[providerID], model.ID)
 68				found = true
 69			}
 70			if !found {
 71				continue
 72			}
 73			slices.Sort(providerModels[providerID])
 74			providerIDs = append(providerIDs, providerID)
 75		}
 76		sort.Strings(providerIDs)
 77
 78		if len(providerIDs) == 0 && len(args) == 0 {
 79			return fmt.Errorf("no enabled providers found")
 80		}
 81		if len(providerIDs) == 0 {
 82			return fmt.Errorf("no enabled providers found matching %q", term)
 83		}
 84
 85		if !isatty.IsTerminal(os.Stdout.Fd()) {
 86			for _, providerID := range providerIDs {
 87				for _, modelID := range providerModels[providerID] {
 88					fmt.Println(providerID + "/" + modelID)
 89				}
 90			}
 91			return nil
 92		}
 93
 94		t := tree.New()
 95		for _, providerID := range providerIDs {
 96			providerNode := tree.Root(providerID)
 97			for _, modelID := range providerModels[providerID] {
 98				providerNode.Child(modelID)
 99			}
100			t.Child(providerNode)
101		}
102
103		cmd.Println(t)
104		return nil
105	},
106}
107
108func init() {
109	rootCmd.AddCommand(modelsCmd)
110}