run.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"os"
  8	"os/signal"
  9	"strings"
 10	"time"
 11
 12	"charm.land/lipgloss/v2"
 13	"charm.land/log/v2"
 14	"github.com/charmbracelet/crush/internal/client"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/event"
 17	"github.com/charmbracelet/crush/internal/format"
 18	"github.com/charmbracelet/crush/internal/proto"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/session"
 21	"github.com/charmbracelet/crush/internal/ui/anim"
 22	"github.com/charmbracelet/crush/internal/ui/styles"
 23	"github.com/charmbracelet/crush/internal/workspace"
 24	"github.com/charmbracelet/x/ansi"
 25	"github.com/charmbracelet/x/exp/charmtone"
 26	"github.com/charmbracelet/x/term"
 27	"github.com/spf13/cobra"
 28)
 29
 30var runCmd = &cobra.Command{
 31	Aliases: []string{"r"},
 32	Use:     "run [prompt...]",
 33	Short:   "Run a single non-interactive prompt",
 34	Long: `Run a single prompt in non-interactive mode and exit.
 35The prompt can be provided as arguments or piped from stdin.`,
 36	Example: `
 37# Run a simple prompt
 38crush run "Guess my 5 favorite Pokรฉmon"
 39
 40# Pipe input from stdin
 41curl https://charm.land | crush run "Summarize this website"
 42
 43# Read from a file
 44crush run "What is this code doing?" <<< prrr.go
 45
 46# Redirect output to a file
 47crush run "Generate a hot README for this project" > MY_HOT_README.md
 48
 49# Run in quiet mode (hide the spinner)
 50crush run --quiet "Generate a README for this project"
 51
 52# Run in verbose mode (show logs)
 53crush run --verbose "Generate a README for this project"
 54
 55# Continue a previous session
 56crush run --session {session-id} "Follow up on your last response"
 57
 58# Continue the most recent session
 59crush run --continue "Follow up on your last response"
 60
 61  `,
 62	RunE: func(cmd *cobra.Command, args []string) error {
 63		var (
 64			quiet, _      = cmd.Flags().GetBool("quiet")
 65			verbose, _    = cmd.Flags().GetBool("verbose")
 66			largeModel, _ = cmd.Flags().GetString("model")
 67			smallModel, _ = cmd.Flags().GetString("small-model")
 68			sessionID, _  = cmd.Flags().GetString("session")
 69			useLast, _    = cmd.Flags().GetBool("continue")
 70		)
 71
 72		// Cancel on SIGINT or SIGTERM.
 73		ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
 74		defer cancel()
 75
 76		prompt := strings.Join(args, " ")
 77
 78		prompt, err := MaybePrependStdin(prompt)
 79		if err != nil {
 80			slog.Error("Failed to read from stdin", "error", err)
 81			return err
 82		}
 83
 84		if prompt == "" {
 85			return fmt.Errorf("no prompt provided")
 86		}
 87
 88		event.SetNonInteractive(true)
 89		event.AppInitialized()
 90
 91		switch {
 92		case sessionID != "":
 93			event.SetContinueBySessionID(true)
 94		case useLast:
 95			event.SetContinueLastSession(true)
 96		}
 97
 98		if useClientServer() {
 99			c, ws, cleanup, err := connectToServer(cmd)
100			if err != nil {
101				return err
102			}
103			defer cleanup()
104
105			if sessionID != "" {
106				sess, err := resolveSessionByID(ctx, c, ws.ID, sessionID)
107				if err != nil {
108					return err
109				}
110				sessionID = sess.ID
111			}
112
113			if !ws.Config.IsConfigured() {
114				return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
115			}
116
117			if verbose {
118				slog.SetDefault(slog.New(log.New(os.Stderr)))
119			}
120
121			return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
122		}
123
124		ws, cleanup, err := setupLocalWorkspace(cmd)
125		if err != nil {
126			return err
127		}
128		defer cleanup()
129
130		if !ws.Config().IsConfigured() {
131			return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
132		}
133
134		if verbose {
135			slog.SetDefault(slog.New(log.New(os.Stderr)))
136		}
137
138		appWs := ws.(*workspace.AppWorkspace)
139		return appWs.App().RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
140	},
141}
142
143func init() {
144	runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
145	runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
146	runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
147	runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
148	runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
149	runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
150	runCmd.MarkFlagsMutuallyExclusive("session", "continue")
151}
152
153// runNonInteractive executes the agent via the server and streams output
154// to stdout.
155func runNonInteractive(
156	ctx context.Context,
157	c *client.Client,
158	ws *proto.Workspace,
159	prompt, largeModel, smallModel string,
160	hideSpinner bool,
161	continueSessionID string,
162	useLast bool,
163) error {
164	slog.Info("Running in non-interactive mode")
165
166	ctx, cancel := context.WithCancel(ctx)
167	defer cancel()
168
169	if largeModel != "" || smallModel != "" {
170		if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
171			return fmt.Errorf("failed to override models: %w", err)
172		}
173	}
174
175	var (
176		spinner   *format.Spinner
177		stdoutTTY bool
178		stderrTTY bool
179		stdinTTY  bool
180		progress  bool
181	)
182
183	stdoutTTY = term.IsTerminal(os.Stdout.Fd())
184	stderrTTY = term.IsTerminal(os.Stderr.Fd())
185	stdinTTY = term.IsTerminal(os.Stdin.Fd())
186	progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
187
188	if !hideSpinner && stderrTTY {
189		t := styles.DefaultStyles()
190
191		hasDarkBG := true
192		if stdinTTY && stdoutTTY {
193			hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
194		}
195		defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
196
197		spinner = format.NewSpinner(ctx, cancel, anim.Settings{
198			Size:        10,
199			Label:       "Generating",
200			LabelColor:  defaultFG,
201			GradColorA:  t.Primary,
202			GradColorB:  t.Secondary,
203			CycleColors: true,
204		})
205		spinner.Start()
206	}
207
208	stopSpinner := func() {
209		if !hideSpinner && spinner != nil {
210			spinner.Stop()
211			spinner = nil
212		}
213	}
214
215	// Wait for the agent to become ready (MCP init, etc).
216	if err := waitForAgent(ctx, c, ws.ID); err != nil {
217		stopSpinner()
218		return fmt.Errorf("agent not ready: %w", err)
219	}
220
221	// Force-update agent models so MCP tools are loaded.
222	if err := c.UpdateAgent(ctx, ws.ID); err != nil {
223		slog.Warn("Failed to update agent", "error", err)
224	}
225
226	defer stopSpinner()
227
228	sess, err := resolveSession(ctx, c, ws.ID, continueSessionID, useLast)
229	if err != nil {
230		return fmt.Errorf("failed to resolve session: %w", err)
231	}
232	if continueSessionID != "" || useLast {
233		slog.Info("Continuing session for non-interactive run", "session_id", sess.ID)
234	} else {
235		slog.Info("Created session for non-interactive run", "session_id", sess.ID)
236	}
237
238	events, err := c.SubscribeEvents(ctx, ws.ID)
239	if err != nil {
240		return fmt.Errorf("failed to subscribe to events: %w", err)
241	}
242
243	if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil {
244		return fmt.Errorf("failed to send message: %w", err)
245	}
246
247	messageReadBytes := make(map[string]int)
248	var printed bool
249
250	defer func() {
251		if progress && stderrTTY {
252			_, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
253		}
254		_, _ = fmt.Fprintln(os.Stdout)
255	}()
256
257	for {
258		if progress && stderrTTY {
259			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
260		}
261
262		select {
263		case ev, ok := <-events:
264			if !ok {
265				stopSpinner()
266				return nil
267			}
268
269			switch e := ev.(type) {
270			case pubsub.Event[proto.Message]:
271				msg := e.Payload
272				if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
273					continue
274				}
275				stopSpinner()
276
277				content := msg.Content().String()
278				readBytes := messageReadBytes[msg.ID]
279
280				if len(content) < readBytes {
281					slog.Error("Non-interactive: message content shorter than read bytes",
282						"message_length", len(content), "read_bytes", readBytes)
283					return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
284				}
285
286				part := content[readBytes:]
287				if readBytes == 0 {
288					part = strings.TrimLeft(part, " \t")
289				}
290				if printed || strings.TrimSpace(part) != "" {
291					printed = true
292					fmt.Fprint(os.Stdout, part)
293				}
294				messageReadBytes[msg.ID] = len(content)
295
296				if msg.IsFinished() {
297					return nil
298				}
299
300			case pubsub.Event[proto.AgentEvent]:
301				if e.Payload.Error != nil {
302					stopSpinner()
303					return fmt.Errorf("agent error: %w", e.Payload.Error)
304				}
305			}
306
307		case <-ctx.Done():
308			stopSpinner()
309			return ctx.Err()
310		}
311	}
312}
313
314// waitForAgent polls GetAgentInfo until the agent is ready, with a
315// timeout.
316func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
317	timeout := time.After(30 * time.Second)
318	for {
319		info, err := c.GetAgentInfo(ctx, wsID)
320		if err == nil && info.IsReady {
321			return nil
322		}
323		select {
324		case <-timeout:
325			if err != nil {
326				return fmt.Errorf("timeout waiting for agent: %w", err)
327			}
328			return fmt.Errorf("timeout waiting for agent readiness")
329		case <-ctx.Done():
330			return ctx.Err()
331		case <-time.After(200 * time.Millisecond):
332		}
333	}
334}
335
336// overrideModels resolves model strings and updates the workspace
337// configuration via the server.
338func overrideModels(
339	ctx context.Context,
340	c *client.Client,
341	ws *proto.Workspace,
342	largeModel, smallModel string,
343) error {
344	cfg, err := c.GetConfig(ctx, ws.ID)
345	if err != nil {
346		return fmt.Errorf("failed to get config: %w", err)
347	}
348
349	providers := cfg.Providers.Copy()
350
351	largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
352
353	var largeProviderID string
354
355	if largeModel != "" {
356		found, err := validateModelMatches(largeMatches, largeModel, "large")
357		if err != nil {
358			return err
359		}
360		largeProviderID = found.provider
361		slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
362		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
363			Provider: found.provider,
364			Model:    found.modelID,
365		}); err != nil {
366			return fmt.Errorf("failed to set large model: %w", err)
367		}
368	}
369
370	switch {
371	case smallModel != "":
372		found, err := validateModelMatches(smallMatches, smallModel, "small")
373		if err != nil {
374			return err
375		}
376		slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
377		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
378			Provider: found.provider,
379			Model:    found.modelID,
380		}); err != nil {
381			return fmt.Errorf("failed to set small model: %w", err)
382		}
383
384	case largeModel != "":
385		sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
386		if err != nil {
387			slog.Warn("Failed to get default small model", "error", err)
388		} else if sm != nil {
389			if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
390				return fmt.Errorf("failed to set small model: %w", err)
391			}
392		}
393	}
394
395	return c.UpdateAgent(ctx, ws.ID)
396}
397
398type modelMatch struct {
399	provider string
400	modelID  string
401}
402
403// findModelMatches searches providers for matching large/small model
404// strings.
405func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
406	largeFilter, largeID := parseModelString(largeModel)
407	smallFilter, smallID := parseModelString(smallModel)
408
409	var largeMatches, smallMatches []modelMatch
410	for name, provider := range providers {
411		if provider.Disable {
412			continue
413		}
414		for _, m := range provider.Models {
415			if matchesModel(largeID, largeFilter, m.ID, name) {
416				largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
417			}
418			if matchesModel(smallID, smallFilter, m.ID, name) {
419				smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
420			}
421		}
422	}
423	return largeMatches, smallMatches
424}
425
426// parseModelString splits "provider/model" into (provider, model) or
427// ("", model).
428func parseModelString(s string) (string, string) {
429	if s == "" {
430		return "", ""
431	}
432	if idx := strings.Index(s, "/"); idx >= 0 {
433		return s[:idx], s[idx+1:]
434	}
435	return "", s
436}
437
438// matchesModel returns true if the model ID matches the filter
439// criteria.
440func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
441	if wantID == "" {
442		return false
443	}
444	if wantProvider != "" && wantProvider != providerName {
445		return false
446	}
447	return strings.EqualFold(modelID, wantID)
448}
449
450// validateModelMatches ensures exactly one match exists.
451func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
452	switch {
453	case len(matches) == 0:
454		return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
455	case len(matches) > 1:
456		names := make([]string, len(matches))
457		for i, m := range matches {
458			names[i] = m.provider
459		}
460		return modelMatch{}, fmt.Errorf(
461			"%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
462			label, modelID, strings.Join(names, ", "),
463		)
464	}
465	return matches[0], nil
466}
467
468// resolveSession returns the session to use for a non-interactive run.
469// If continueSessionID is set it fetches that session; if useLast is set it
470// returns the most recently updated top-level session; otherwise it creates a
471// new one.
472func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) {
473	switch {
474	case continueSessionID != "":
475		sess, err := c.GetSession(ctx, wsID, continueSessionID)
476		if err != nil {
477			return nil, fmt.Errorf("session not found: %s", continueSessionID)
478		}
479		if sess.ParentSessionID != "" {
480			return nil, fmt.Errorf("cannot continue a child session: %s", continueSessionID)
481		}
482		return sess, nil
483
484	case useLast:
485		sessions, err := c.ListSessions(ctx, wsID)
486		if err != nil || len(sessions) == 0 {
487			return nil, fmt.Errorf("no sessions found to continue")
488		}
489		last := sessions[0]
490		for _, s := range sessions[1:] {
491			if s.UpdatedAt > last.UpdatedAt && s.ParentSessionID == "" {
492				last = s
493			}
494		}
495		return &last, nil
496
497	default:
498		return c.CreateSession(ctx, wsID, "non-interactive")
499	}
500}
501
502// resolveSessionByID resolves a session ID that may be a full UUID or a hash
503// prefix returned by crush session list.
504func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) {
505	if sess, err := c.GetSession(ctx, wsID, id); err == nil {
506		return sess, nil
507	}
508
509	sessions, err := c.ListSessions(ctx, wsID)
510	if err != nil {
511		return nil, err
512	}
513
514	var matches []proto.Session
515	for _, s := range sessions {
516		hash := session.HashID(s.ID)
517		if hash == id || strings.HasPrefix(hash, id) {
518			matches = append(matches, s)
519		}
520	}
521
522	switch len(matches) {
523	case 0:
524		return nil, fmt.Errorf("session %q not found", id)
525	case 1:
526		return &matches[0], nil
527	default:
528		return nil, fmt.Errorf("session ID %q is ambiguous (%d matches)", id, len(matches))
529	}
530}