1package cmd
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "os"
8 "os/signal"
9 "strings"
10 "time"
11
12 "charm.land/lipgloss/v2"
13 "charm.land/log/v2"
14 "github.com/charmbracelet/crush/internal/client"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/event"
17 "github.com/charmbracelet/crush/internal/format"
18 "github.com/charmbracelet/crush/internal/proto"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/session"
21 "github.com/charmbracelet/crush/internal/ui/anim"
22 "github.com/charmbracelet/crush/internal/ui/styles"
23 "github.com/charmbracelet/crush/internal/workspace"
24 "github.com/charmbracelet/x/ansi"
25 "github.com/charmbracelet/x/exp/charmtone"
26 "github.com/charmbracelet/x/term"
27 "github.com/spf13/cobra"
28)
29
30var runCmd = &cobra.Command{
31 Aliases: []string{"r"},
32 Use: "run [prompt...]",
33 Short: "Run a single non-interactive prompt",
34 Long: `Run a single prompt in non-interactive mode and exit.
35The prompt can be provided as arguments or piped from stdin.`,
36 Example: `
37# Run a simple prompt
38crush run "Guess my 5 favorite Pokรฉmon"
39
40# Pipe input from stdin
41curl https://charm.land | crush run "Summarize this website"
42
43# Read from a file
44crush run "What is this code doing?" <<< prrr.go
45
46# Redirect output to a file
47crush run "Generate a hot README for this project" > MY_HOT_README.md
48
49# Run in quiet mode (hide the spinner)
50crush run --quiet "Generate a README for this project"
51
52# Run in verbose mode (show logs)
53crush run --verbose "Generate a README for this project"
54
55# Continue a previous session
56crush run --session {session-id} "Follow up on your last response"
57
58# Continue the most recent session
59crush run --continue "Follow up on your last response"
60
61 `,
62 RunE: func(cmd *cobra.Command, args []string) error {
63 var (
64 quiet, _ = cmd.Flags().GetBool("quiet")
65 verbose, _ = cmd.Flags().GetBool("verbose")
66 largeModel, _ = cmd.Flags().GetString("model")
67 smallModel, _ = cmd.Flags().GetString("small-model")
68 sessionID, _ = cmd.Flags().GetString("session")
69 useLast, _ = cmd.Flags().GetBool("continue")
70 )
71
72 // Cancel on SIGINT or SIGTERM.
73 ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
74 defer cancel()
75
76 prompt := strings.Join(args, " ")
77
78 prompt, err := MaybePrependStdin(prompt)
79 if err != nil {
80 slog.Error("Failed to read from stdin", "error", err)
81 return err
82 }
83
84 if prompt == "" {
85 return fmt.Errorf("no prompt provided")
86 }
87
88 event.SetNonInteractive(true)
89
90 switch {
91 case sessionID != "":
92 event.SetContinueBySessionID(true)
93 case useLast:
94 event.SetContinueLastSession(true)
95 }
96
97 if useClientServer() {
98 c, ws, cleanup, err := connectToServer(cmd)
99 if err != nil {
100 return err
101 }
102 defer cleanup()
103
104 event.AppInitialized()
105
106 if sessionID != "" {
107 sess, err := resolveSessionByID(ctx, c, ws.ID, sessionID)
108 if err != nil {
109 return err
110 }
111 sessionID = sess.ID
112 }
113
114 if !ws.Config.IsConfigured() {
115 return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
116 }
117
118 if verbose {
119 slog.SetDefault(slog.New(log.New(os.Stderr)))
120 }
121
122 return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
123 }
124
125 ws, cleanup, err := setupLocalWorkspace(cmd)
126 if err != nil {
127 return err
128 }
129 defer cleanup()
130
131 event.AppInitialized()
132
133 if !ws.Config().IsConfigured() {
134 return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
135 }
136
137 if verbose {
138 slog.SetDefault(slog.New(log.New(os.Stderr)))
139 }
140
141 appWs := ws.(*workspace.AppWorkspace)
142 return appWs.App().RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
143 },
144}
145
146func init() {
147 runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
148 runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
149 runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
150 runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
151 runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
152 runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
153 runCmd.MarkFlagsMutuallyExclusive("session", "continue")
154}
155
156// runNonInteractive executes the agent via the server and streams output
157// to stdout.
158func runNonInteractive(
159 ctx context.Context,
160 c *client.Client,
161 ws *proto.Workspace,
162 prompt, largeModel, smallModel string,
163 hideSpinner bool,
164 continueSessionID string,
165 useLast bool,
166) error {
167 slog.Info("Running in non-interactive mode")
168
169 ctx, cancel := context.WithCancel(ctx)
170 defer cancel()
171
172 if largeModel != "" || smallModel != "" {
173 if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
174 return fmt.Errorf("failed to override models: %w", err)
175 }
176 }
177
178 var (
179 spinner *format.Spinner
180 stdoutTTY bool
181 stderrTTY bool
182 stdinTTY bool
183 progress bool
184 )
185
186 stdoutTTY = term.IsTerminal(os.Stdout.Fd())
187 stderrTTY = term.IsTerminal(os.Stderr.Fd())
188 stdinTTY = term.IsTerminal(os.Stdin.Fd())
189 progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
190
191 if !hideSpinner && stderrTTY {
192 t := styles.DefaultStyles()
193
194 hasDarkBG := true
195 if stdinTTY && stdoutTTY {
196 hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
197 }
198 defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
199
200 spinner = format.NewSpinner(ctx, cancel, anim.Settings{
201 Size: 10,
202 Label: "Generating",
203 LabelColor: defaultFG,
204 GradColorA: t.Primary,
205 GradColorB: t.Secondary,
206 CycleColors: true,
207 })
208 spinner.Start()
209 }
210
211 stopSpinner := func() {
212 if !hideSpinner && spinner != nil {
213 spinner.Stop()
214 spinner = nil
215 }
216 }
217
218 // Wait for the agent to become ready (MCP init, etc).
219 if err := waitForAgent(ctx, c, ws.ID); err != nil {
220 stopSpinner()
221 return fmt.Errorf("agent not ready: %w", err)
222 }
223
224 // Force-update agent models so MCP tools are loaded.
225 if err := c.UpdateAgent(ctx, ws.ID); err != nil {
226 slog.Warn("Failed to update agent", "error", err)
227 }
228
229 defer stopSpinner()
230
231 sess, err := resolveSession(ctx, c, ws.ID, continueSessionID, useLast)
232 if err != nil {
233 return fmt.Errorf("failed to resolve session: %w", err)
234 }
235 if continueSessionID != "" || useLast {
236 slog.Info("Continuing session for non-interactive run", "session_id", sess.ID)
237 } else {
238 slog.Info("Created session for non-interactive run", "session_id", sess.ID)
239 }
240
241 events, err := c.SubscribeEvents(ctx, ws.ID)
242 if err != nil {
243 return fmt.Errorf("failed to subscribe to events: %w", err)
244 }
245
246 if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil {
247 return fmt.Errorf("failed to send message: %w", err)
248 }
249
250 messageReadBytes := make(map[string]int)
251 var printed bool
252
253 defer func() {
254 if progress && stderrTTY {
255 _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
256 }
257 _, _ = fmt.Fprintln(os.Stdout)
258 }()
259
260 for {
261 if progress && stderrTTY {
262 _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
263 }
264
265 select {
266 case ev, ok := <-events:
267 if !ok {
268 stopSpinner()
269 return nil
270 }
271
272 switch e := ev.(type) {
273 case pubsub.Event[proto.Message]:
274 msg := e.Payload
275 if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
276 continue
277 }
278 stopSpinner()
279
280 content := msg.Content().String()
281 readBytes := messageReadBytes[msg.ID]
282
283 if len(content) < readBytes {
284 slog.Error("Non-interactive: message content shorter than read bytes",
285 "message_length", len(content), "read_bytes", readBytes)
286 return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
287 }
288
289 part := content[readBytes:]
290 if readBytes == 0 {
291 part = strings.TrimLeft(part, " \t")
292 }
293 if printed || strings.TrimSpace(part) != "" {
294 printed = true
295 fmt.Fprint(os.Stdout, part)
296 }
297 messageReadBytes[msg.ID] = len(content)
298
299 if msg.IsFinished() {
300 return nil
301 }
302
303 case pubsub.Event[proto.AgentEvent]:
304 if e.Payload.Error != nil {
305 stopSpinner()
306 return fmt.Errorf("agent error: %w", e.Payload.Error)
307 }
308 }
309
310 case <-ctx.Done():
311 stopSpinner()
312 return ctx.Err()
313 }
314 }
315}
316
317// waitForAgent polls GetAgentInfo until the agent is ready, with a
318// timeout.
319func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
320 timeout := time.After(30 * time.Second)
321 for {
322 info, err := c.GetAgentInfo(ctx, wsID)
323 if err == nil && info.IsReady {
324 return nil
325 }
326 select {
327 case <-timeout:
328 if err != nil {
329 return fmt.Errorf("timeout waiting for agent: %w", err)
330 }
331 return fmt.Errorf("timeout waiting for agent readiness")
332 case <-ctx.Done():
333 return ctx.Err()
334 case <-time.After(200 * time.Millisecond):
335 }
336 }
337}
338
339// overrideModels resolves model strings and updates the workspace
340// configuration via the server.
341func overrideModels(
342 ctx context.Context,
343 c *client.Client,
344 ws *proto.Workspace,
345 largeModel, smallModel string,
346) error {
347 cfg, err := c.GetConfig(ctx, ws.ID)
348 if err != nil {
349 return fmt.Errorf("failed to get config: %w", err)
350 }
351
352 providers := cfg.Providers.Copy()
353
354 largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
355
356 var largeProviderID string
357
358 if largeModel != "" {
359 found, err := validateModelMatches(largeMatches, largeModel, "large")
360 if err != nil {
361 return err
362 }
363 largeProviderID = found.provider
364 slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
365 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
366 Provider: found.provider,
367 Model: found.modelID,
368 }); err != nil {
369 return fmt.Errorf("failed to set large model: %w", err)
370 }
371 }
372
373 switch {
374 case smallModel != "":
375 found, err := validateModelMatches(smallMatches, smallModel, "small")
376 if err != nil {
377 return err
378 }
379 slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
380 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
381 Provider: found.provider,
382 Model: found.modelID,
383 }); err != nil {
384 return fmt.Errorf("failed to set small model: %w", err)
385 }
386
387 case largeModel != "":
388 sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
389 if err != nil {
390 slog.Warn("Failed to get default small model", "error", err)
391 } else if sm != nil {
392 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
393 return fmt.Errorf("failed to set small model: %w", err)
394 }
395 }
396 }
397
398 return c.UpdateAgent(ctx, ws.ID)
399}
400
401type modelMatch struct {
402 provider string
403 modelID string
404}
405
406// findModelMatches searches providers for matching large/small model
407// strings.
408func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
409 largeFilter, largeID := parseModelString(largeModel)
410 smallFilter, smallID := parseModelString(smallModel)
411
412 var largeMatches, smallMatches []modelMatch
413 for name, provider := range providers {
414 if provider.Disable {
415 continue
416 }
417 for _, m := range provider.Models {
418 if matchesModel(largeID, largeFilter, m.ID, name) {
419 largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
420 }
421 if matchesModel(smallID, smallFilter, m.ID, name) {
422 smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
423 }
424 }
425 }
426 return largeMatches, smallMatches
427}
428
429// parseModelString splits "provider/model" into (provider, model) or
430// ("", model).
431func parseModelString(s string) (string, string) {
432 if s == "" {
433 return "", ""
434 }
435 if idx := strings.Index(s, "/"); idx >= 0 {
436 return s[:idx], s[idx+1:]
437 }
438 return "", s
439}
440
441// matchesModel returns true if the model ID matches the filter
442// criteria.
443func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
444 if wantID == "" {
445 return false
446 }
447 if wantProvider != "" && wantProvider != providerName {
448 return false
449 }
450 return strings.EqualFold(modelID, wantID)
451}
452
453// validateModelMatches ensures exactly one match exists.
454func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
455 switch {
456 case len(matches) == 0:
457 return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
458 case len(matches) > 1:
459 names := make([]string, len(matches))
460 for i, m := range matches {
461 names[i] = m.provider
462 }
463 return modelMatch{}, fmt.Errorf(
464 "%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
465 label, modelID, strings.Join(names, ", "),
466 )
467 }
468 return matches[0], nil
469}
470
471// resolveSession returns the session to use for a non-interactive run.
472// If continueSessionID is set it fetches that session; if useLast is set it
473// returns the most recently updated top-level session; otherwise it creates a
474// new one.
475func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) {
476 switch {
477 case continueSessionID != "":
478 sess, err := c.GetSession(ctx, wsID, continueSessionID)
479 if err != nil {
480 return nil, fmt.Errorf("session not found: %s", continueSessionID)
481 }
482 if sess.ParentSessionID != "" {
483 return nil, fmt.Errorf("cannot continue a child session: %s", continueSessionID)
484 }
485 return sess, nil
486
487 case useLast:
488 sessions, err := c.ListSessions(ctx, wsID)
489 if err != nil || len(sessions) == 0 {
490 return nil, fmt.Errorf("no sessions found to continue")
491 }
492 last := sessions[0]
493 for _, s := range sessions[1:] {
494 if s.UpdatedAt > last.UpdatedAt && s.ParentSessionID == "" {
495 last = s
496 }
497 }
498 return &last, nil
499
500 default:
501 return c.CreateSession(ctx, wsID, "non-interactive")
502 }
503}
504
505// resolveSessionByID resolves a session ID that may be a full UUID or a hash
506// prefix returned by crush session list.
507func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) {
508 if sess, err := c.GetSession(ctx, wsID, id); err == nil {
509 return sess, nil
510 }
511
512 sessions, err := c.ListSessions(ctx, wsID)
513 if err != nil {
514 return nil, err
515 }
516
517 var matches []proto.Session
518 for _, s := range sessions {
519 hash := session.HashID(s.ID)
520 if hash == id || strings.HasPrefix(hash, id) {
521 matches = append(matches, s)
522 }
523 }
524
525 switch len(matches) {
526 case 0:
527 return nil, fmt.Errorf("session %q not found", id)
528 case 1:
529 return &matches[0], nil
530 default:
531 return nil, fmt.Errorf("session ID %q is ambiguous (%d matches)", id, len(matches))
532 }
533}