feat: crush run --model, and crush models (#1889)

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

go.mod                        |   3 
go.sum                        |   2 
internal/app/app.go           |  98 +++++++++++++++++
internal/app/provider.go      |  95 ++++++++++++++++
internal/app/provider_test.go | 210 +++++++++++++++++++++++++++++++++++++
internal/cmd/models.go        | 110 +++++++++++++++++++
internal/cmd/run.go           |   6 
7 files changed, 521 insertions(+), 3 deletions(-)

Detailed changes

go.mod 🔗

@@ -30,6 +30,7 @@ require (
 	github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f
 	github.com/charmbracelet/x/exp/ordered v0.1.0
 	github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff
+	github.com/charmbracelet/x/exp/strings v0.0.0-20260119114936-fd556377ea59
 	github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b
 	github.com/charmbracelet/x/term v0.2.2
 	github.com/denisbrodbeck/machineid v1.0.1
@@ -38,6 +39,7 @@ require (
 	github.com/invopop/jsonschema v0.13.0
 	github.com/joho/godotenv v1.5.1
 	github.com/lucasb-eyer/go-colorful v1.3.0
+	github.com/mattn/go-isatty v0.0.20
 	github.com/modelcontextprotocol/go-sdk v1.2.0
 	github.com/muesli/termenv v0.16.0
 	github.com/ncruces/go-sqlite3 v0.30.4
@@ -135,7 +137,6 @@ require (
 	github.com/klauspost/cpuid/v2 v2.0.9 // indirect
 	github.com/klauspost/pgzip v1.2.6 // indirect
 	github.com/mailru/easyjson v0.7.7 // indirect
-	github.com/mattn/go-isatty v0.0.20 // indirect
 	github.com/mattn/go-runewidth v0.0.19 // indirect
 	github.com/mfridman/interpolate v0.0.2 // indirect
 	github.com/microcosm-cc/bluemonday v1.0.27 // indirect

go.sum 🔗

@@ -118,6 +118,8 @@ github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sB
 github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
 github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff h1:Uwr+/JS+qnRcO/++xjYEDtW7x+P5E4+4cBiOHTt2Xfk=
 github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
+github.com/charmbracelet/x/exp/strings v0.0.0-20260119114936-fd556377ea59 h1:cvPMInXNmK/CHjQU8eXC/oSnGfEKpQmndsEykh03bt0=
+github.com/charmbracelet/x/exp/strings v0.0.0-20260119114936-fd556377ea59/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
 github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
 github.com/charmbracelet/x/json v0.2.0/go.mod h1:opFIflx2YgXgi49xVUu8gEQ21teFAxyMwvOiZhIvWNM=
 github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b h1:5ye9hzBKH623bMVz5auIuY6K21loCdxpRmFle2O9R/8=

internal/app/app.go 🔗

@@ -17,6 +17,7 @@ import (
 	tea "charm.land/bubbletea/v2"
 	"charm.land/fantasy"
 	"charm.land/lipgloss/v2"
+	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/agent"
 	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/config"
@@ -131,12 +132,18 @@ func (app *App) Config() *config.Config {
 
 // RunNonInteractive runs the application in non-interactive mode with the
 // given prompt, printing to stdout.
-func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt string, quiet bool) error {
+func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, quiet bool) error {
 	slog.Info("Running in non-interactive mode")
 
 	ctx, cancel := context.WithCancel(ctx)
 	defer cancel()
 
+	if largeModel != "" || smallModel != "" {
+		if err := app.overrideModelsForNonInteractive(ctx, largeModel, smallModel); err != nil {
+			return fmt.Errorf("failed to override models: %w", err)
+		}
+	}
+
 	var (
 		spinner   *format.Spinner
 		stdoutTTY bool
@@ -299,6 +306,95 @@ func (app *App) UpdateAgentModel(ctx context.Context) error {
 	return app.AgentCoordinator.UpdateModels(ctx)
 }
 
+// overrideModelsForNonInteractive parses the model strings and temporarily
+// overrides the model configurations, then rebuilds the agent.
+// Format: "model-name" (searches all providers) or "provider/model-name".
+// Model matching is case-insensitive.
+// If largeModel is provided but smallModel is not, the small model defaults to
+// the provider's default small model.
+func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, smallModel string) error {
+	providers := app.config.Providers.Copy()
+
+	largeMatches, smallMatches, err := findModels(providers, largeModel, smallModel)
+	if err != nil {
+		return err
+	}
+
+	var largeProviderID string
+
+	// Override large model.
+	if largeModel != "" {
+		found, err := validateMatches(largeMatches, largeModel, "large")
+		if err != nil {
+			return err
+		}
+		largeProviderID = found.provider
+		slog.Info("Overriding large model for non-interactive run", "provider", found.provider, "model", found.modelID)
+		app.config.Models[config.SelectedModelTypeLarge] = config.SelectedModel{
+			Provider: found.provider,
+			Model:    found.modelID,
+		}
+	}
+
+	// Override small model.
+	switch {
+	case smallModel != "":
+		found, err := validateMatches(smallMatches, smallModel, "small")
+		if err != nil {
+			return err
+		}
+		slog.Info("Overriding small model for non-interactive run", "provider", found.provider, "model", found.modelID)
+		app.config.Models[config.SelectedModelTypeSmall] = config.SelectedModel{
+			Provider: found.provider,
+			Model:    found.modelID,
+		}
+
+	case largeModel != "":
+		// No small model specified, but large model was - use provider's default.
+		smallCfg := app.getDefaultSmallModel(largeProviderID)
+		app.config.Models[config.SelectedModelTypeSmall] = smallCfg
+	}
+
+	return app.AgentCoordinator.UpdateModels(ctx)
+}
+
+// provider. Falls back to the large model if no default is found.
+func (app *App) getDefaultSmallModel(providerID string) config.SelectedModel {
+	cfg := app.config
+	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
+
+	// Find the provider in the known providers list to get its default small model.
+	knownProviders, _ := config.Providers(cfg)
+	var knownProvider *catwalk.Provider
+	for _, p := range knownProviders {
+		if string(p.ID) == providerID {
+			knownProvider = &p
+			break
+		}
+	}
+
+	// For unknown/local providers, use the large model as small.
+	if knownProvider == nil {
+		slog.Warn("Using large model as small model for unknown provider", "provider", providerID, "model", largeModelCfg.Model)
+		return largeModelCfg
+	}
+
+	defaultSmallModelID := knownProvider.DefaultSmallModelID
+	model := cfg.GetModel(providerID, defaultSmallModelID)
+	if model == nil {
+		slog.Warn("Default small model not found, using large model", "provider", providerID, "model", largeModelCfg.Model)
+		return largeModelCfg
+	}
+
+	slog.Info("Using provider default small model", "provider", providerID, "model", defaultSmallModelID)
+	return config.SelectedModel{
+		Provider:        providerID,
+		Model:           defaultSmallModelID,
+		MaxTokens:       model.DefaultMaxTokens,
+		ReasoningEffort: model.DefaultReasoningEffort,
+	}
+}
+
 func (app *App) setupEvents() {
 	ctx, cancel := context.WithCancel(app.globalCtx)
 	app.eventsCtx = ctx

internal/app/provider.go 🔗

@@ -0,0 +1,95 @@
+package app
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/config"
+	xstrings "github.com/charmbracelet/x/exp/strings"
+)
+
+// parseModelStr parses a model string into provider filter and model ID.
+// Format: "model-name" or "provider/model-name" or "synthetic/moonshot/kimi-k2".
+// This function only checks if the first component is a valid provider name; if not,
+// it treats the entire string as a model ID (which may contain slashes).
+func parseModelStr(providers map[string]config.ProviderConfig, modelStr string) (providerFilter, modelID string) {
+	parts := strings.Split(modelStr, "/")
+	if len(parts) == 1 {
+		return "", parts[0]
+	}
+	// Check if the first part is a valid provider name
+	if _, ok := providers[parts[0]]; ok {
+		return parts[0], strings.Join(parts[1:], "/")
+	}
+
+	// First part is not a valid provider, treat entire string as model ID
+	return "", modelStr
+}
+
+// modelMatch represents a found model.
+type modelMatch struct {
+	provider string
+	modelID  string
+}
+
+func findModels(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch, error) {
+	largeProviderFilter, largeModelID := parseModelStr(providers, largeModel)
+	smallProviderFilter, smallModelID := parseModelStr(providers, smallModel)
+
+	// Validate provider filters exist.
+	for _, pf := range []struct {
+		filter, label string
+	}{
+		{largeProviderFilter, "large"},
+		{smallProviderFilter, "small"},
+	} {
+		if pf.filter != "" {
+			if _, ok := providers[pf.filter]; !ok {
+				return nil, nil, fmt.Errorf("%s model: provider %q not found in configuration. Use 'crush models' to list available models", pf.label, pf.filter)
+			}
+		}
+	}
+
+	// Find matching models in a single pass.
+	var largeMatches, smallMatches []modelMatch
+	for name, provider := range providers {
+		if provider.Disable {
+			continue
+		}
+		for _, m := range provider.Models {
+			if filter(largeModelID, largeProviderFilter, m.ID, name) {
+				largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
+			}
+			if filter(smallModelID, smallProviderFilter, m.ID, name) {
+				smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
+			}
+		}
+	}
+
+	return largeMatches, smallMatches, nil
+}
+
+func filter(modelFilter, providerFilter, model, provider string) bool {
+	return modelFilter != "" && model == modelFilter &&
+		(providerFilter == "" || provider == providerFilter)
+}
+
+// Validate and return a single match.
+func validateMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
+	switch {
+	case len(matches) == 0:
+		return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
+	case len(matches) > 1:
+		names := make([]string, len(matches))
+		for i, m := range matches {
+			names[i] = m.provider
+		}
+		return modelMatch{}, fmt.Errorf(
+			"%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
+			label,
+			modelID,
+			xstrings.EnglishJoin(names, true),
+		)
+	}
+	return matches[0], nil
+}

internal/app/provider_test.go 🔗

@@ -0,0 +1,210 @@
+package app
+
+import (
+	"testing"
+
+	"github.com/charmbracelet/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/stretchr/testify/require"
+)
+
+func TestParseModelStr(t *testing.T) {
+	tests := []struct {
+		name            string
+		modelStr        string
+		expectedFilter  string
+		expectedModelID string
+		setupProviders  func() map[string]config.ProviderConfig
+	}{
+		{
+			name:            "simple model with no slashes",
+			modelStr:        "gpt-4o",
+			expectedFilter:  "",
+			expectedModelID: "gpt-4o",
+			setupProviders:  setupMockProviders,
+		},
+		{
+			name:            "valid provider and model",
+			modelStr:        "openai/gpt-4o",
+			expectedFilter:  "openai",
+			expectedModelID: "gpt-4o",
+			setupProviders:  setupMockProviders,
+		},
+		{
+			name:            "model with multiple slashes and first part is invalid provider",
+			modelStr:        "moonshot/kimi-k2",
+			expectedFilter:  "",
+			expectedModelID: "moonshot/kimi-k2",
+			setupProviders:  setupMockProviders,
+		},
+		{
+			name:            "full path with valid provider and model with slashes",
+			modelStr:        "synthetic/moonshot/kimi-k2",
+			expectedFilter:  "synthetic",
+			expectedModelID: "moonshot/kimi-k2",
+			setupProviders:  setupMockProvidersWithSlashes,
+		},
+		{
+			name:            "empty model string",
+			modelStr:        "",
+			expectedFilter:  "",
+			expectedModelID: "",
+			setupProviders:  setupMockProviders,
+		},
+		{
+			name:            "model with trailing slash but valid provider",
+			modelStr:        "openai/",
+			expectedFilter:  "openai",
+			expectedModelID: "",
+			setupProviders:  setupMockProviders,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			providers := tt.setupProviders()
+			filter, modelID := parseModelStr(providers, tt.modelStr)
+
+			require.Equal(t, tt.expectedFilter, filter, "provider filter mismatch")
+			require.Equal(t, tt.expectedModelID, modelID, "model ID mismatch")
+		})
+	}
+}
+
+func setupMockProviders() map[string]config.ProviderConfig {
+	return map[string]config.ProviderConfig{
+		"openai": {
+			ID:     "openai",
+			Name:   "OpenAI",
+			Models: []catwalk.Model{{ID: "gpt-4o"}, {ID: "gpt-4o-mini"}},
+		},
+		"anthropic": {
+			ID:     "anthropic",
+			Name:   "Anthropic",
+			Models: []catwalk.Model{{ID: "claude-3-sonnet"}, {ID: "claude-3-opus"}},
+		},
+	}
+}
+
+func setupMockProvidersWithSlashes() map[string]config.ProviderConfig {
+	return map[string]config.ProviderConfig{
+		"synthetic": {
+			ID:   "synthetic",
+			Name: "Synthetic",
+			Models: []catwalk.Model{
+				{ID: "moonshot/kimi-k2"},
+				{ID: "deepseek/deepseek-chat"},
+			},
+		},
+		"openai": {
+			ID:     "openai",
+			Name:   "OpenAI",
+			Models: []catwalk.Model{{ID: "gpt-4o"}},
+		},
+	}
+}
+
+func TestFindModels(t *testing.T) {
+	tests := []struct {
+		name             string
+		modelStr         string
+		expectedProvider string
+		expectedModelID  string
+		expectError      bool
+		errorContains    string
+		setupProviders   func() map[string]config.ProviderConfig
+	}{
+		{
+			name:             "simple model found in one provider",
+			modelStr:         "gpt-4o",
+			expectedProvider: "openai",
+			expectedModelID:  "gpt-4o",
+			expectError:      false,
+			setupProviders:   setupMockProviders,
+		},
+		{
+			name:             "model with slashes in ID",
+			modelStr:         "moonshot/kimi-k2",
+			expectedProvider: "synthetic",
+			expectedModelID:  "moonshot/kimi-k2",
+			expectError:      false,
+			setupProviders:   setupMockProvidersWithSlashes,
+		},
+		{
+			name:             "provider and model with slashes in ID",
+			modelStr:         "synthetic/moonshot/kimi-k2",
+			expectedProvider: "synthetic",
+			expectedModelID:  "moonshot/kimi-k2",
+			expectError:      false,
+			setupProviders:   setupMockProvidersWithSlashes,
+		},
+		{
+			name:           "model not found",
+			modelStr:       "nonexistent-model",
+			expectError:    true,
+			errorContains:  "not found",
+			setupProviders: setupMockProviders,
+		},
+		{
+			name:           "invalid provider specified",
+			modelStr:       "nonexistent-provider/gpt-4o",
+			expectError:    true,
+			errorContains:  "provider",
+			setupProviders: setupMockProviders,
+		},
+		{
+			name:          "model found in multiple providers without provider filter",
+			modelStr:      "shared-model",
+			expectError:   true,
+			errorContains: "multiple providers",
+			setupProviders: func() map[string]config.ProviderConfig {
+				return map[string]config.ProviderConfig{
+					"openai": {
+						ID:     "openai",
+						Models: []catwalk.Model{{ID: "shared-model"}},
+					},
+					"anthropic": {
+						ID:     "anthropic",
+						Models: []catwalk.Model{{ID: "shared-model"}},
+					},
+				}
+			},
+		},
+		{
+			name:           "empty model string",
+			modelStr:       "",
+			expectError:    true,
+			errorContains:  "not found",
+			setupProviders: setupMockProviders,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			providers := tt.setupProviders()
+
+			// Use findModels with the model as "large" and empty "small".
+			matches, _, err := findModels(providers, tt.modelStr, "")
+			if err != nil {
+				if tt.expectError {
+					require.Contains(t, err.Error(), tt.errorContains)
+				} else {
+					require.NoError(t, err)
+				}
+				return
+			}
+
+			// Validate the matches.
+			match, err := validateMatches(matches, tt.modelStr, "large")
+
+			if tt.expectError {
+				require.Error(t, err)
+				require.Contains(t, err.Error(), tt.errorContains)
+			} else {
+				require.NoError(t, err)
+				require.Equal(t, tt.expectedProvider, match.provider)
+				require.Equal(t, tt.expectedModelID, match.modelID)
+			}
+		})
+	}
+}

internal/cmd/models.go 🔗

@@ -0,0 +1,110 @@
+package cmd
+
+import (
+	"fmt"
+	"os"
+	"slices"
+	"sort"
+	"strings"
+
+	"charm.land/lipgloss/v2/tree"
+	"github.com/charmbracelet/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/mattn/go-isatty"
+	"github.com/spf13/cobra"
+)
+
+var modelsCmd = &cobra.Command{
+	Use:   "models",
+	Short: "List all available models from configured providers",
+	Long:  `List all available models from configured providers. Shows provider name and model IDs.`,
+	Example: `# List all available models
+crush models
+
+# Search models
+crush models gpt5`,
+	Args: cobra.ArbitraryArgs,
+	RunE: func(cmd *cobra.Command, args []string) error {
+		cwd, err := ResolveCwd(cmd)
+		if err != nil {
+			return err
+		}
+
+		dataDir, _ := cmd.Flags().GetString("data-dir")
+		debug, _ := cmd.Flags().GetBool("debug")
+
+		cfg, err := config.Init(cwd, dataDir, debug)
+		if err != nil {
+			return err
+		}
+
+		if !cfg.IsConfigured() {
+			return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
+		}
+
+		term := strings.ToLower(strings.Join(args, " "))
+		filter := func(p config.ProviderConfig, m catwalk.Model) bool {
+			for _, s := range []string{p.ID, p.Name, m.ID, m.Name} {
+				if term == "" || strings.Contains(strings.ToLower(s), term) {
+					return true
+				}
+			}
+			return false
+		}
+
+		var providerIDs []string
+		providerModels := make(map[string][]string)
+
+		for providerID, provider := range cfg.Providers.Seq2() {
+			if provider.Disable {
+				continue
+			}
+			var found bool
+			for _, model := range provider.Models {
+				if !filter(provider, model) {
+					continue
+				}
+				providerModels[providerID] = append(providerModels[providerID], model.ID)
+				found = true
+			}
+			if !found {
+				continue
+			}
+			slices.Sort(providerModels[providerID])
+			providerIDs = append(providerIDs, providerID)
+		}
+		sort.Strings(providerIDs)
+
+		if len(providerIDs) == 0 && len(args) == 0 {
+			return fmt.Errorf("no enabled providers found")
+		}
+		if len(providerIDs) == 0 {
+			return fmt.Errorf("no enabled providers found matching %q", term)
+		}
+
+		if !isatty.IsTerminal(os.Stdout.Fd()) {
+			for _, providerID := range providerIDs {
+				for _, modelID := range providerModels[providerID] {
+					fmt.Println(providerID + "/" + modelID)
+				}
+			}
+			return nil
+		}
+
+		t := tree.New()
+		for _, providerID := range providerIDs {
+			providerNode := tree.Root(providerID)
+			for _, modelID := range providerModels[providerID] {
+				providerNode.Child(modelID)
+			}
+			t.Child(providerNode)
+		}
+
+		cmd.Println(t)
+		return nil
+	},
+}
+
+func init() {
+	rootCmd.AddCommand(modelsCmd)
+}

internal/cmd/run.go 🔗

@@ -32,6 +32,8 @@ crush run --quiet "Generate a README for this project"
   `,
 	RunE: func(cmd *cobra.Command, args []string) error {
 		quiet, _ := cmd.Flags().GetBool("quiet")
+		largeModel, _ := cmd.Flags().GetString("model")
+		smallModel, _ := cmd.Flags().GetString("small-model")
 
 		// Cancel on SIGINT or SIGTERM.
 		ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
@@ -62,7 +64,7 @@ crush run --quiet "Generate a README for this project"
 		event.SetNonInteractive(true)
 		event.AppInitialized()
 
-		return app.RunNonInteractive(ctx, os.Stdout, prompt, quiet)
+		return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet)
 	},
 	PostRun: func(cmd *cobra.Command, args []string) {
 		event.AppExited()
@@ -71,4 +73,6 @@ crush run --quiet "Generate a README for this project"
 
 func init() {
 	runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
+	runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
+	runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
 }