run.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"os"
  8	"os/signal"
  9	"strings"
 10	"time"
 11
 12	"charm.land/lipgloss/v2"
 13	"charm.land/log/v2"
 14	"github.com/charmbracelet/crush/internal/client"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/event"
 17	"github.com/charmbracelet/crush/internal/format"
 18	"github.com/charmbracelet/crush/internal/proto"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/ui/anim"
 21	"github.com/charmbracelet/crush/internal/ui/styles"
 22	"github.com/charmbracelet/x/ansi"
 23	"github.com/charmbracelet/x/exp/charmtone"
 24	"github.com/charmbracelet/x/term"
 25	"github.com/spf13/cobra"
 26)
 27
 28var runCmd = &cobra.Command{
 29	Use:   "run [prompt...]",
 30	Short: "Run a single non-interactive prompt",
 31	Long: `Run a single prompt in non-interactive mode and exit.
 32The prompt can be provided as arguments or piped from stdin.`,
 33	Example: `
 34# Run a simple prompt
 35crush run "Guess my 5 favorite Pokรฉmon"
 36
 37# Pipe input from stdin
 38curl https://charm.land | crush run "Summarize this website"
 39
 40# Read from a file
 41crush run "What is this code doing?" <<< prrr.go
 42
 43# Redirect output to a file
 44crush run "Generate a hot README for this project" > MY_HOT_README.md
 45
 46# Run in quiet mode (hide the spinner)
 47crush run --quiet "Generate a README for this project"
 48
 49# Run in verbose mode (show logs)
 50crush run --verbose "Generate a README for this project"
 51  `,
 52	RunE: func(cmd *cobra.Command, args []string) error {
 53		quiet, _ := cmd.Flags().GetBool("quiet")
 54		verbose, _ := cmd.Flags().GetBool("verbose")
 55		largeModel, _ := cmd.Flags().GetString("model")
 56		smallModel, _ := cmd.Flags().GetString("small-model")
 57
 58		// Cancel on SIGINT or SIGTERM.
 59		ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
 60		defer cancel()
 61
 62		c, ws, cleanup, err := connectToServer(cmd)
 63		if err != nil {
 64			return err
 65		}
 66		defer cleanup()
 67
 68		if !ws.Config.IsConfigured() {
 69			return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
 70		}
 71
 72		if verbose {
 73			slog.SetDefault(slog.New(log.New(os.Stderr)))
 74		}
 75
 76		prompt := strings.Join(args, " ")
 77
 78		prompt, err = MaybePrependStdin(prompt)
 79		if err != nil {
 80			slog.Error("Failed to read from stdin", "error", err)
 81			return err
 82		}
 83
 84		if prompt == "" {
 85			return fmt.Errorf("no prompt provided")
 86		}
 87
 88		event.SetNonInteractive(true)
 89		event.AppInitialized()
 90
 91		return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose)
 92	},
 93}
 94
 95func init() {
 96	runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
 97	runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
 98	runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
 99	runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
100}
101
102// runNonInteractive executes the agent via the server and streams output
103// to stdout.
104func runNonInteractive(
105	ctx context.Context,
106	c *client.Client,
107	ws *proto.Workspace,
108	prompt, largeModel, smallModel string,
109	hideSpinner bool,
110) error {
111	slog.Info("Running in non-interactive mode")
112
113	ctx, cancel := context.WithCancel(ctx)
114	defer cancel()
115
116	if largeModel != "" || smallModel != "" {
117		if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
118			return fmt.Errorf("failed to override models: %w", err)
119		}
120	}
121
122	var (
123		spinner   *format.Spinner
124		stdoutTTY bool
125		stderrTTY bool
126		stdinTTY  bool
127		progress  bool
128	)
129
130	stdoutTTY = term.IsTerminal(os.Stdout.Fd())
131	stderrTTY = term.IsTerminal(os.Stderr.Fd())
132	stdinTTY = term.IsTerminal(os.Stdin.Fd())
133	progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
134
135	if !hideSpinner && stderrTTY {
136		t := styles.DefaultStyles()
137
138		hasDarkBG := true
139		if stdinTTY && stdoutTTY {
140			hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
141		}
142		defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
143
144		spinner = format.NewSpinner(ctx, cancel, anim.Settings{
145			Size:        10,
146			Label:       "Generating",
147			LabelColor:  defaultFG,
148			GradColorA:  t.Primary,
149			GradColorB:  t.Secondary,
150			CycleColors: true,
151		})
152		spinner.Start()
153	}
154
155	stopSpinner := func() {
156		if !hideSpinner && spinner != nil {
157			spinner.Stop()
158			spinner = nil
159		}
160	}
161
162	// Wait for the agent to become ready (MCP init, etc).
163	if err := waitForAgent(ctx, c, ws.ID); err != nil {
164		stopSpinner()
165		return fmt.Errorf("agent not ready: %w", err)
166	}
167
168	// Force-update agent models so MCP tools are loaded.
169	if err := c.UpdateAgent(ctx, ws.ID); err != nil {
170		slog.Warn("Failed to update agent", "error", err)
171	}
172
173	defer stopSpinner()
174
175	sess, err := c.CreateSession(ctx, ws.ID, "non-interactive")
176	if err != nil {
177		return fmt.Errorf("failed to create session: %w", err)
178	}
179	slog.Info("Created session for non-interactive run", "session_id", sess.ID)
180
181	events, err := c.SubscribeEvents(ctx, ws.ID)
182	if err != nil {
183		return fmt.Errorf("failed to subscribe to events: %w", err)
184	}
185
186	if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil {
187		return fmt.Errorf("failed to send message: %w", err)
188	}
189
190	messageReadBytes := make(map[string]int)
191	var printed bool
192
193	defer func() {
194		if progress && stderrTTY {
195			_, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
196		}
197		_, _ = fmt.Fprintln(os.Stdout)
198	}()
199
200	for {
201		if progress && stderrTTY {
202			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
203		}
204
205		select {
206		case ev, ok := <-events:
207			if !ok {
208				stopSpinner()
209				return nil
210			}
211
212			switch e := ev.(type) {
213			case pubsub.Event[proto.Message]:
214				msg := e.Payload
215				if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
216					continue
217				}
218				stopSpinner()
219
220				content := msg.Content().String()
221				readBytes := messageReadBytes[msg.ID]
222
223				if len(content) < readBytes {
224					slog.Error("Non-interactive: message content shorter than read bytes",
225						"message_length", len(content), "read_bytes", readBytes)
226					return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
227				}
228
229				part := content[readBytes:]
230				if readBytes == 0 {
231					part = strings.TrimLeft(part, " \t")
232				}
233				if printed || strings.TrimSpace(part) != "" {
234					printed = true
235					fmt.Fprint(os.Stdout, part)
236				}
237				messageReadBytes[msg.ID] = len(content)
238
239				if msg.IsFinished() {
240					return nil
241				}
242
243			case pubsub.Event[proto.AgentEvent]:
244				if e.Payload.Error != nil {
245					stopSpinner()
246					return fmt.Errorf("agent error: %w", e.Payload.Error)
247				}
248			}
249
250		case <-ctx.Done():
251			stopSpinner()
252			return ctx.Err()
253		}
254	}
255}
256
257// waitForAgent polls GetAgentInfo until the agent is ready, with a
258// timeout.
259func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
260	timeout := time.After(30 * time.Second)
261	for {
262		info, err := c.GetAgentInfo(ctx, wsID)
263		if err == nil && info.IsReady {
264			return nil
265		}
266		select {
267		case <-timeout:
268			if err != nil {
269				return fmt.Errorf("timeout waiting for agent: %w", err)
270			}
271			return fmt.Errorf("timeout waiting for agent readiness")
272		case <-ctx.Done():
273			return ctx.Err()
274		case <-time.After(200 * time.Millisecond):
275		}
276	}
277}
278
279// overrideModels resolves model strings and updates the workspace
280// configuration via the server.
281func overrideModels(
282	ctx context.Context,
283	c *client.Client,
284	ws *proto.Workspace,
285	largeModel, smallModel string,
286) error {
287	cfg, err := c.GetConfig(ctx, ws.ID)
288	if err != nil {
289		return fmt.Errorf("failed to get config: %w", err)
290	}
291
292	providers := cfg.Providers.Copy()
293
294	largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
295
296	var largeProviderID string
297
298	if largeModel != "" {
299		found, err := validateModelMatches(largeMatches, largeModel, "large")
300		if err != nil {
301			return err
302		}
303		largeProviderID = found.provider
304		slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
305		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
306			Provider: found.provider,
307			Model:    found.modelID,
308		}); err != nil {
309			return fmt.Errorf("failed to set large model: %w", err)
310		}
311	}
312
313	switch {
314	case smallModel != "":
315		found, err := validateModelMatches(smallMatches, smallModel, "small")
316		if err != nil {
317			return err
318		}
319		slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
320		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
321			Provider: found.provider,
322			Model:    found.modelID,
323		}); err != nil {
324			return fmt.Errorf("failed to set small model: %w", err)
325		}
326
327	case largeModel != "":
328		sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
329		if err != nil {
330			slog.Warn("Failed to get default small model", "error", err)
331		} else if sm != nil {
332			if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
333				return fmt.Errorf("failed to set small model: %w", err)
334			}
335		}
336	}
337
338	return c.UpdateAgent(ctx, ws.ID)
339}
340
341type modelMatch struct {
342	provider string
343	modelID  string
344}
345
346// findModelMatches searches providers for matching large/small model
347// strings.
348func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
349	largeFilter, largeID := parseModelString(largeModel)
350	smallFilter, smallID := parseModelString(smallModel)
351
352	var largeMatches, smallMatches []modelMatch
353	for name, provider := range providers {
354		if provider.Disable {
355			continue
356		}
357		for _, m := range provider.Models {
358			if matchesModel(largeID, largeFilter, m.ID, name) {
359				largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
360			}
361			if matchesModel(smallID, smallFilter, m.ID, name) {
362				smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
363			}
364		}
365	}
366	return largeMatches, smallMatches
367}
368
369// parseModelString splits "provider/model" into (provider, model) or
370// ("", model).
371func parseModelString(s string) (string, string) {
372	if s == "" {
373		return "", ""
374	}
375	if idx := strings.Index(s, "/"); idx >= 0 {
376		return s[:idx], s[idx+1:]
377	}
378	return "", s
379}
380
381// matchesModel returns true if the model ID matches the filter
382// criteria.
383func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
384	if wantID == "" {
385		return false
386	}
387	if wantProvider != "" && wantProvider != providerName {
388		return false
389	}
390	return strings.EqualFold(modelID, wantID)
391}
392
393// validateModelMatches ensures exactly one match exists.
394func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
395	switch {
396	case len(matches) == 0:
397		return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
398	case len(matches) > 1:
399		names := make([]string, len(matches))
400		for i, m := range matches {
401			names[i] = m.provider
402		}
403		return modelMatch{}, fmt.Errorf(
404			"%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
405			label, modelID, strings.Join(names, ", "),
406		)
407	}
408	return matches[0], nil
409}