1package cmd
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "os"
8 "os/signal"
9 "strings"
10 "time"
11
12 "charm.land/lipgloss/v2"
13 "charm.land/log/v2"
14 "github.com/charmbracelet/crush/internal/client"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/event"
17 "github.com/charmbracelet/crush/internal/format"
18 "github.com/charmbracelet/crush/internal/proto"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/ui/anim"
21 "github.com/charmbracelet/crush/internal/ui/styles"
22 "github.com/charmbracelet/x/ansi"
23 "github.com/charmbracelet/x/exp/charmtone"
24 "github.com/charmbracelet/x/term"
25 "github.com/spf13/cobra"
26)
27
28var runCmd = &cobra.Command{
29 Use: "run [prompt...]",
30 Short: "Run a single non-interactive prompt",
31 Long: `Run a single prompt in non-interactive mode and exit.
32The prompt can be provided as arguments or piped from stdin.`,
33 Example: `
34# Run a simple prompt
35crush run "Guess my 5 favorite Pokรฉmon"
36
37# Pipe input from stdin
38curl https://charm.land | crush run "Summarize this website"
39
40# Read from a file
41crush run "What is this code doing?" <<< prrr.go
42
43# Redirect output to a file
44crush run "Generate a hot README for this project" > MY_HOT_README.md
45
46# Run in quiet mode (hide the spinner)
47crush run --quiet "Generate a README for this project"
48
49# Run in verbose mode (show logs)
50crush run --verbose "Generate a README for this project"
51 `,
52 RunE: func(cmd *cobra.Command, args []string) error {
53 quiet, _ := cmd.Flags().GetBool("quiet")
54 verbose, _ := cmd.Flags().GetBool("verbose")
55 largeModel, _ := cmd.Flags().GetString("model")
56 smallModel, _ := cmd.Flags().GetString("small-model")
57
58 // Cancel on SIGINT or SIGTERM.
59 ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
60 defer cancel()
61
62 c, ws, cleanup, err := connectToServer(cmd)
63 if err != nil {
64 return err
65 }
66 defer cleanup()
67
68 if !ws.Config.IsConfigured() {
69 return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
70 }
71
72 if verbose {
73 slog.SetDefault(slog.New(log.New(os.Stderr)))
74 }
75
76 prompt := strings.Join(args, " ")
77
78 prompt, err = MaybePrependStdin(prompt)
79 if err != nil {
80 slog.Error("Failed to read from stdin", "error", err)
81 return err
82 }
83
84 if prompt == "" {
85 return fmt.Errorf("no prompt provided")
86 }
87
88 event.SetNonInteractive(true)
89 event.AppInitialized()
90
91 return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose)
92 },
93}
94
95func init() {
96 runCmd.Flags().BoolP("quiet", "q", false, "Hide spinner")
97 runCmd.Flags().BoolP("verbose", "v", false, "Show logs")
98 runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
99 runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
100}
101
102// runNonInteractive executes the agent via the server and streams output
103// to stdout.
104func runNonInteractive(
105 ctx context.Context,
106 c *client.Client,
107 ws *proto.Workspace,
108 prompt, largeModel, smallModel string,
109 hideSpinner bool,
110) error {
111 slog.Info("Running in non-interactive mode")
112
113 ctx, cancel := context.WithCancel(ctx)
114 defer cancel()
115
116 if largeModel != "" || smallModel != "" {
117 if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
118 return fmt.Errorf("failed to override models: %w", err)
119 }
120 }
121
122 var (
123 spinner *format.Spinner
124 stdoutTTY bool
125 stderrTTY bool
126 stdinTTY bool
127 progress bool
128 )
129
130 stdoutTTY = term.IsTerminal(os.Stdout.Fd())
131 stderrTTY = term.IsTerminal(os.Stderr.Fd())
132 stdinTTY = term.IsTerminal(os.Stdin.Fd())
133 progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
134
135 if !hideSpinner && stderrTTY {
136 t := styles.DefaultStyles()
137
138 hasDarkBG := true
139 if stdinTTY && stdoutTTY {
140 hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
141 }
142 defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
143
144 spinner = format.NewSpinner(ctx, cancel, anim.Settings{
145 Size: 10,
146 Label: "Generating",
147 LabelColor: defaultFG,
148 GradColorA: t.Primary,
149 GradColorB: t.Secondary,
150 CycleColors: true,
151 })
152 spinner.Start()
153 }
154
155 stopSpinner := func() {
156 if !hideSpinner && spinner != nil {
157 spinner.Stop()
158 spinner = nil
159 }
160 }
161
162 // Wait for the agent to become ready (MCP init, etc).
163 if err := waitForAgent(ctx, c, ws.ID); err != nil {
164 stopSpinner()
165 return fmt.Errorf("agent not ready: %w", err)
166 }
167
168 // Force-update agent models so MCP tools are loaded.
169 if err := c.UpdateAgent(ctx, ws.ID); err != nil {
170 slog.Warn("Failed to update agent", "error", err)
171 }
172
173 defer stopSpinner()
174
175 sess, err := c.CreateSession(ctx, ws.ID, "non-interactive")
176 if err != nil {
177 return fmt.Errorf("failed to create session: %w", err)
178 }
179 slog.Info("Created session for non-interactive run", "session_id", sess.ID)
180
181 events, err := c.SubscribeEvents(ctx, ws.ID)
182 if err != nil {
183 return fmt.Errorf("failed to subscribe to events: %w", err)
184 }
185
186 if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil {
187 return fmt.Errorf("failed to send message: %w", err)
188 }
189
190 messageReadBytes := make(map[string]int)
191 var printed bool
192
193 defer func() {
194 if progress && stderrTTY {
195 _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
196 }
197 _, _ = fmt.Fprintln(os.Stdout)
198 }()
199
200 for {
201 if progress && stderrTTY {
202 _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
203 }
204
205 select {
206 case ev, ok := <-events:
207 if !ok {
208 stopSpinner()
209 return nil
210 }
211
212 switch e := ev.(type) {
213 case pubsub.Event[proto.Message]:
214 msg := e.Payload
215 if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
216 continue
217 }
218 stopSpinner()
219
220 content := msg.Content().String()
221 readBytes := messageReadBytes[msg.ID]
222
223 if len(content) < readBytes {
224 slog.Error("Non-interactive: message content shorter than read bytes",
225 "message_length", len(content), "read_bytes", readBytes)
226 return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
227 }
228
229 part := content[readBytes:]
230 if readBytes == 0 {
231 part = strings.TrimLeft(part, " \t")
232 }
233 if printed || strings.TrimSpace(part) != "" {
234 printed = true
235 fmt.Fprint(os.Stdout, part)
236 }
237 messageReadBytes[msg.ID] = len(content)
238
239 if msg.IsFinished() {
240 return nil
241 }
242
243 case pubsub.Event[proto.AgentEvent]:
244 if e.Payload.Error != nil {
245 stopSpinner()
246 return fmt.Errorf("agent error: %w", e.Payload.Error)
247 }
248 }
249
250 case <-ctx.Done():
251 stopSpinner()
252 return ctx.Err()
253 }
254 }
255}
256
257// waitForAgent polls GetAgentInfo until the agent is ready, with a
258// timeout.
259func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
260 timeout := time.After(30 * time.Second)
261 for {
262 info, err := c.GetAgentInfo(ctx, wsID)
263 if err == nil && info.IsReady {
264 return nil
265 }
266 select {
267 case <-timeout:
268 if err != nil {
269 return fmt.Errorf("timeout waiting for agent: %w", err)
270 }
271 return fmt.Errorf("timeout waiting for agent readiness")
272 case <-ctx.Done():
273 return ctx.Err()
274 case <-time.After(200 * time.Millisecond):
275 }
276 }
277}
278
279// overrideModels resolves model strings and updates the workspace
280// configuration via the server.
281func overrideModels(
282 ctx context.Context,
283 c *client.Client,
284 ws *proto.Workspace,
285 largeModel, smallModel string,
286) error {
287 cfg, err := c.GetConfig(ctx, ws.ID)
288 if err != nil {
289 return fmt.Errorf("failed to get config: %w", err)
290 }
291
292 providers := cfg.Providers.Copy()
293
294 largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
295
296 var largeProviderID string
297
298 if largeModel != "" {
299 found, err := validateModelMatches(largeMatches, largeModel, "large")
300 if err != nil {
301 return err
302 }
303 largeProviderID = found.provider
304 slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
305 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
306 Provider: found.provider,
307 Model: found.modelID,
308 }); err != nil {
309 return fmt.Errorf("failed to set large model: %w", err)
310 }
311 }
312
313 switch {
314 case smallModel != "":
315 found, err := validateModelMatches(smallMatches, smallModel, "small")
316 if err != nil {
317 return err
318 }
319 slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
320 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
321 Provider: found.provider,
322 Model: found.modelID,
323 }); err != nil {
324 return fmt.Errorf("failed to set small model: %w", err)
325 }
326
327 case largeModel != "":
328 sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
329 if err != nil {
330 slog.Warn("Failed to get default small model", "error", err)
331 } else if sm != nil {
332 if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
333 return fmt.Errorf("failed to set small model: %w", err)
334 }
335 }
336 }
337
338 return c.UpdateAgent(ctx, ws.ID)
339}
340
341type modelMatch struct {
342 provider string
343 modelID string
344}
345
346// findModelMatches searches providers for matching large/small model
347// strings.
348func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
349 largeFilter, largeID := parseModelString(largeModel)
350 smallFilter, smallID := parseModelString(smallModel)
351
352 var largeMatches, smallMatches []modelMatch
353 for name, provider := range providers {
354 if provider.Disable {
355 continue
356 }
357 for _, m := range provider.Models {
358 if matchesModel(largeID, largeFilter, m.ID, name) {
359 largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
360 }
361 if matchesModel(smallID, smallFilter, m.ID, name) {
362 smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
363 }
364 }
365 }
366 return largeMatches, smallMatches
367}
368
369// parseModelString splits "provider/model" into (provider, model) or
370// ("", model).
371func parseModelString(s string) (string, string) {
372 if s == "" {
373 return "", ""
374 }
375 if idx := strings.Index(s, "/"); idx >= 0 {
376 return s[:idx], s[idx+1:]
377 }
378 return "", s
379}
380
381// matchesModel returns true if the model ID matches the filter
382// criteria.
383func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
384 if wantID == "" {
385 return false
386 }
387 if wantProvider != "" && wantProvider != providerName {
388 return false
389 }
390 return strings.EqualFold(modelID, wantID)
391}
392
393// validateModelMatches ensures exactly one match exists.
394func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
395 switch {
396 case len(matches) == 0:
397 return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
398 case len(matches) > 1:
399 names := make([]string, len(matches))
400 for i, m := range matches {
401 names[i] = m.provider
402 }
403 return modelMatch{}, fmt.Errorf(
404 "%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
405 label, modelID, strings.Join(names, ", "),
406 )
407 }
408 return matches[0], nil
409}