1package cmd
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "os"
8 "os/signal"
9 "strings"
10 "time"
11
12 "charm.land/lipgloss/v2"
13 "charm.land/log/v2"
14 "github.com/charmbracelet/crush/internal/client"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/event"
17 "github.com/charmbracelet/crush/internal/format"
18 "github.com/charmbracelet/crush/internal/proto"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/session"
21 "github.com/charmbracelet/crush/internal/ui/anim"
22 "github.com/charmbracelet/crush/internal/ui/styles"
23 "github.com/charmbracelet/crush/internal/workspace"
24 "github.com/charmbracelet/x/ansi"
25 "github.com/charmbracelet/x/exp/charmtone"
26 "github.com/charmbracelet/x/term"
27 "github.com/spf13/cobra"
28)
29
30var runCmd = &cobra.Command{
31 Aliases: []string{"r"},
32 Use: "run [prompt...]",
33 Short: "Run a single non-interactive prompt",
34 Long: `Run a single prompt in non-interactive mode and exit.
35The prompt can be provided as arguments or piped from stdin.`,
36 Example: `
37# Run a simple prompt
38crush run "Guess my 5 favorite Pokรฉmon"
39
40# Pipe input from stdin
41curl https://charm.land | crush run "Summarize this website"
42
43# Read from a file
44crush run "What is this code doing?" <<< prrr.go
45
46# Redirect output to a file
47crush run "Generate a hot README for this project" > MY_HOT_README.md
48
49# Run in quiet mode (hide the spinner)
50crush run --quiet "Generate a README for this project"
51
52# Run in verbose mode (show logs)
53crush run --verbose "Generate a README for this project"
54
55# Continue a previous session
56crush run --session {session-id} "Follow up on your last response"
57
58# Continue the most recent session
59crush run --continue "Follow up on your last response"
60
61 `,
62 RunE: func(cmd *cobra.Command, args []string) error {
63 var (
64 quiet, _ = cmd.Flags().GetBool("quiet")
65 verbose, _ = cmd.Flags().GetBool("verbose")
66 largeModel, _ = cmd.Flags().GetString("model")
67 smallModel, _ = cmd.Flags().GetString("small-model")
68 sessionID, _ = cmd.Flags().GetString("session")
69 useLast, _ = cmd.Flags().GetBool("continue")
70 )
71
72 // Cancel on SIGINT or SIGTERM.
73 ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
74 defer cancel()
75
76 prompt := strings.Join(args, " ")
77
78 prompt, err := MaybePrependStdin(prompt)
79 if err != nil {
80 slog.Error("Failed to read from stdin", "error", err)
81 return err
82 }
83
84 if prompt == "" {
85 return fmt.Errorf("no prompt provided")
86 }
87
88 event.SetNonInteractive(true)
89 event.AppInitialized()
90
91 switch {
92 case sessionID != "":
93 event.SetContinueBySessionID(true)
94 case useLast:
95 event.SetContinueLastSession(true)
96 }
97
98 if useClientServer() {
99 c, ws, cleanup, err := connectToServer(cmd)
100 if err != nil {
101 return err
102 }
103 defer cleanup()
104
105 if sessionID != "" {
106 sess, err := resolveSessionByID(ctx, c, ws.ID, sessionID)
107 if err != nil {
108 return err
109 }
110 sessionID = sess.ID
111 }
112
113 if !ws.Config.IsConfigured() {
114 return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
115 }
116
117 if verbose {
118 slog.SetDefault(slog.New(log.New(os.Stderr)))
119 }
120
121 return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
122 }
123
124 ws, cleanup, err := setupLocalWorkspace(cmd)
125 if err != nil {
126 return err
127 }
128 defer cleanup()
129
130 if !ws.Config().IsConfigured() {
131 return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
132 }
133
134 if verbose {
135 slog.SetDefault(slog.New(log.New(os.Stderr)))
136 }
137
138 appWs := ws.(*workspace.AppWorkspace)
139 return appWs.App().RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast)
140 },
141}
142
143func init() {
144 runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
145 runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
146 runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
147 runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
148 runCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
149 runCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
150 runCmd.MarkFlagsMutuallyExclusive("session", "continue")
151}
152
153// runNonInteractive executes the agent via the server and streams output
154// to stdout.
155func runNonInteractive(
156 ctx context.Context,
157 c *client.Client,
158 ws *proto.Workspace,
159 prompt, largeModel, smallModel string,
160 hideSpinner bool,
161 continueSessionID string,
162 useLast bool,
163) error {
164 slog.Info("Running in non-interactive mode")
165
166 ctx, cancel := context.WithCancel(ctx)
167 defer cancel()
168
169 if largeModel != "" || smallModel != "" {
170 if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
171 return fmt.Errorf("failed to override models: %w", err)
172 }
173 }
174
175 var (
176 spinner *format.Spinner
177 stdoutTTY bool
178 stderrTTY bool
179 stdinTTY bool
180 progress bool
181 )
182
183 stdoutTTY = term.IsTerminal(os.Stdout.Fd())
184 stderrTTY = term.IsTerminal(os.Stderr.Fd())
185 stdinTTY = term.IsTerminal(os.Stdin.Fd())
186 progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
187
188 if !hideSpinner && stderrTTY {
189 t := styles.DefaultStyles()
190
191 hasDarkBG := true
192 if stdinTTY && stdoutTTY {
193 hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
194 }
195 defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
196
197 spinner = format.NewSpinner(ctx, cancel, anim.Settings{
198 Size: 10,
199 Label: "Generating",
200 LabelColor: defaultFG,
201 GradColorA: t.Primary,
202 GradColorB: t.Secondary,
203 CycleColors: true,
204 })
205 spinner.Start()
206 }
207
208 stopSpinner := func() {
209 if !hideSpinner && spinner != nil {
210 spinner.Stop()
211 spinner = nil
212 }
213 }
214
215 // Wait for the agent to become ready (MCP init, etc).
216 if err := waitForAgent(ctx, c, ws.ID); err != nil {
217 stopSpinner()
218 return fmt.Errorf("agent not ready: %w", err)
219 }
220
221 // Force-update agent models so MCP tools are loaded.
222 if err := c.UpdateAgent(ctx, ws.ID); err != nil {
223 slog.Warn("Failed to update agent", "error", err)
224 }
225
226 defer stopSpinner()
227
228 sess, err := resolveSession(ctx, c, ws.ID, continueSessionID, useLast)
229 if err != nil {
230 return fmt.Errorf("failed to resolve session: %w", err)
231 }
232 if continueSessionID != "" || useLast {
233 slog.Info("Continuing session for non-interactive run", "session_id", sess.ID)
234 } else {
235 slog.Info("Created session for non-interactive run", "session_id", sess.ID)
236 }
237
238 events, err := c.SubscribeEvents(ctx, ws.ID)
239 if err != nil {
240 return fmt.Errorf("failed to subscribe to events: %w", err)
241 }
242
243 if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil {
244 return fmt.Errorf("failed to send message: %w", err)
245 }
246
247 messageReadBytes := make(map[string]int)
248 var printed bool
249
250 defer func() {
251 if progress && stderrTTY {
252 _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
253 }
254 _, _ = fmt.Fprintln(os.Stdout)
255 }()
256
257 for {
258 if progress && stderrTTY {
259 _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
260 }
261
262 select {
263 case ev, ok := <-events:
264 if !ok {
265 stopSpinner()
266 return nil
267 }
268
269 switch e := ev.(type) {
270 case pubsub.Event[proto.Message]:
271 msg := e.Payload
272 if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
273 continue
274 }
275 stopSpinner()
276
277 content := msg.Content().String()
278 readBytes := messageReadBytes[msg.ID]
279
280 if len(content) < readBytes {
281 slog.Error("Non-interactive: message content shorter than read bytes",
282 "message_length", len(content), "read_bytes", readBytes)
283 return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
284 }
285
286 part := content[readBytes:]
287 if readBytes == 0 {
288 part = strings.TrimLeft(part, " \t")
289 }
290 if printed || strings.TrimSpace(part) != "" {
291 printed = true
292 fmt.Fprint(os.Stdout, part)
293 }
294 messageReadBytes[msg.ID] = len(content)
295
296 if msg.IsFinished() {
297 return nil
298 }
299
300 case pubsub.Event[proto.AgentEvent]:
301 if e.Payload.Error != nil {
302 stopSpinner()
303 return fmt.Errorf("agent error: %w", e.Payload.Error)
304 }
305 }
306
307 case <-ctx.Done():
308 stopSpinner()
309 return ctx.Err()
310 }
311 }
312}
313
314// waitForAgent polls GetAgentInfo until the agent is ready, with a
315// timeout.
316func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
317 timeout := time.After(30 * time.Second)
318 for {
319 info, err := c.GetAgentInfo(ctx, wsID)
320 if err == nil && info.IsReady {
321 return nil
322 }
323 select {
324 case <-timeout:
325 if err != nil {
326 return fmt.Errorf("timeout waiting for agent: %w", err)
327 }
328 return fmt.Errorf("timeout waiting for agent readiness")
329 case <-ctx.Done():
330 return ctx.Err()
331 case <-time.After(200 * time.Millisecond):
332 }
333 }
334}
335
336// overrideModels resolves model strings and updates the workspace
337// configuration via the server.
338func overrideModels(
339 ctx context.Context,
340 c *client.Client,
341 ws *proto.Workspace,
342 largeModel, smallModel string,
343) error {
344 cfg, err := c.GetConfig(ctx, ws.ID)
345 if err != nil {
346 return fmt.Errorf("failed to get config: %w", err)
347 }
348
349 providers := cfg.Providers.Copy()
350
351 largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
352
353 var largeProviderID string
354
355 if largeModel != "" {
356 found, err := validateModelMatches(largeMatches, largeModel, "large")
357 if err != nil {
358 return err
359 }
360 largeProviderID = found.provider
361 slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
362 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
363 Provider: found.provider,
364 Model: found.modelID,
365 }); err != nil {
366 return fmt.Errorf("failed to set large model: %w", err)
367 }
368 }
369
370 switch {
371 case smallModel != "":
372 found, err := validateModelMatches(smallMatches, smallModel, "small")
373 if err != nil {
374 return err
375 }
376 slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
377 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
378 Provider: found.provider,
379 Model: found.modelID,
380 }); err != nil {
381 return fmt.Errorf("failed to set small model: %w", err)
382 }
383
384 case largeModel != "":
385 sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
386 if err != nil {
387 slog.Warn("Failed to get default small model", "error", err)
388 } else if sm != nil {
389 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
390 return fmt.Errorf("failed to set small model: %w", err)
391 }
392 }
393 }
394
395 return c.UpdateAgent(ctx, ws.ID)
396}
397
398type modelMatch struct {
399 provider string
400 modelID string
401}
402
403// findModelMatches searches providers for matching large/small model
404// strings.
405func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
406 largeFilter, largeID := parseModelString(largeModel)
407 smallFilter, smallID := parseModelString(smallModel)
408
409 var largeMatches, smallMatches []modelMatch
410 for name, provider := range providers {
411 if provider.Disable {
412 continue
413 }
414 for _, m := range provider.Models {
415 if matchesModel(largeID, largeFilter, m.ID, name) {
416 largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
417 }
418 if matchesModel(smallID, smallFilter, m.ID, name) {
419 smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
420 }
421 }
422 }
423 return largeMatches, smallMatches
424}
425
426// parseModelString splits "provider/model" into (provider, model) or
427// ("", model).
428func parseModelString(s string) (string, string) {
429 if s == "" {
430 return "", ""
431 }
432 if idx := strings.Index(s, "/"); idx >= 0 {
433 return s[:idx], s[idx+1:]
434 }
435 return "", s
436}
437
438// matchesModel returns true if the model ID matches the filter
439// criteria.
440func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
441 if wantID == "" {
442 return false
443 }
444 if wantProvider != "" && wantProvider != providerName {
445 return false
446 }
447 return strings.EqualFold(modelID, wantID)
448}
449
450// validateModelMatches ensures exactly one match exists.
451func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
452 switch {
453 case len(matches) == 0:
454 return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
455 case len(matches) > 1:
456 names := make([]string, len(matches))
457 for i, m := range matches {
458 names[i] = m.provider
459 }
460 return modelMatch{}, fmt.Errorf(
461 "%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
462 label, modelID, strings.Join(names, ", "),
463 )
464 }
465 return matches[0], nil
466}
467
468// resolveSession returns the session to use for a non-interactive run.
469// If continueSessionID is set it fetches that session; if useLast is set it
470// returns the most recently updated top-level session; otherwise it creates a
471// new one.
472func resolveSession(ctx context.Context, c *client.Client, wsID, continueSessionID string, useLast bool) (*proto.Session, error) {
473 switch {
474 case continueSessionID != "":
475 sess, err := c.GetSession(ctx, wsID, continueSessionID)
476 if err != nil {
477 return nil, fmt.Errorf("session not found: %s", continueSessionID)
478 }
479 if sess.ParentSessionID != "" {
480 return nil, fmt.Errorf("cannot continue a child session: %s", continueSessionID)
481 }
482 return sess, nil
483
484 case useLast:
485 sessions, err := c.ListSessions(ctx, wsID)
486 if err != nil || len(sessions) == 0 {
487 return nil, fmt.Errorf("no sessions found to continue")
488 }
489 last := sessions[0]
490 for _, s := range sessions[1:] {
491 if s.UpdatedAt > last.UpdatedAt && s.ParentSessionID == "" {
492 last = s
493 }
494 }
495 return &last, nil
496
497 default:
498 return c.CreateSession(ctx, wsID, "non-interactive")
499 }
500}
501
502// resolveSessionByID resolves a session ID that may be a full UUID or a hash
503// prefix returned by crush session list.
504func resolveSessionByID(ctx context.Context, c *client.Client, wsID, id string) (*proto.Session, error) {
505 if sess, err := c.GetSession(ctx, wsID, id); err == nil {
506 return sess, nil
507 }
508
509 sessions, err := c.ListSessions(ctx, wsID)
510 if err != nil {
511 return nil, err
512 }
513
514 var matches []proto.Session
515 for _, s := range sessions {
516 hash := session.HashID(s.ID)
517 if hash == id || strings.HasPrefix(hash, id) {
518 matches = append(matches, s)
519 }
520 }
521
522 switch len(matches) {
523 case 0:
524 return nil, fmt.Errorf("session %q not found", id)
525 case 1:
526 return &matches[0], nil
527 default:
528 return nil, fmt.Errorf("session ID %q is ambiguous (%d matches)", id, len(matches))
529 }
530}