diff --git a/go.mod b/go.mod index d86aeba68814867bac69e6a113d0887d75405003..e7068c241fbc3d5c78c166445b77f240d437ea2b 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f github.com/charmbracelet/x/exp/ordered v0.1.0 github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff + github.com/charmbracelet/x/exp/strings v0.0.0-20260119114936-fd556377ea59 github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b github.com/charmbracelet/x/term v0.2.2 github.com/denisbrodbeck/machineid v1.0.1 @@ -38,6 +39,7 @@ require ( github.com/invopop/jsonschema v0.13.0 github.com/joho/godotenv v1.5.1 github.com/lucasb-eyer/go-colorful v1.3.0 + github.com/mattn/go-isatty v0.0.20 github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.30.4 @@ -135,7 +137,6 @@ require ( github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/klauspost/pgzip v1.2.6 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect github.com/mfridman/interpolate v0.0.2 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect diff --git a/go.sum b/go.sum index 09a239110a64328bd17452d2e4f8bbabe2a75770..8bb193020a32665c5189fa7c16acb9fa4995bb0e 100644 --- a/go.sum +++ b/go.sum @@ -118,6 +118,8 @@ github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sB github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8= github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff h1:Uwr+/JS+qnRcO/++xjYEDtW7x+P5E4+4cBiOHTt2Xfk= github.com/charmbracelet/x/exp/slice v0.0.0-20251201173703-9f73bfd934ff/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA= +github.com/charmbracelet/x/exp/strings v0.0.0-20260119114936-fd556377ea59 h1:cvPMInXNmK/CHjQU8eXC/oSnGfEKpQmndsEykh03bt0= +github.com/charmbracelet/x/exp/strings v0.0.0-20260119114936-fd556377ea59/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8= github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ= github.com/charmbracelet/x/json v0.2.0/go.mod h1:opFIflx2YgXgi49xVUu8gEQ21teFAxyMwvOiZhIvWNM= github.com/charmbracelet/x/powernap v0.0.0-20260113142046-c1fa3de7983b h1:5ye9hzBKH623bMVz5auIuY6K21loCdxpRmFle2O9R/8= diff --git a/internal/app/app.go b/internal/app/app.go index 96750a453e63c79b7363fef5b3aa9b09632de940..816f7a8b25cc945436a148110ce7d774750bb493 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -17,6 +17,7 @@ import ( tea "charm.land/bubbletea/v2" "charm.land/fantasy" "charm.land/lipgloss/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" @@ -131,12 +132,18 @@ func (app *App) Config() *config.Config { // RunNonInteractive runs the application in non-interactive mode with the // given prompt, printing to stdout. -func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt string, quiet bool) error { +func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, quiet bool) error { slog.Info("Running in non-interactive mode") ctx, cancel := context.WithCancel(ctx) defer cancel() + if largeModel != "" || smallModel != "" { + if err := app.overrideModelsForNonInteractive(ctx, largeModel, smallModel); err != nil { + return fmt.Errorf("failed to override models: %w", err) + } + } + var ( spinner *format.Spinner stdoutTTY bool @@ -299,6 +306,95 @@ func (app *App) UpdateAgentModel(ctx context.Context) error { return app.AgentCoordinator.UpdateModels(ctx) } +// overrideModelsForNonInteractive parses the model strings and temporarily +// overrides the model configurations, then rebuilds the agent. +// Format: "model-name" (searches all providers) or "provider/model-name". +// Model matching is case-insensitive. +// If largeModel is provided but smallModel is not, the small model defaults to +// the provider's default small model. +func (app *App) overrideModelsForNonInteractive(ctx context.Context, largeModel, smallModel string) error { + providers := app.config.Providers.Copy() + + largeMatches, smallMatches, err := findModels(providers, largeModel, smallModel) + if err != nil { + return err + } + + var largeProviderID string + + // Override large model. + if largeModel != "" { + found, err := validateMatches(largeMatches, largeModel, "large") + if err != nil { + return err + } + largeProviderID = found.provider + slog.Info("Overriding large model for non-interactive run", "provider", found.provider, "model", found.modelID) + app.config.Models[config.SelectedModelTypeLarge] = config.SelectedModel{ + Provider: found.provider, + Model: found.modelID, + } + } + + // Override small model. + switch { + case smallModel != "": + found, err := validateMatches(smallMatches, smallModel, "small") + if err != nil { + return err + } + slog.Info("Overriding small model for non-interactive run", "provider", found.provider, "model", found.modelID) + app.config.Models[config.SelectedModelTypeSmall] = config.SelectedModel{ + Provider: found.provider, + Model: found.modelID, + } + + case largeModel != "": + // No small model specified, but large model was - use provider's default. + smallCfg := app.getDefaultSmallModel(largeProviderID) + app.config.Models[config.SelectedModelTypeSmall] = smallCfg + } + + return app.AgentCoordinator.UpdateModels(ctx) +} + +// provider. Falls back to the large model if no default is found. +func (app *App) getDefaultSmallModel(providerID string) config.SelectedModel { + cfg := app.config + largeModelCfg := cfg.Models[config.SelectedModelTypeLarge] + + // Find the provider in the known providers list to get its default small model. + knownProviders, _ := config.Providers(cfg) + var knownProvider *catwalk.Provider + for _, p := range knownProviders { + if string(p.ID) == providerID { + knownProvider = &p + break + } + } + + // For unknown/local providers, use the large model as small. + if knownProvider == nil { + slog.Warn("Using large model as small model for unknown provider", "provider", providerID, "model", largeModelCfg.Model) + return largeModelCfg + } + + defaultSmallModelID := knownProvider.DefaultSmallModelID + model := cfg.GetModel(providerID, defaultSmallModelID) + if model == nil { + slog.Warn("Default small model not found, using large model", "provider", providerID, "model", largeModelCfg.Model) + return largeModelCfg + } + + slog.Info("Using provider default small model", "provider", providerID, "model", defaultSmallModelID) + return config.SelectedModel{ + Provider: providerID, + Model: defaultSmallModelID, + MaxTokens: model.DefaultMaxTokens, + ReasoningEffort: model.DefaultReasoningEffort, + } +} + func (app *App) setupEvents() { ctx, cancel := context.WithCancel(app.globalCtx) app.eventsCtx = ctx diff --git a/internal/app/provider.go b/internal/app/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..570edadf9e1647eeeeab32107d3da3a1d3494935 --- /dev/null +++ b/internal/app/provider.go @@ -0,0 +1,95 @@ +package app + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/crush/internal/config" + xstrings "github.com/charmbracelet/x/exp/strings" +) + +// parseModelStr parses a model string into provider filter and model ID. +// Format: "model-name" or "provider/model-name" or "synthetic/moonshot/kimi-k2". +// This function only checks if the first component is a valid provider name; if not, +// it treats the entire string as a model ID (which may contain slashes). +func parseModelStr(providers map[string]config.ProviderConfig, modelStr string) (providerFilter, modelID string) { + parts := strings.Split(modelStr, "/") + if len(parts) == 1 { + return "", parts[0] + } + // Check if the first part is a valid provider name + if _, ok := providers[parts[0]]; ok { + return parts[0], strings.Join(parts[1:], "/") + } + + // First part is not a valid provider, treat entire string as model ID + return "", modelStr +} + +// modelMatch represents a found model. +type modelMatch struct { + provider string + modelID string +} + +func findModels(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch, error) { + largeProviderFilter, largeModelID := parseModelStr(providers, largeModel) + smallProviderFilter, smallModelID := parseModelStr(providers, smallModel) + + // Validate provider filters exist. + for _, pf := range []struct { + filter, label string + }{ + {largeProviderFilter, "large"}, + {smallProviderFilter, "small"}, + } { + if pf.filter != "" { + if _, ok := providers[pf.filter]; !ok { + return nil, nil, fmt.Errorf("%s model: provider %q not found in configuration. Use 'crush models' to list available models", pf.label, pf.filter) + } + } + } + + // Find matching models in a single pass. + var largeMatches, smallMatches []modelMatch + for name, provider := range providers { + if provider.Disable { + continue + } + for _, m := range provider.Models { + if filter(largeModelID, largeProviderFilter, m.ID, name) { + largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID}) + } + if filter(smallModelID, smallProviderFilter, m.ID, name) { + smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID}) + } + } + } + + return largeMatches, smallMatches, nil +} + +func filter(modelFilter, providerFilter, model, provider string) bool { + return modelFilter != "" && model == modelFilter && + (providerFilter == "" || provider == providerFilter) +} + +// Validate and return a single match. +func validateMatches(matches []modelMatch, modelID, label string) (modelMatch, error) { + switch { + case len(matches) == 0: + return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID) + case len(matches) > 1: + names := make([]string, len(matches)) + for i, m := range matches { + names[i] = m.provider + } + return modelMatch{}, fmt.Errorf( + "%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format", + label, + modelID, + xstrings.EnglishJoin(names, true), + ) + } + return matches[0], nil +} diff --git a/internal/app/provider_test.go b/internal/app/provider_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c3acae64d1057f3bb8bd8f9a0cb6443dbe9731b7 --- /dev/null +++ b/internal/app/provider_test.go @@ -0,0 +1,210 @@ +package app + +import ( + "testing" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/stretchr/testify/require" +) + +func TestParseModelStr(t *testing.T) { + tests := []struct { + name string + modelStr string + expectedFilter string + expectedModelID string + setupProviders func() map[string]config.ProviderConfig + }{ + { + name: "simple model with no slashes", + modelStr: "gpt-4o", + expectedFilter: "", + expectedModelID: "gpt-4o", + setupProviders: setupMockProviders, + }, + { + name: "valid provider and model", + modelStr: "openai/gpt-4o", + expectedFilter: "openai", + expectedModelID: "gpt-4o", + setupProviders: setupMockProviders, + }, + { + name: "model with multiple slashes and first part is invalid provider", + modelStr: "moonshot/kimi-k2", + expectedFilter: "", + expectedModelID: "moonshot/kimi-k2", + setupProviders: setupMockProviders, + }, + { + name: "full path with valid provider and model with slashes", + modelStr: "synthetic/moonshot/kimi-k2", + expectedFilter: "synthetic", + expectedModelID: "moonshot/kimi-k2", + setupProviders: setupMockProvidersWithSlashes, + }, + { + name: "empty model string", + modelStr: "", + expectedFilter: "", + expectedModelID: "", + setupProviders: setupMockProviders, + }, + { + name: "model with trailing slash but valid provider", + modelStr: "openai/", + expectedFilter: "openai", + expectedModelID: "", + setupProviders: setupMockProviders, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + providers := tt.setupProviders() + filter, modelID := parseModelStr(providers, tt.modelStr) + + require.Equal(t, tt.expectedFilter, filter, "provider filter mismatch") + require.Equal(t, tt.expectedModelID, modelID, "model ID mismatch") + }) + } +} + +func setupMockProviders() map[string]config.ProviderConfig { + return map[string]config.ProviderConfig{ + "openai": { + ID: "openai", + Name: "OpenAI", + Models: []catwalk.Model{{ID: "gpt-4o"}, {ID: "gpt-4o-mini"}}, + }, + "anthropic": { + ID: "anthropic", + Name: "Anthropic", + Models: []catwalk.Model{{ID: "claude-3-sonnet"}, {ID: "claude-3-opus"}}, + }, + } +} + +func setupMockProvidersWithSlashes() map[string]config.ProviderConfig { + return map[string]config.ProviderConfig{ + "synthetic": { + ID: "synthetic", + Name: "Synthetic", + Models: []catwalk.Model{ + {ID: "moonshot/kimi-k2"}, + {ID: "deepseek/deepseek-chat"}, + }, + }, + "openai": { + ID: "openai", + Name: "OpenAI", + Models: []catwalk.Model{{ID: "gpt-4o"}}, + }, + } +} + +func TestFindModels(t *testing.T) { + tests := []struct { + name string + modelStr string + expectedProvider string + expectedModelID string + expectError bool + errorContains string + setupProviders func() map[string]config.ProviderConfig + }{ + { + name: "simple model found in one provider", + modelStr: "gpt-4o", + expectedProvider: "openai", + expectedModelID: "gpt-4o", + expectError: false, + setupProviders: setupMockProviders, + }, + { + name: "model with slashes in ID", + modelStr: "moonshot/kimi-k2", + expectedProvider: "synthetic", + expectedModelID: "moonshot/kimi-k2", + expectError: false, + setupProviders: setupMockProvidersWithSlashes, + }, + { + name: "provider and model with slashes in ID", + modelStr: "synthetic/moonshot/kimi-k2", + expectedProvider: "synthetic", + expectedModelID: "moonshot/kimi-k2", + expectError: false, + setupProviders: setupMockProvidersWithSlashes, + }, + { + name: "model not found", + modelStr: "nonexistent-model", + expectError: true, + errorContains: "not found", + setupProviders: setupMockProviders, + }, + { + name: "invalid provider specified", + modelStr: "nonexistent-provider/gpt-4o", + expectError: true, + errorContains: "provider", + setupProviders: setupMockProviders, + }, + { + name: "model found in multiple providers without provider filter", + modelStr: "shared-model", + expectError: true, + errorContains: "multiple providers", + setupProviders: func() map[string]config.ProviderConfig { + return map[string]config.ProviderConfig{ + "openai": { + ID: "openai", + Models: []catwalk.Model{{ID: "shared-model"}}, + }, + "anthropic": { + ID: "anthropic", + Models: []catwalk.Model{{ID: "shared-model"}}, + }, + } + }, + }, + { + name: "empty model string", + modelStr: "", + expectError: true, + errorContains: "not found", + setupProviders: setupMockProviders, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + providers := tt.setupProviders() + + // Use findModels with the model as "large" and empty "small". + matches, _, err := findModels(providers, tt.modelStr, "") + if err != nil { + if tt.expectError { + require.Contains(t, err.Error(), tt.errorContains) + } else { + require.NoError(t, err) + } + return + } + + // Validate the matches. + match, err := validateMatches(matches, tt.modelStr, "large") + + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorContains) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedProvider, match.provider) + require.Equal(t, tt.expectedModelID, match.modelID) + } + }) + } +} diff --git a/internal/cmd/models.go b/internal/cmd/models.go new file mode 100644 index 0000000000000000000000000000000000000000..3267469638ee83463e1785774d37c5d281d37de9 --- /dev/null +++ b/internal/cmd/models.go @@ -0,0 +1,110 @@ +package cmd + +import ( + "fmt" + "os" + "slices" + "sort" + "strings" + + "charm.land/lipgloss/v2/tree" + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/mattn/go-isatty" + "github.com/spf13/cobra" +) + +var modelsCmd = &cobra.Command{ + Use: "models", + Short: "List all available models from configured providers", + Long: `List all available models from configured providers. Shows provider name and model IDs.`, + Example: `# List all available models +crush models + +# Search models +crush models gpt5`, + Args: cobra.ArbitraryArgs, + RunE: func(cmd *cobra.Command, args []string) error { + cwd, err := ResolveCwd(cmd) + if err != nil { + return err + } + + dataDir, _ := cmd.Flags().GetString("data-dir") + debug, _ := cmd.Flags().GetBool("debug") + + cfg, err := config.Init(cwd, dataDir, debug) + if err != nil { + return err + } + + if !cfg.IsConfigured() { + return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively") + } + + term := strings.ToLower(strings.Join(args, " ")) + filter := func(p config.ProviderConfig, m catwalk.Model) bool { + for _, s := range []string{p.ID, p.Name, m.ID, m.Name} { + if term == "" || strings.Contains(strings.ToLower(s), term) { + return true + } + } + return false + } + + var providerIDs []string + providerModels := make(map[string][]string) + + for providerID, provider := range cfg.Providers.Seq2() { + if provider.Disable { + continue + } + var found bool + for _, model := range provider.Models { + if !filter(provider, model) { + continue + } + providerModels[providerID] = append(providerModels[providerID], model.ID) + found = true + } + if !found { + continue + } + slices.Sort(providerModels[providerID]) + providerIDs = append(providerIDs, providerID) + } + sort.Strings(providerIDs) + + if len(providerIDs) == 0 && len(args) == 0 { + return fmt.Errorf("no enabled providers found") + } + if len(providerIDs) == 0 { + return fmt.Errorf("no enabled providers found matching %q", term) + } + + if !isatty.IsTerminal(os.Stdout.Fd()) { + for _, providerID := range providerIDs { + for _, modelID := range providerModels[providerID] { + fmt.Println(providerID + "/" + modelID) + } + } + return nil + } + + t := tree.New() + for _, providerID := range providerIDs { + providerNode := tree.Root(providerID) + for _, modelID := range providerModels[providerID] { + providerNode.Child(modelID) + } + t.Child(providerNode) + } + + cmd.Println(t) + return nil + }, +} + +func init() { + rootCmd.AddCommand(modelsCmd) +} diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 6ebe5d79593bab6170e958ebdf26240d34327445..e4d72b41be13684e28ca6c2b85b79bfdcea52fc7 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -32,6 +32,8 @@ crush run --quiet "Generate a README for this project" `, RunE: func(cmd *cobra.Command, args []string) error { quiet, _ := cmd.Flags().GetBool("quiet") + largeModel, _ := cmd.Flags().GetString("model") + smallModel, _ := cmd.Flags().GetString("small-model") // Cancel on SIGINT or SIGTERM. ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) @@ -62,7 +64,7 @@ crush run --quiet "Generate a README for this project" event.SetNonInteractive(true) event.AppInitialized() - return app.RunNonInteractive(ctx, os.Stdout, prompt, quiet) + return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet) }, PostRun: func(cmd *cobra.Command, args []string) { event.AppExited() @@ -71,4 +73,6 @@ crush run --quiet "Generate a README for this project" func init() { runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner") + runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers") + runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider") }