run.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io"
  7	"log/slog"
  8	"os"
  9	"os/signal"
 10	"strings"
 11	"time"
 12
 13	"charm.land/lipgloss/v2"
 14	"charm.land/log/v2"
 15	"github.com/charmbracelet/crush/internal/client"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/event"
 18	"github.com/charmbracelet/crush/internal/format"
 19	"github.com/charmbracelet/crush/internal/proto"
 20	"github.com/charmbracelet/crush/internal/pubsub"
 21	"github.com/charmbracelet/crush/internal/session"
 22	"github.com/charmbracelet/crush/internal/ui/anim"
 23	"github.com/charmbracelet/crush/internal/ui/styles"
 24	"github.com/charmbracelet/crush/internal/workspace"
 25	"github.com/charmbracelet/x/ansi"
 26	"github.com/charmbracelet/x/exp/charmtone"
 27	"github.com/charmbracelet/x/term"
 28	"github.com/google/uuid"
 29	"github.com/spf13/cobra"
 30)
 31
 32var runCmd = &cobra.Command{
 33	Aliases: []string{"r"},
 34	Use:     "run [prompt...]",
 35	Short:   "Run a single non-interactive prompt",
 36	Long: `Run a single prompt in non-interactive mode and exit.
 37The prompt can be provided as arguments or piped from stdin.`,
 38	Example: `
 39# Run a simple prompt
 40crush run "Guess my 5 favorite Pokรฉmon"
 41
 42# Pipe input from stdin
 43curl https://charm.land | crush run "Summarize this website"
 44
 45# Read from a file
 46crush run "What is this code doing?" <<< prrr.go
 47
 48# Redirect output to a file
 49crush run "Generate a hot README for this project" > MY_HOT_README.md
 50
 51# Run in quiet mode (hide the spinner)
 52crush run --quiet "Generate a README for this project"
 53
 54# Run in verbose mode (show logs)
 55crush run --verbose "Generate a README for this project"
 56
 57# Continue a previous session
 58crush run --session {session-id} "Follow up on your last response"
 59
 60# Continue the most recent session
 61crush run --continue "Follow up on your last response"
 62
 63  `,
 64	RunE: func(cmd *cobra.Command, args []string) error {
 65		var (
 66			quiet, _      = cmd.Flags().GetBool("quiet")
 67			verbose, _    = cmd.Flags().GetBool("verbose")
 68			largeModel, _ = cmd.Flags().GetString("model")
 69			smallModel, _ = cmd.Flags().GetString("small-model")
 70			sessionID, _  = cmd.Flags().GetString("session")
 71			useLast, _    = cmd.Flags().GetBool("continue")
 72		)
 73
 74		// Cancel on SIGINT or SIGTERM.
 75		ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
 76		defer cancel()
 77
 78		prompt := strings.Join(args, " ")
 79
 80		prompt, err := MaybePrependStdin(prompt)
 81		if err != nil {
 82			slog.Error("Failed to read from stdin", "error", err)
 83			return err
 84		}
 85
 86		if prompt == "" {
 87			return fmt.Errorf("no prompt provided")
 88		}
 89
 90		event.SetNonInteractive(true)
 91
 92		switch {
 93		case sessionID != "":
 94			event.SetContinueBySessionID(true)
 95		case useLast:
 96			event.SetContinueLastSession(true)
 97		}
 98
 99		if useClientServer() {
100			c, ws, cleanup, err := connectToServer(cmd)
101			if err != nil {
102				return err
103			}
104			defer cleanup()
105
106			event.AppInitialized()
107
108			if sessionID != "" {
109				sess, err := resolveSessionByID(ctx, c, ws.ID, sessionID)
110				if err != nil {
111					return err
112				}
113				sessionID = sess.ID
114			}
115
116			if !ws.Config.IsConfigured() {
117				return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
118			}
119
120			if verbose {
121				slog.SetDefault(slog.New(log.New(os.Stderr)))
122			}
123
124			return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
125		}
126
127		ws, cleanup, err := setupLocalWorkspace(cmd)
128		if err != nil {
129			return err
130		}
131		defer cleanup()
132
133		event.AppInitialized()
134
135		if !ws.Config().IsConfigured() {
136			return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
137		}
138
139		if verbose {
140			slog.SetDefault(slog.New(log.New(os.Stderr)))
141		}
142
143		appWs := ws.(*workspace.AppWorkspace)
144		return appWs.App().RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
145	},
146}
147
148func init() {
149	runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
150	runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
151	runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
152	runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
153	runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
154	runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
155	runCmd.MarkFlagsMutuallyExclusive("session", "continue")
156}
157
158// runNonInteractive executes the agent via the server and streams output
159// to stdout.
160func runNonInteractive(
161	ctx context.Context,
162	c *client.Client,
163	ws *proto.Workspace,
164	prompt, largeModel, smallModel string,
165	hideSpinner bool,
166	continueSessionID string,
167	useLast bool,
168) error {
169	slog.Info("Running in non-interactive mode")
170
171	ctx, cancel := context.WithCancel(ctx)
172	defer cancel()
173
174	if largeModel != "" || smallModel != "" {
175		if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
176			return fmt.Errorf("failed to override models: %w", err)
177		}
178	}
179
180	var (
181		spinner   *format.Spinner
182		stdoutTTY bool
183		stderrTTY bool
184		stdinTTY  bool
185		progress  bool
186	)
187
188	stdoutTTY = term.IsTerminal(os.Stdout.Fd())
189	stderrTTY = term.IsTerminal(os.Stderr.Fd())
190	stdinTTY = term.IsTerminal(os.Stdin.Fd())
191	progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
192
193	if !hideSpinner && stderrTTY {
194		t := styles.ThemeForProvider(ws.Config.Models[config.SelectedModelTypeLarge].Provider)
195
196		hasDarkBG := true
197		if stdinTTY && stdoutTTY {
198			hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
199		}
200		defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.WorkingLabelColor)
201
202		spinner = format.NewSpinner(ctx, cancel, anim.Settings{
203			Size:        10,
204			Label:       "Generating",
205			LabelColor:  defaultFG,
206			GradColorA:  t.WorkingGradFromColor,
207			GradColorB:  t.WorkingGradToColor,
208			CycleColors: true,
209		})
210		spinner.Start()
211	}
212
213	stopSpinner := func() {
214		if !hideSpinner && spinner != nil {
215			spinner.Stop()
216			spinner = nil
217		}
218	}
219
220	// Wait for the agent to become ready (MCP init, etc).
221	if err := waitForAgent(ctx, c, ws.ID); err != nil {
222		stopSpinner()
223		return fmt.Errorf("agent not ready: %w", err)
224	}
225
226	// Force-update agent models so MCP tools are loaded.
227	if err := c.UpdateAgent(ctx, ws.ID); err != nil {
228		slog.Warn("Failed to update agent", "error", err)
229	}
230
231	defer stopSpinner()
232
233	sess, err := resolveSession(ctx, c, ws.ID, continueSessionID, useLast)
234	if err != nil {
235		return fmt.Errorf("failed to resolve session: %w", err)
236	}
237	if continueSessionID != "" || useLast {
238		slog.Info("Continuing session for non-interactive run", "session_id", sess.ID)
239	} else {
240		slog.Info("Created session for non-interactive run", "session_id", sess.ID)
241	}
242
243	events, err := c.SubscribeEvents(ctx, ws.ID)
244	if err != nil {
245		return fmt.Errorf("failed to subscribe to events: %w", err)
246	}
247
248	// Mint a per-call RunID so we can correlate the terminal
249	// RunComplete with *this* SendMessage even if the session was
250	// busy and another turn finished first. Without it the stream
251	// loop would exit on whichever RunComplete arrived first for
252	// the same session and drop the queued prompt's output.
253	runID := uuid.New().String()
254	if err := c.SendMessage(ctx, ws.ID, sess.ID, runID, prompt); err != nil {
255		return fmt.Errorf("failed to send message: %w", err)
256	}
257
258	stream := &runStream{
259		sessionID: sess.ID,
260		runID:     runID,
261		out:       os.Stdout,
262		read:      make(map[string]int),
263	}
264
265	defer func() {
266		if progress && stderrTTY {
267			_, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
268		}
269		_, _ = fmt.Fprintln(os.Stdout)
270	}()
271
272	for {
273		if progress && stderrTTY {
274			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
275		}
276
277		select {
278		case ev, ok := <-events:
279			if !ok {
280				stopSpinner()
281				return nil
282			}
283
284			done, err := stream.handle(ev, stopSpinner)
285			if err != nil {
286				return err
287			}
288			if done {
289				return nil
290			}
291
292		case <-ctx.Done():
293			stopSpinner()
294			return ctx.Err()
295		}
296	}
297}
298
299// runStream tracks the per-message stdout cursor and the
300// reconciliation state used by [runNonInteractive] to translate
301// streaming SSE events into a final, complete stdout for `crush run`.
302// It is split out so the state machine can be exercised in unit tests
303// without spinning up the full server/client harness.
304//
305// runID, when non-empty, is the authoritative correlator for the
306// terminal RunComplete event: the stream suppresses live message
307// events and only exits on a RunComplete whose RunID matches, so a
308// turn that finishes first on the same session (e.g. when our prompt
309// was queued behind a busy session) cannot contaminate stdout or
310// terminate us prematurely. When empty (older servers, tests that
311// don't supply one) the stream falls back to SessionID-only matching
312// and live message streaming, which is still correct for the
313// single-turn case.
314type runStream struct {
315	sessionID string
316	runID     string
317	out       io.Writer
318	read      map[string]int
319	printed   bool
320}
321
322// handle processes one SSE event. Returns done=true when the run
323// loop should exit (RunComplete observed); returns an error only
324// when the agent run failed (not on context cancel โ€” that path is
325// handled by the caller's select). stopSpinner is called on the
326// first observable assistant output and on completion; passing nil
327// is safe for tests.
328func (s *runStream) handle(ev any, stopSpinner func()) (done bool, err error) {
329	stop := func() {
330		if stopSpinner != nil {
331			stopSpinner()
332		}
333	}
334	switch e := ev.(type) {
335	case pubsub.Event[proto.Message]:
336		msg := e.Payload
337		if msg.SessionID != s.sessionID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
338			return false, nil
339		}
340		if s.runID != "" {
341			return false, nil
342		}
343		stop()
344
345		content := msg.Content().String()
346		readBytes := s.read[msg.ID]
347		if len(content) < readBytes {
348			slog.Error("Non-interactive: message content shorter than read bytes",
349				"message_length", len(content), "read_bytes", readBytes)
350			return false, fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
351		}
352
353		part := content[readBytes:]
354		if readBytes == 0 {
355			part = strings.TrimLeft(part, " \t")
356		}
357		if s.printed || strings.TrimSpace(part) != "" {
358			s.printed = true
359			fmt.Fprint(s.out, part)
360		}
361		s.read[msg.ID] = len(content)
362		return false, nil
363
364	case pubsub.Event[proto.RunComplete]:
365		// RunComplete is the authoritative end-of-run signal. We
366		// exit on it instead of guessing from message finish parts,
367		// which fire on every tool-call step too and were the
368		// source of the regression where `crush run` exited
369		// mid-turn on finish.reason == tool_use.
370		//
371		// Correlation:
372		//   - if we minted a RunID for this SendMessage, only the
373		//     event whose RunID matches is ours; any other turn
374		//     finishing first on the same session (busy-session
375		//     queue path) must be ignored.
376		//   - if we have no RunID (older server, tests), fall back
377		//     to SessionID matching.
378		if s.runID != "" {
379			if e.Payload.RunID != s.runID {
380				return false, nil
381			}
382		} else if e.Payload.SessionID != s.sessionID {
383			return false, nil
384		}
385		stop()
386		if e.Payload.Error != "" && !e.Payload.Cancelled {
387			return true, fmt.Errorf("agent run failed: %s", e.Payload.Error)
388		}
389		// Reconcile stdout against the authoritative final
390		// assistant text carried in the event. The pubsub fan-in
391		// does not serialize publishes across upstream brokers, so
392		// the final message event may not have reached this loop
393		// yet; the embedded Text field is the backstop that
394		// guarantees the full final text always appears on stdout.
395		if e.Payload.MessageID != "" {
396			full := e.Payload.Text
397			readBytes := s.read[e.Payload.MessageID]
398			if readBytes < len(full) {
399				tail := full[readBytes:]
400				if readBytes == 0 {
401					tail = strings.TrimLeft(tail, " \t")
402				}
403				if s.printed || strings.TrimSpace(tail) != "" {
404					s.printed = true
405					fmt.Fprint(s.out, tail)
406				}
407			}
408		}
409		return true, nil
410
411	case pubsub.Event[proto.AgentEvent]:
412		if e.Payload.Error == nil {
413			return false, nil
414		}
415		// Attribute the error to our run before treating it as
416		// fatal. Async errors from an unrelated workspace run share
417		// this channel, so a foreign failure must not abort us:
418		//   - if the event carries a RunID, it is the authoritative
419		//     correlator: it must match our run exactly, otherwise it
420		//     belongs to a different request and we ignore it.
421		//   - if the event carries no RunID (older server), fall back
422		//     to SessionID: it must be present and match our session,
423		//     otherwise we ignore it.
424		if e.Payload.RunID != "" {
425			if e.Payload.RunID != s.runID {
426				return false, nil
427			}
428		} else if e.Payload.SessionID == "" || e.Payload.SessionID != s.sessionID {
429			return false, nil
430		}
431		stop()
432		return true, fmt.Errorf("agent error: %w", e.Payload.Error)
433	}
434	return false, nil
435}
436
437// waitForAgent polls GetAgentInfo until the agent is ready, with a
438// timeout.
439func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
440	timeout := time.After(30 * time.Second)
441	for {
442		info, err := c.GetAgentInfo(ctx, wsID)
443		if err == nil && info.IsReady {
444			return nil
445		}
446		select {
447		case <-timeout:
448			if err != nil {
449				return fmt.Errorf("timeout waiting for agent: %w", err)
450			}
451			return fmt.Errorf("timeout waiting for agent readiness")
452		case <-ctx.Done():
453			return ctx.Err()
454		case <-time.After(200 * time.Millisecond):
455		}
456	}
457}
458
459// overrideModels resolves model strings and updates the workspace
460// configuration via the server.
461func overrideModels(
462	ctx context.Context,
463	c *client.Client,
464	ws *proto.Workspace,
465	largeModel, smallModel string,
466) error {
467	cfg, err := c.GetConfig(ctx, ws.ID)
468	if err != nil {
469		return fmt.Errorf("failed to get config: %w", err)
470	}
471
472	providers := cfg.Providers.Copy()
473
474	largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
475
476	var largeProviderID string
477
478	if largeModel != "" {
479		found, err := validateModelMatches(largeMatches, largeModel, "large")
480		if err != nil {
481			return err
482		}
483		largeProviderID = found.provider
484		slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
485		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
486			Provider: found.provider,
487			Model:    found.modelID,
488		}); err != nil {
489			return fmt.Errorf("failed to set large model: %w", err)
490		}
491	}
492
493	switch {
494	case smallModel != "":
495		found, err := validateModelMatches(smallMatches, smallModel, "small")
496		if err != nil {
497			return err
498		}
499		slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
500		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
501			Provider: found.provider,
502			Model:    found.modelID,
503		}); err != nil {
504			return fmt.Errorf("failed to set small model: %w", err)
505		}
506
507	case largeModel != "":
508		sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
509		if err != nil {
510			slog.Warn("Failed to get default small model", "error", err)
511		} else if sm != nil {
512			if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
513				return fmt.Errorf("failed to set small model: %w", err)
514			}
515		}
516	}
517
518	return c.UpdateAgent(ctx, ws.ID)
519}
520
521type modelMatch struct {
522	provider string
523	modelID  string
524}
525
526// findModelMatches searches providers for matching large/small model
527// strings.
528func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
529	largeFilter, largeID := parseModelString(largeModel)
530	smallFilter, smallID := parseModelString(smallModel)
531
532	var largeMatches, smallMatches []modelMatch
533	for name, provider := range providers {
534		if provider.Disable {
535			continue
536		}
537		for _, m := range provider.Models {
538			if matchesModel(largeID, largeFilter, m.ID, name) {
539				largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
540			}
541			if matchesModel(smallID, smallFilter, m.ID, name) {
542				smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
543			}
544		}
545	}
546	return largeMatches, smallMatches
547}
548
549// parseModelString splits "provider/model" into (provider, model) or
550// ("", model).
551func parseModelString(s string) (string, string) {
552	if s == "" {
553		return "", ""
554	}
555	if idx := strings.Index(s, "/"); idx >= 0 {
556		return s[:idx], s[idx+1:]
557	}
558	return "", s
559}
560
561// matchesModel returns true if the model ID matches the filter
562// criteria.
563func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
564	if wantID == "" {
565		return false
566	}
567	if wantProvider != "" && wantProvider != providerName {
568		return false
569	}
570	return strings.EqualFold(modelID, wantID)
571}
572
573// validateModelMatches ensures exactly one match exists.
574func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
575	switch {
576	case len(matches) == 0:
577		return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
578	case len(matches) > 1:
579		names := make([]string, len(matches))
580		for i, m := range matches {
581			names[i] = m.provider
582		}
583		return modelMatch{}, fmt.Errorf(
584			"%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
585			label, modelID, strings.Join(names, ", "),
586		)
587	}
588	return matches[0], nil
589}
590
591// resolveSession returns the session to use for a non-interactive run.
592// If continueSessionID is set it fetches that session; if useLast is set it
593// returns the most recently updated top-level session; otherwise it creates a
594// new one.
595func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) {
596	switch {
597	case continueSessionID != "":
598		sess, err := c.GetSession(ctx, wsID, continueSessionID)
599		if err != nil {
600			return nil, fmt.Errorf("session not found: %s", continueSessionID)
601		}
602		if sess.ParentSessionID != "" {
603			return nil, fmt.Errorf("cannot continue a child session: %s", continueSessionID)
604		}
605		return sess, nil
606
607	case useLast:
608		sessions, err := c.ListSessions(ctx, wsID)
609		if err != nil || len(sessions) == 0 {
610			return nil, fmt.Errorf("no sessions found to continue")
611		}
612		last := sessions[0]
613		for _, s := range sessions[1:] {
614			if s.UpdatedAt > last.UpdatedAt && s.ParentSessionID == "" {
615				last = s
616			}
617		}
618		return &last, nil
619
620	default:
621		return c.CreateSession(ctx, wsID, "non-interactive")
622	}
623}
624
625// resolveSessionByID resolves a session ID that may be a full UUID or a hash
626// prefix returned by crush session list.
627func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) {
628	if sess, err := c.GetSession(ctx, wsID, id); err == nil {
629		return sess, nil
630	}
631
632	sessions, err := c.ListSessions(ctx, wsID)
633	if err != nil {
634		return nil, err
635	}
636
637	var matches []proto.Session
638	for _, s := range sessions {
639		hash := session.HashID(s.ID)
640		if hash == id || strings.HasPrefix(hash, id) {
641			matches = append(matches, s)
642		}
643	}
644
645	switch len(matches) {
646	case 0:
647		return nil, fmt.Errorf("session %q not found", id)
648	case 1:
649		return &matches[0], nil
650	default:
651		return nil, fmt.Errorf("session ID %q is ambiguous (%d matches)", id, len(matches))
652	}
653}