1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package provides Shell instances for executing commands with their own
4// working directory and environment. Each shell execution is independent.
5//
6// WINDOWS COMPATIBILITY:
7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
8// Windows. Commands should use forward slashes (/) as path separators to work
9// correctly on all platforms.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "os"
19 "slices"
20 "strings"
21 "sync"
22
23 "github.com/charmbracelet/x/exp/slice"
24 "mvdan.cc/sh/v3/interp"
25 "mvdan.cc/sh/v3/syntax"
26)
27
28// ShellType represents the type of shell to use
29type ShellType int
30
31const (
32 ShellTypePOSIX ShellType = iota
33 ShellTypeCmd
34 ShellTypePowerShell
35)
36
37// Logger interface for optional logging
38type Logger interface {
39 InfoPersist(msg string, keysAndValues ...any)
40}
41
42// noopLogger is a logger that does nothing
43type noopLogger struct{}
44
45func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
46
47// BlockFunc is a function that determines if a command should be blocked
48type BlockFunc func(args []string) bool
49
50// Shell provides cross-platform shell execution with optional state persistence
51type Shell struct {
52 env []string
53 cwd string
54 mu sync.Mutex
55 logger Logger
56 blockFuncs []BlockFunc
57}
58
59// Options for creating a new shell
60type Options struct {
61 WorkingDir string
62 Env []string
63 Logger Logger
64 BlockFuncs []BlockFunc
65}
66
67// NewShell creates a new shell instance with the given options
68func NewShell(opts *Options) *Shell {
69 if opts == nil {
70 opts = &Options{}
71 }
72
73 cwd := opts.WorkingDir
74 if cwd == "" {
75 cwd, _ = os.Getwd()
76 }
77
78 env := opts.Env
79 if env == nil {
80 env = os.Environ()
81 }
82
83 // Allow tools to detect execution by Crush.
84 env = append(
85 env,
86 "CRUSH=1",
87 "AGENT=crush",
88 "AI_AGENT=crush",
89 )
90
91 logger := opts.Logger
92 if logger == nil {
93 logger = noopLogger{}
94 }
95
96 return &Shell{
97 cwd: cwd,
98 env: env,
99 logger: logger,
100 blockFuncs: opts.BlockFuncs,
101 }
102}
103
104// Exec executes a command in the shell
105func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
106 s.mu.Lock()
107 defer s.mu.Unlock()
108
109 return s.exec(ctx, command)
110}
111
112// ExecStream executes a command in the shell with streaming output to provided writers
113func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
114 s.mu.Lock()
115 defer s.mu.Unlock()
116
117 return s.execStream(ctx, command, stdout, stderr)
118}
119
120// GetWorkingDir returns the current working directory
121func (s *Shell) GetWorkingDir() string {
122 s.mu.Lock()
123 defer s.mu.Unlock()
124 return s.cwd
125}
126
127// SetWorkingDir sets the working directory
128func (s *Shell) SetWorkingDir(dir string) error {
129 s.mu.Lock()
130 defer s.mu.Unlock()
131
132 // Verify the directory exists
133 if _, err := os.Stat(dir); err != nil {
134 return fmt.Errorf("directory does not exist: %w", err)
135 }
136
137 s.cwd = dir
138 return nil
139}
140
141// GetEnv returns a copy of the environment variables
142func (s *Shell) GetEnv() []string {
143 s.mu.Lock()
144 defer s.mu.Unlock()
145
146 env := make([]string, len(s.env))
147 copy(env, s.env)
148 return env
149}
150
151// SetEnv sets an environment variable
152func (s *Shell) SetEnv(key, value string) {
153 s.mu.Lock()
154 defer s.mu.Unlock()
155
156 // Update or add the environment variable
157 keyPrefix := key + "="
158 for i, env := range s.env {
159 if strings.HasPrefix(env, keyPrefix) {
160 s.env[i] = keyPrefix + value
161 return
162 }
163 }
164 s.env = append(s.env, keyPrefix+value)
165}
166
167// SetBlockFuncs sets the command block functions for the shell
168func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
169 s.mu.Lock()
170 defer s.mu.Unlock()
171 s.blockFuncs = blockFuncs
172}
173
174// CommandsBlocker creates a BlockFunc that blocks exact command matches
175func CommandsBlocker(cmds []string) BlockFunc {
176 bannedSet := make(map[string]struct{})
177 for _, cmd := range cmds {
178 bannedSet[cmd] = struct{}{}
179 }
180
181 return func(args []string) bool {
182 if len(args) == 0 {
183 return false
184 }
185 _, ok := bannedSet[args[0]]
186 return ok
187 }
188}
189
190// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
191func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
192 return func(parts []string) bool {
193 if len(parts) == 0 || parts[0] != cmd {
194 return false
195 }
196
197 argParts, flagParts := splitArgsFlags(parts[1:])
198 if len(argParts) < len(args) || len(flagParts) < len(flags) {
199 return false
200 }
201
202 argsMatch := slices.Equal(argParts[:len(args)], args)
203 flagsMatch := slice.IsSubset(flags, flagParts)
204
205 return argsMatch && flagsMatch
206 }
207}
208
209func splitArgsFlags(parts []string) (args []string, flags []string) {
210 args = make([]string, 0, len(parts))
211 flags = make([]string, 0, len(parts))
212 for _, part := range parts {
213 if strings.HasPrefix(part, "-") {
214 // Extract flag name before '=' if present
215 flag := part
216 if before, _, ok := strings.Cut(part, "="); ok {
217 flag = before
218 }
219 flags = append(flags, flag)
220 } else {
221 args = append(args, part)
222 }
223 }
224 return args, flags
225}
226
227// newInterp creates a new interpreter with the current shell state. A nil
228// stdin is equivalent to an empty input stream.
229func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
230 return newRunner(s.cwd, s.env, stdin, stdout, stderr, s.blockFuncs)
231}
232
233// updateShellFromRunner updates the shell from the interpreter after execution.
234func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
235 s.cwd = runner.Dir
236 s.env = s.env[:0]
237 for name, vr := range runner.Vars {
238 if vr.Exported {
239 s.env = append(s.env, name+"="+vr.Str)
240 }
241 }
242}
243
244// execCommon is the shared implementation for executing commands
245func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) (err error) {
246 var runner *interp.Runner
247 defer func() {
248 if r := recover(); r != nil {
249 err = fmt.Errorf("command execution panic: %v", r)
250 }
251 if runner != nil {
252 s.updateShellFromRunner(runner)
253 }
254 s.logger.InfoPersist("command finished", "command", command, "err", err)
255 }()
256
257 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
258 if err != nil {
259 return fmt.Errorf("could not parse command: %w", err)
260 }
261
262 runner, err = s.newInterp(nil, stdout, stderr)
263 if err != nil {
264 return fmt.Errorf("could not run command: %w", err)
265 }
266
267 err = runner.Run(ctx, line)
268 return err
269}
270
271// exec executes commands using a cross-platform shell interpreter.
272func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
273 var stdout, stderr bytes.Buffer
274 err := s.execCommon(ctx, command, &stdout, &stderr)
275 return stdout.String(), stderr.String(), err
276}
277
278// execStream executes commands using POSIX shell emulation with streaming output
279func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
280 return s.execCommon(ctx, command, stdout, stderr)
281}
282
283// IsInterrupt checks if an error is due to interruption
284func IsInterrupt(err error) bool {
285 return errors.Is(err, context.Canceled) ||
286 errors.Is(err, context.DeadlineExceeded)
287}
288
289// ExitCode extracts the exit code from an error
290func ExitCode(err error) int {
291 if err == nil {
292 return 0
293 }
294 var exitErr interp.ExitStatus
295 if errors.As(err, &exitErr) {
296 return int(exitErr)
297 }
298 return 1
299}