1package cmd
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "log/slog"
8 "os"
9 "os/signal"
10 "strings"
11 "time"
12
13 "charm.land/lipgloss/v2"
14 "charm.land/log/v2"
15 "github.com/charmbracelet/crush/internal/client"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/event"
18 "github.com/charmbracelet/crush/internal/format"
19 "github.com/charmbracelet/crush/internal/proto"
20 "github.com/charmbracelet/crush/internal/pubsub"
21 "github.com/charmbracelet/crush/internal/session"
22 "github.com/charmbracelet/crush/internal/ui/anim"
23 "github.com/charmbracelet/crush/internal/ui/styles"
24 "github.com/charmbracelet/crush/internal/workspace"
25 "github.com/charmbracelet/x/ansi"
26 "github.com/charmbracelet/x/exp/charmtone"
27 "github.com/charmbracelet/x/term"
28 "github.com/google/uuid"
29 "github.com/spf13/cobra"
30)
31
32var runCmd = &cobra.Command{
33 Aliases: []string{"r"},
34 Use: "run [prompt...]",
35 Short: "Run a single non-interactive prompt",
36 Long: `Run a single prompt in non-interactive mode and exit.
37The prompt can be provided as arguments or piped from stdin.`,
38 Example: `
39# Run a simple prompt
40crush run "Guess my 5 favorite Pokรฉmon"
41
42# Pipe input from stdin
43curl https://charm.land | crush run "Summarize this website"
44
45# Read from a file
46crush run "What is this code doing?" <<< prrr.go
47
48# Redirect output to a file
49crush run "Generate a hot README for this project" > MY_HOT_README.md
50
51# Run in quiet mode (hide the spinner)
52crush run --quiet "Generate a README for this project"
53
54# Run in verbose mode (show logs)
55crush run --verbose "Generate a README for this project"
56
57# Continue a previous session
58crush run --session {session-id} "Follow up on your last response"
59
60# Continue the most recent session
61crush run --continue "Follow up on your last response"
62
63 `,
64 RunE: func(cmd *cobra.Command, args []string) error {
65 var (
66 quiet, _ = cmd.Flags().GetBool("quiet")
67 verbose, _ = cmd.Flags().GetBool("verbose")
68 largeModel, _ = cmd.Flags().GetString("model")
69 smallModel, _ = cmd.Flags().GetString("small-model")
70 sessionID, _ = cmd.Flags().GetString("session")
71 useLast, _ = cmd.Flags().GetBool("continue")
72 )
73
74 // Cancel on SIGINT or SIGTERM.
75 ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
76 defer cancel()
77
78 prompt := strings.Join(args, " ")
79
80 prompt, err := MaybePrependStdin(prompt)
81 if err != nil {
82 slog.Error("Failed to read from stdin", "error", err)
83 return err
84 }
85
86 if prompt == "" {
87 return fmt.Errorf("no prompt provided")
88 }
89
90 event.SetNonInteractive(true)
91
92 switch {
93 case sessionID != "":
94 event.SetContinueBySessionID(true)
95 case useLast:
96 event.SetContinueLastSession(true)
97 }
98
99 if useClientServer() {
100 c, ws, cleanup, err := connectToServer(cmd)
101 if err != nil {
102 return err
103 }
104 defer cleanup()
105
106 event.AppInitialized()
107
108 if sessionID != "" {
109 sess, err := resolveSessionByID(ctx, c, ws.ID, sessionID)
110 if err != nil {
111 return err
112 }
113 sessionID = sess.ID
114 }
115
116 if !ws.Config.IsConfigured() {
117 return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
118 }
119
120 if verbose {
121 slog.SetDefault(slog.New(log.New(os.Stderr)))
122 }
123
124 return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
125 }
126
127 ws, cleanup, err := setupLocalWorkspace(cmd)
128 if err != nil {
129 return err
130 }
131 defer cleanup()
132
133 event.AppInitialized()
134
135 if !ws.Config().IsConfigured() {
136 return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
137 }
138
139 if verbose {
140 slog.SetDefault(slog.New(log.New(os.Stderr)))
141 }
142
143 appWs := ws.(*workspace.AppWorkspace)
144 return appWs.App().RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
145 },
146}
147
148func init() {
149 runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
150 runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
151 runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
152 runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
153 runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
154 runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
155 runCmd.MarkFlagsMutuallyExclusive("session", "continue")
156}
157
158// runNonInteractive executes the agent via the server and streams output
159// to stdout.
160func runNonInteractive(
161 ctx context.Context,
162 c *client.Client,
163 ws *proto.Workspace,
164 prompt, largeModel, smallModel string,
165 hideSpinner bool,
166 continueSessionID string,
167 useLast bool,
168) error {
169 slog.Info("Running in non-interactive mode")
170
171 ctx, cancel := context.WithCancel(ctx)
172 defer cancel()
173
174 if largeModel != "" || smallModel != "" {
175 if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
176 return fmt.Errorf("failed to override models: %w", err)
177 }
178 }
179
180 var (
181 spinner *format.Spinner
182 stdoutTTY bool
183 stderrTTY bool
184 stdinTTY bool
185 progress bool
186 )
187
188 stdoutTTY = term.IsTerminal(os.Stdout.Fd())
189 stderrTTY = term.IsTerminal(os.Stderr.Fd())
190 stdinTTY = term.IsTerminal(os.Stdin.Fd())
191 progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
192
193 if !hideSpinner && stderrTTY {
194 t := styles.ThemeForProvider(ws.Config.Models[config.SelectedModelTypeLarge].Provider)
195
196 hasDarkBG := true
197 if stdinTTY && stdoutTTY {
198 hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
199 }
200 defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.WorkingLabelColor)
201
202 spinner = format.NewSpinner(ctx, cancel, anim.Settings{
203 Size: 10,
204 Label: "Generating",
205 LabelColor: defaultFG,
206 GradColorA: t.WorkingGradFromColor,
207 GradColorB: t.WorkingGradToColor,
208 CycleColors: true,
209 })
210 spinner.Start()
211 }
212
213 stopSpinner := func() {
214 if !hideSpinner && spinner != nil {
215 spinner.Stop()
216 spinner = nil
217 }
218 }
219
220 // Wait for the agent to become ready (MCP init, etc).
221 if err := waitForAgent(ctx, c, ws.ID); err != nil {
222 stopSpinner()
223 return fmt.Errorf("agent not ready: %w", err)
224 }
225
226 // Force-update agent models so MCP tools are loaded.
227 if err := c.UpdateAgent(ctx, ws.ID); err != nil {
228 slog.Warn("Failed to update agent", "error", err)
229 }
230
231 defer stopSpinner()
232
233 sess, err := resolveSession(ctx, c, ws.ID, continueSessionID, useLast)
234 if err != nil {
235 return fmt.Errorf("failed to resolve session: %w", err)
236 }
237 if continueSessionID != "" || useLast {
238 slog.Info("Continuing session for non-interactive run", "session_id", sess.ID)
239 } else {
240 slog.Info("Created session for non-interactive run", "session_id", sess.ID)
241 }
242
243 events, err := c.SubscribeEvents(ctx, ws.ID)
244 if err != nil {
245 return fmt.Errorf("failed to subscribe to events: %w", err)
246 }
247
248 // Mint a per-call RunID so we can correlate the terminal
249 // RunComplete with *this* SendMessage even if the session was
250 // busy and another turn finished first. Without it the stream
251 // loop would exit on whichever RunComplete arrived first for
252 // the same session and drop the queued prompt's output.
253 runID := uuid.New().String()
254 if err := c.SendMessage(ctx, ws.ID, sess.ID, runID, prompt); err != nil {
255 return fmt.Errorf("failed to send message: %w", err)
256 }
257
258 stream := &runStream{
259 sessionID: sess.ID,
260 runID: runID,
261 out: os.Stdout,
262 read: make(map[string]int),
263 }
264
265 defer func() {
266 if progress && stderrTTY {
267 _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
268 }
269 _, _ = fmt.Fprintln(os.Stdout)
270 }()
271
272 for {
273 if progress && stderrTTY {
274 _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
275 }
276
277 select {
278 case ev, ok := <-events:
279 if !ok {
280 stopSpinner()
281 return nil
282 }
283
284 done, err := stream.handle(ev, stopSpinner)
285 if err != nil {
286 return err
287 }
288 if done {
289 return nil
290 }
291
292 case <-ctx.Done():
293 stopSpinner()
294 return ctx.Err()
295 }
296 }
297}
298
299// runStream tracks the per-message stdout cursor and the
300// reconciliation state used by [runNonInteractive] to translate
301// streaming SSE events into a final, complete stdout for `crush run`.
302// It is split out so the state machine can be exercised in unit tests
303// without spinning up the full server/client harness.
304//
305// runID, when non-empty, is the authoritative correlator for the
306// terminal RunComplete event: the stream suppresses live message
307// events and only exits on a RunComplete whose RunID matches, so a
308// turn that finishes first on the same session (e.g. when our prompt
309// was queued behind a busy session) cannot contaminate stdout or
310// terminate us prematurely. When empty (older servers, tests that
311// don't supply one) the stream falls back to SessionID-only matching
312// and live message streaming, which is still correct for the
313// single-turn case.
314type runStream struct {
315 sessionID string
316 runID string
317 out io.Writer
318 read map[string]int
319 printed bool
320}
321
322// handle processes one SSE event. Returns done=true when the run
323// loop should exit (RunComplete observed); returns an error only
324// when the agent run failed (not on context cancel โ that path is
325// handled by the caller's select). stopSpinner is called on the
326// first observable assistant output and on completion; passing nil
327// is safe for tests.
328func (s *runStream) handle(ev any, stopSpinner func()) (done bool, err error) {
329 stop := func() {
330 if stopSpinner != nil {
331 stopSpinner()
332 }
333 }
334 switch e := ev.(type) {
335 case pubsub.Event[proto.Message]:
336 msg := e.Payload
337 if msg.SessionID != s.sessionID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
338 return false, nil
339 }
340 if s.runID != "" {
341 return false, nil
342 }
343 stop()
344
345 content := msg.Content().String()
346 readBytes := s.read[msg.ID]
347 if len(content) < readBytes {
348 slog.Error("Non-interactive: message content shorter than read bytes",
349 "message_length", len(content), "read_bytes", readBytes)
350 return false, fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
351 }
352
353 part := content[readBytes:]
354 if readBytes == 0 {
355 part = strings.TrimLeft(part, " \t")
356 }
357 if s.printed || strings.TrimSpace(part) != "" {
358 s.printed = true
359 fmt.Fprint(s.out, part)
360 }
361 s.read[msg.ID] = len(content)
362 return false, nil
363
364 case pubsub.Event[proto.RunComplete]:
365 // RunComplete is the authoritative end-of-run signal. We
366 // exit on it instead of guessing from message finish parts,
367 // which fire on every tool-call step too and were the
368 // source of the regression where `crush run` exited
369 // mid-turn on finish.reason == tool_use.
370 //
371 // Correlation:
372 // - if we minted a RunID for this SendMessage, only the
373 // event whose RunID matches is ours; any other turn
374 // finishing first on the same session (busy-session
375 // queue path) must be ignored.
376 // - if we have no RunID (older server, tests), fall back
377 // to SessionID matching.
378 if s.runID != "" {
379 if e.Payload.RunID != s.runID {
380 return false, nil
381 }
382 } else if e.Payload.SessionID != s.sessionID {
383 return false, nil
384 }
385 stop()
386 if e.Payload.Error != "" && !e.Payload.Cancelled {
387 return true, fmt.Errorf("agent run failed: %s", e.Payload.Error)
388 }
389 // Reconcile stdout against the authoritative final
390 // assistant text carried in the event. The pubsub fan-in
391 // does not serialize publishes across upstream brokers, so
392 // the final message event may not have reached this loop
393 // yet; the embedded Text field is the backstop that
394 // guarantees the full final text always appears on stdout.
395 if e.Payload.MessageID != "" {
396 full := e.Payload.Text
397 readBytes := s.read[e.Payload.MessageID]
398 if readBytes < len(full) {
399 tail := full[readBytes:]
400 if readBytes == 0 {
401 tail = strings.TrimLeft(tail, " \t")
402 }
403 if s.printed || strings.TrimSpace(tail) != "" {
404 s.printed = true
405 fmt.Fprint(s.out, tail)
406 }
407 }
408 }
409 return true, nil
410
411 case pubsub.Event[proto.AgentEvent]:
412 if e.Payload.Error == nil {
413 return false, nil
414 }
415 // Attribute the error to our run before treating it as
416 // fatal. Async errors from an unrelated workspace run share
417 // this channel, so a foreign failure must not abort us:
418 // - if the event carries a RunID, it is the authoritative
419 // correlator: it must match our run exactly, otherwise it
420 // belongs to a different request and we ignore it.
421 // - if the event carries no RunID (older server), fall back
422 // to SessionID: it must be present and match our session,
423 // otherwise we ignore it.
424 if e.Payload.RunID != "" {
425 if e.Payload.RunID != s.runID {
426 return false, nil
427 }
428 } else if e.Payload.SessionID == "" || e.Payload.SessionID != s.sessionID {
429 return false, nil
430 }
431 stop()
432 return true, fmt.Errorf("agent error: %w", e.Payload.Error)
433 }
434 return false, nil
435}
436
437// waitForAgent polls GetAgentInfo until the agent is ready, with a
438// timeout.
439func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
440 timeout := time.After(30 * time.Second)
441 for {
442 info, err := c.GetAgentInfo(ctx, wsID)
443 if err == nil && info.IsReady {
444 return nil
445 }
446 select {
447 case <-timeout:
448 if err != nil {
449 return fmt.Errorf("timeout waiting for agent: %w", err)
450 }
451 return fmt.Errorf("timeout waiting for agent readiness")
452 case <-ctx.Done():
453 return ctx.Err()
454 case <-time.After(200 * time.Millisecond):
455 }
456 }
457}
458
459// overrideModels resolves model strings and updates the workspace
460// configuration via the server.
461func overrideModels(
462 ctx context.Context,
463 c *client.Client,
464 ws *proto.Workspace,
465 largeModel, smallModel string,
466) error {
467 cfg, err := c.GetConfig(ctx, ws.ID)
468 if err != nil {
469 return fmt.Errorf("failed to get config: %w", err)
470 }
471
472 providers := cfg.Providers.Copy()
473
474 largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
475
476 var largeProviderID string
477
478 if largeModel != "" {
479 found, err := validateModelMatches(largeMatches, largeModel, "large")
480 if err != nil {
481 return err
482 }
483 largeProviderID = found.provider
484 slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
485 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
486 Provider: found.provider,
487 Model: found.modelID,
488 }); err != nil {
489 return fmt.Errorf("failed to set large model: %w", err)
490 }
491 }
492
493 switch {
494 case smallModel != "":
495 found, err := validateModelMatches(smallMatches, smallModel, "small")
496 if err != nil {
497 return err
498 }
499 slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
500 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
501 Provider: found.provider,
502 Model: found.modelID,
503 }); err != nil {
504 return fmt.Errorf("failed to set small model: %w", err)
505 }
506
507 case largeModel != "":
508 sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
509 if err != nil {
510 slog.Warn("Failed to get default small model", "error", err)
511 } else if sm != nil {
512 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
513 return fmt.Errorf("failed to set small model: %w", err)
514 }
515 }
516 }
517
518 return c.UpdateAgent(ctx, ws.ID)
519}
520
521type modelMatch struct {
522 provider string
523 modelID string
524}
525
526// findModelMatches searches providers for matching large/small model
527// strings.
528func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
529 largeFilter, largeID := parseModelString(largeModel)
530 smallFilter, smallID := parseModelString(smallModel)
531
532 var largeMatches, smallMatches []modelMatch
533 for name, provider := range providers {
534 if provider.Disable {
535 continue
536 }
537 for _, m := range provider.Models {
538 if matchesModel(largeID, largeFilter, m.ID, name) {
539 largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
540 }
541 if matchesModel(smallID, smallFilter, m.ID, name) {
542 smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
543 }
544 }
545 }
546 return largeMatches, smallMatches
547}
548
549// parseModelString splits "provider/model" into (provider, model) or
550// ("", model).
551func parseModelString(s string) (string, string) {
552 if s == "" {
553 return "", ""
554 }
555 if idx := strings.Index(s, "/"); idx >= 0 {
556 return s[:idx], s[idx+1:]
557 }
558 return "", s
559}
560
561// matchesModel returns true if the model ID matches the filter
562// criteria.
563func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
564 if wantID == "" {
565 return false
566 }
567 if wantProvider != "" && wantProvider != providerName {
568 return false
569 }
570 return strings.EqualFold(modelID, wantID)
571}
572
573// validateModelMatches ensures exactly one match exists.
574func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
575 switch {
576 case len(matches) == 0:
577 return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
578 case len(matches) > 1:
579 names := make([]string, len(matches))
580 for i, m := range matches {
581 names[i] = m.provider
582 }
583 return modelMatch{}, fmt.Errorf(
584 "%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
585 label, modelID, strings.Join(names, ", "),
586 )
587 }
588 return matches[0], nil
589}
590
591// resolveSession returns the session to use for a non-interactive run.
592// If continueSessionID is set it fetches that session; if useLast is set it
593// returns the most recently updated top-level session; otherwise it creates a
594// new one.
595func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) {
596 switch {
597 case continueSessionID != "":
598 sess, err := c.GetSession(ctx, wsID, continueSessionID)
599 if err != nil {
600 return nil, fmt.Errorf("session not found: %s", continueSessionID)
601 }
602 if sess.ParentSessionID != "" {
603 return nil, fmt.Errorf("cannot continue a child session: %s", continueSessionID)
604 }
605 return sess, nil
606
607 case useLast:
608 sessions, err := c.ListSessions(ctx, wsID)
609 if err != nil || len(sessions) == 0 {
610 return nil, fmt.Errorf("no sessions found to continue")
611 }
612 last := sessions[0]
613 for _, s := range sessions[1:] {
614 if s.UpdatedAt > last.UpdatedAt && s.ParentSessionID == "" {
615 last = s
616 }
617 }
618 return &last, nil
619
620 default:
621 return c.CreateSession(ctx, wsID, "non-interactive")
622 }
623}
624
625// resolveSessionByID resolves a session ID that may be a full UUID or a hash
626// prefix returned by crush session list.
627func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) {
628 if sess, err := c.GetSession(ctx, wsID, id); err == nil {
629 return sess, nil
630 }
631
632 sessions, err := c.ListSessions(ctx, wsID)
633 if err != nil {
634 return nil, err
635 }
636
637 var matches []proto.Session
638 for _, s := range sessions {
639 hash := session.HashID(s.ID)
640 if hash == id || strings.HasPrefix(hash, id) {
641 matches = append(matches, s)
642 }
643 }
644
645 switch len(matches) {
646 case 0:
647 return nil, fmt.Errorf("session %q not found", id)
648 case 1:
649 return &matches[0], nil
650 default:
651 return nil, fmt.Errorf("session ID %q is ambiguous (%d matches)", id, len(matches))
652 }
653}