run.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"os"
  8	"os/signal"
  9	"strings"
 10	"time"
 11
 12	"charm.land/lipgloss/v2"
 13	"charm.land/log/v2"
 14	"github.com/charmbracelet/crush/internal/client"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/event"
 17	"github.com/charmbracelet/crush/internal/format"
 18	"github.com/charmbracelet/crush/internal/proto"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/session"
 21	"github.com/charmbracelet/crush/internal/ui/anim"
 22	"github.com/charmbracelet/crush/internal/ui/styles"
 23	"github.com/charmbracelet/crush/internal/workspace"
 24	"github.com/charmbracelet/x/ansi"
 25	"github.com/charmbracelet/x/exp/charmtone"
 26	"github.com/charmbracelet/x/term"
 27	"github.com/spf13/cobra"
 28)
 29
 30var runCmd = &cobra.Command{
 31	Aliases: []string{"r"},
 32	Use:     "run [prompt...]",
 33	Short:   "Run a single non-interactive prompt",
 34	Long: `Run a single prompt in non-interactive mode and exit.
 35The prompt can be provided as arguments or piped from stdin.`,
 36	Example: `
 37# Run a simple prompt
 38crush run "Guess my 5 favorite Pokรฉmon"
 39
 40# Pipe input from stdin
 41curl https://charm.land | crush run "Summarize this website"
 42
 43# Read from a file
 44crush run "What is this code doing?" <<< prrr.go
 45
 46# Redirect output to a file
 47crush run "Generate a hot README for this project" > MY_HOT_README.md
 48
 49# Run in quiet mode (hide the spinner)
 50crush run --quiet "Generate a README for this project"
 51
 52# Run in verbose mode (show logs)
 53crush run --verbose "Generate a README for this project"
 54
 55# Continue a previous session
 56crush run --session {session-id} "Follow up on your last response"
 57
 58# Continue the most recent session
 59crush run --continue "Follow up on your last response"
 60
 61  `,
 62	RunE: func(cmd *cobra.Command, args []string) error {
 63		var (
 64			quiet, _      = cmd.Flags().GetBool("quiet")
 65			verbose, _    = cmd.Flags().GetBool("verbose")
 66			largeModel, _ = cmd.Flags().GetString("model")
 67			smallModel, _ = cmd.Flags().GetString("small-model")
 68			sessionID, _  = cmd.Flags().GetString("session")
 69			useLast, _    = cmd.Flags().GetBool("continue")
 70		)
 71
 72		// Cancel on SIGINT or SIGTERM.
 73		ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
 74		defer cancel()
 75
 76		prompt := strings.Join(args, " ")
 77
 78		prompt, err := MaybePrependStdin(prompt)
 79		if err != nil {
 80			slog.Error("Failed to read from stdin", "error", err)
 81			return err
 82		}
 83
 84		if prompt == "" {
 85			return fmt.Errorf("no prompt provided")
 86		}
 87
 88		event.SetNonInteractive(true)
 89
 90		switch {
 91		case sessionID != "":
 92			event.SetContinueBySessionID(true)
 93		case useLast:
 94			event.SetContinueLastSession(true)
 95		}
 96
 97		if useClientServer() {
 98			c, ws, cleanup, err := connectToServer(cmd)
 99			if err != nil {
100				return err
101			}
102			defer cleanup()
103
104			event.AppInitialized()
105
106			if sessionID != "" {
107				sess, err := resolveSessionByID(ctx, c, ws.ID, sessionID)
108				if err != nil {
109					return err
110				}
111				sessionID = sess.ID
112			}
113
114			if !ws.Config.IsConfigured() {
115				return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
116			}
117
118			if verbose {
119				slog.SetDefault(slog.New(log.New(os.Stderr)))
120			}
121
122			return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
123		}
124
125		ws, cleanup, err := setupLocalWorkspace(cmd)
126		if err != nil {
127			return err
128		}
129		defer cleanup()
130
131		event.AppInitialized()
132
133		if !ws.Config().IsConfigured() {
134			return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
135		}
136
137		if verbose {
138			slog.SetDefault(slog.New(log.New(os.Stderr)))
139		}
140
141		appWs := ws.(*workspace.AppWorkspace)
142		return appWs.App().RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
143	},
144}
145
146func init() {
147	runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
148	runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
149	runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
150	runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
151	runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
152	runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
153	runCmd.MarkFlagsMutuallyExclusive("session", "continue")
154}
155
156// runNonInteractive executes the agent via the server and streams output
157// to stdout.
158func runNonInteractive(
159	ctx context.Context,
160	c *client.Client,
161	ws *proto.Workspace,
162	prompt, largeModel, smallModel string,
163	hideSpinner bool,
164	continueSessionID string,
165	useLast bool,
166) error {
167	slog.Info("Running in non-interactive mode")
168
169	ctx, cancel := context.WithCancel(ctx)
170	defer cancel()
171
172	if largeModel != "" || smallModel != "" {
173		if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
174			return fmt.Errorf("failed to override models: %w", err)
175		}
176	}
177
178	var (
179		spinner   *format.Spinner
180		stdoutTTY bool
181		stderrTTY bool
182		stdinTTY  bool
183		progress  bool
184	)
185
186	stdoutTTY = term.IsTerminal(os.Stdout.Fd())
187	stderrTTY = term.IsTerminal(os.Stderr.Fd())
188	stdinTTY = term.IsTerminal(os.Stdin.Fd())
189	progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
190
191	if !hideSpinner && stderrTTY {
192		t := styles.DefaultStyles()
193
194		hasDarkBG := true
195		if stdinTTY && stdoutTTY {
196			hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
197		}
198		defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
199
200		spinner = format.NewSpinner(ctx, cancel, anim.Settings{
201			Size:        10,
202			Label:       "Generating",
203			LabelColor:  defaultFG,
204			GradColorA:  t.Primary,
205			GradColorB:  t.Secondary,
206			CycleColors: true,
207		})
208		spinner.Start()
209	}
210
211	stopSpinner := func() {
212		if !hideSpinner && spinner != nil {
213			spinner.Stop()
214			spinner = nil
215		}
216	}
217
218	// Wait for the agent to become ready (MCP init, etc).
219	if err := waitForAgent(ctx, c, ws.ID); err != nil {
220		stopSpinner()
221		return fmt.Errorf("agent not ready: %w", err)
222	}
223
224	// Force-update agent models so MCP tools are loaded.
225	if err := c.UpdateAgent(ctx, ws.ID); err != nil {
226		slog.Warn("Failed to update agent", "error", err)
227	}
228
229	defer stopSpinner()
230
231	sess, err := resolveSession(ctx, c, ws.ID, continueSessionID, useLast)
232	if err != nil {
233		return fmt.Errorf("failed to resolve session: %w", err)
234	}
235	if continueSessionID != "" || useLast {
236		slog.Info("Continuing session for non-interactive run", "session_id", sess.ID)
237	} else {
238		slog.Info("Created session for non-interactive run", "session_id", sess.ID)
239	}
240
241	events, err := c.SubscribeEvents(ctx, ws.ID)
242	if err != nil {
243		return fmt.Errorf("failed to subscribe to events: %w", err)
244	}
245
246	if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil {
247		return fmt.Errorf("failed to send message: %w", err)
248	}
249
250	messageReadBytes := make(map[string]int)
251	var printed bool
252
253	defer func() {
254		if progress && stderrTTY {
255			_, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
256		}
257		_, _ = fmt.Fprintln(os.Stdout)
258	}()
259
260	for {
261		if progress && stderrTTY {
262			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
263		}
264
265		select {
266		case ev, ok := <-events:
267			if !ok {
268				stopSpinner()
269				return nil
270			}
271
272			switch e := ev.(type) {
273			case pubsub.Event[proto.Message]:
274				msg := e.Payload
275				if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
276					continue
277				}
278				stopSpinner()
279
280				content := msg.Content().String()
281				readBytes := messageReadBytes[msg.ID]
282
283				if len(content) < readBytes {
284					slog.Error("Non-interactive: message content shorter than read bytes",
285						"message_length", len(content), "read_bytes", readBytes)
286					return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
287				}
288
289				part := content[readBytes:]
290				if readBytes == 0 {
291					part = strings.TrimLeft(part, " \t")
292				}
293				if printed || strings.TrimSpace(part) != "" {
294					printed = true
295					fmt.Fprint(os.Stdout, part)
296				}
297				messageReadBytes[msg.ID] = len(content)
298
299				if msg.IsFinished() {
300					return nil
301				}
302
303			case pubsub.Event[proto.AgentEvent]:
304				if e.Payload.Error != nil {
305					stopSpinner()
306					return fmt.Errorf("agent error: %w", e.Payload.Error)
307				}
308			}
309
310		case <-ctx.Done():
311			stopSpinner()
312			return ctx.Err()
313		}
314	}
315}
316
317// waitForAgent polls GetAgentInfo until the agent is ready, with a
318// timeout.
319func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
320	timeout := time.After(30 * time.Second)
321	for {
322		info, err := c.GetAgentInfo(ctx, wsID)
323		if err == nil && info.IsReady {
324			return nil
325		}
326		select {
327		case <-timeout:
328			if err != nil {
329				return fmt.Errorf("timeout waiting for agent: %w", err)
330			}
331			return fmt.Errorf("timeout waiting for agent readiness")
332		case <-ctx.Done():
333			return ctx.Err()
334		case <-time.After(200 * time.Millisecond):
335		}
336	}
337}
338
339// overrideModels resolves model strings and updates the workspace
340// configuration via the server.
341func overrideModels(
342	ctx context.Context,
343	c *client.Client,
344	ws *proto.Workspace,
345	largeModel, smallModel string,
346) error {
347	cfg, err := c.GetConfig(ctx, ws.ID)
348	if err != nil {
349		return fmt.Errorf("failed to get config: %w", err)
350	}
351
352	providers := cfg.Providers.Copy()
353
354	largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
355
356	var largeProviderID string
357
358	if largeModel != "" {
359		found, err := validateModelMatches(largeMatches, largeModel, "large")
360		if err != nil {
361			return err
362		}
363		largeProviderID = found.provider
364		slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
365		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
366			Provider: found.provider,
367			Model:    found.modelID,
368		}); err != nil {
369			return fmt.Errorf("failed to set large model: %w", err)
370		}
371	}
372
373	switch {
374	case smallModel != "":
375		found, err := validateModelMatches(smallMatches, smallModel, "small")
376		if err != nil {
377			return err
378		}
379		slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
380		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
381			Provider: found.provider,
382			Model:    found.modelID,
383		}); err != nil {
384			return fmt.Errorf("failed to set small model: %w", err)
385		}
386
387	case largeModel != "":
388		sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
389		if err != nil {
390			slog.Warn("Failed to get default small model", "error", err)
391		} else if sm != nil {
392			if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
393				return fmt.Errorf("failed to set small model: %w", err)
394			}
395		}
396	}
397
398	return c.UpdateAgent(ctx, ws.ID)
399}
400
401type modelMatch struct {
402	provider string
403	modelID  string
404}
405
406// findModelMatches searches providers for matching large/small model
407// strings.
408func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
409	largeFilter, largeID := parseModelString(largeModel)
410	smallFilter, smallID := parseModelString(smallModel)
411
412	var largeMatches, smallMatches []modelMatch
413	for name, provider := range providers {
414		if provider.Disable {
415			continue
416		}
417		for _, m := range provider.Models {
418			if matchesModel(largeID, largeFilter, m.ID, name) {
419				largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
420			}
421			if matchesModel(smallID, smallFilter, m.ID, name) {
422				smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
423			}
424		}
425	}
426	return largeMatches, smallMatches
427}
428
429// parseModelString splits "provider/model" into (provider, model) or
430// ("", model).
431func parseModelString(s string) (string, string) {
432	if s == "" {
433		return "", ""
434	}
435	if idx := strings.Index(s, "/"); idx >= 0 {
436		return s[:idx], s[idx+1:]
437	}
438	return "", s
439}
440
441// matchesModel returns true if the model ID matches the filter
442// criteria.
443func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
444	if wantID == "" {
445		return false
446	}
447	if wantProvider != "" && wantProvider != providerName {
448		return false
449	}
450	return strings.EqualFold(modelID, wantID)
451}
452
453// validateModelMatches ensures exactly one match exists.
454func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
455	switch {
456	case len(matches) == 0:
457		return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
458	case len(matches) > 1:
459		names := make([]string, len(matches))
460		for i, m := range matches {
461			names[i] = m.provider
462		}
463		return modelMatch{}, fmt.Errorf(
464			"%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
465			label, modelID, strings.Join(names, ", "),
466		)
467	}
468	return matches[0], nil
469}
470
471// resolveSession returns the session to use for a non-interactive run.
472// If continueSessionID is set it fetches that session; if useLast is set it
473// returns the most recently updated top-level session; otherwise it creates a
474// new one.
475func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) {
476	switch {
477	case continueSessionID != "":
478		sess, err := c.GetSession(ctx, wsID, continueSessionID)
479		if err != nil {
480			return nil, fmt.Errorf("session not found: %s", continueSessionID)
481		}
482		if sess.ParentSessionID != "" {
483			return nil, fmt.Errorf("cannot continue a child session: %s", continueSessionID)
484		}
485		return sess, nil
486
487	case useLast:
488		sessions, err := c.ListSessions(ctx, wsID)
489		if err != nil || len(sessions) == 0 {
490			return nil, fmt.Errorf("no sessions found to continue")
491		}
492		last := sessions[0]
493		for _, s := range sessions[1:] {
494			if s.UpdatedAt > last.UpdatedAt && s.ParentSessionID == "" {
495				last = s
496			}
497		}
498		return &last, nil
499
500	default:
501		return c.CreateSession(ctx, wsID, "non-interactive")
502	}
503}
504
505// resolveSessionByID resolves a session ID that may be a full UUID or a hash
506// prefix returned by crush session list.
507func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) {
508	if sess, err := c.GetSession(ctx, wsID, id); err == nil {
509		return sess, nil
510	}
511
512	sessions, err := c.ListSessions(ctx, wsID)
513	if err != nil {
514		return nil, err
515	}
516
517	var matches []proto.Session
518	for _, s := range sessions {
519		hash := session.HashID(s.ID)
520		if hash == id || strings.HasPrefix(hash, id) {
521			matches = append(matches, s)
522		}
523	}
524
525	switch len(matches) {
526	case 0:
527		return nil, fmt.Errorf("session %q not found", id)
528	case 1:
529		return &matches[0], nil
530	default:
531		return nil, fmt.Errorf("session ID %q is ambiguous (%d matches)", id, len(matches))
532	}
533}