1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package provides Shell instances for executing commands with their own
4// working directory and environment. Each shell execution is independent.
5//
6// WINDOWS COMPATIBILITY:
7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
8// Windows. Commands should use forward slashes (/) as path separators to work
9// correctly on all platforms.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "os"
19 "slices"
20 "strings"
21 "sync"
22
23 "github.com/charmbracelet/x/exp/slice"
24 "mvdan.cc/sh/v3/interp"
25 "mvdan.cc/sh/v3/syntax"
26)
27
28// ShellType represents the type of shell to use
29type ShellType int
30
31const (
32 ShellTypePOSIX ShellType = iota
33 ShellTypeCmd
34 ShellTypePowerShell
35)
36
37// CrushEnvMarkers returns a fresh slice of the environment variables that
38// Crush unconditionally sets on every shell it spawns — both the interactive
39// bash tool's [Shell] and the hook runner's [Run] calls. Tools that want to
40// detect "am I being invoked by an AI agent?" can check any of these.
41// Keeping them in one place guarantees the two shell surfaces cannot drift.
42// A fresh slice is returned on every call so callers may append freely.
43func CrushEnvMarkers() []string {
44 return []string{
45 "CRUSH=1",
46 "AGENT=crush",
47 "AI_AGENT=crush",
48 }
49}
50
51// Logger interface for optional logging
52type Logger interface {
53 InfoPersist(msg string, keysAndValues ...any)
54}
55
56// noopLogger is a logger that does nothing
57type noopLogger struct{}
58
59func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
60
61// BlockFunc is a function that determines if a command should be blocked
62type BlockFunc func(args []string) bool
63
64// Shell provides cross-platform shell execution with optional state persistence
65type Shell struct {
66 env []string
67 cwd string
68 mu sync.Mutex
69 logger Logger
70 blockFuncs []BlockFunc
71}
72
73// Options for creating a new shell
74type Options struct {
75 WorkingDir string
76 Env []string
77 Logger Logger
78 BlockFuncs []BlockFunc
79}
80
81// NewShell creates a new shell instance with the given options
82func NewShell(opts *Options) *Shell {
83 if opts == nil {
84 opts = &Options{}
85 }
86
87 cwd := opts.WorkingDir
88 if cwd == "" {
89 cwd, _ = os.Getwd()
90 }
91
92 env := opts.Env
93 if env == nil {
94 env = os.Environ()
95 }
96
97 // Allow tools to detect execution by Crush.
98 env = append(env, CrushEnvMarkers()...)
99
100 logger := opts.Logger
101 if logger == nil {
102 logger = noopLogger{}
103 }
104
105 return &Shell{
106 cwd: cwd,
107 env: env,
108 logger: logger,
109 blockFuncs: opts.BlockFuncs,
110 }
111}
112
113// Exec executes a command in the shell
114func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
115 s.mu.Lock()
116 defer s.mu.Unlock()
117
118 return s.exec(ctx, command)
119}
120
121// ExecStream executes a command in the shell with streaming output to provided writers
122func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
123 s.mu.Lock()
124 defer s.mu.Unlock()
125
126 return s.execStream(ctx, command, stdout, stderr)
127}
128
129// GetWorkingDir returns the current working directory
130func (s *Shell) GetWorkingDir() string {
131 s.mu.Lock()
132 defer s.mu.Unlock()
133 return s.cwd
134}
135
136// SetWorkingDir sets the working directory
137func (s *Shell) SetWorkingDir(dir string) error {
138 s.mu.Lock()
139 defer s.mu.Unlock()
140
141 // Verify the directory exists
142 if _, err := os.Stat(dir); err != nil {
143 return fmt.Errorf("directory does not exist: %w", err)
144 }
145
146 s.cwd = dir
147 return nil
148}
149
150// GetEnv returns a copy of the environment variables
151func (s *Shell) GetEnv() []string {
152 s.mu.Lock()
153 defer s.mu.Unlock()
154
155 env := make([]string, len(s.env))
156 copy(env, s.env)
157 return env
158}
159
160// SetEnv sets an environment variable
161func (s *Shell) SetEnv(key, value string) {
162 s.mu.Lock()
163 defer s.mu.Unlock()
164
165 // Update or add the environment variable
166 keyPrefix := key + "="
167 for i, env := range s.env {
168 if strings.HasPrefix(env, keyPrefix) {
169 s.env[i] = keyPrefix + value
170 return
171 }
172 }
173 s.env = append(s.env, keyPrefix+value)
174}
175
176// SetBlockFuncs sets the command block functions for the shell
177func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
178 s.mu.Lock()
179 defer s.mu.Unlock()
180 s.blockFuncs = blockFuncs
181}
182
183// CommandsBlocker creates a BlockFunc that blocks exact command matches
184func CommandsBlocker(cmds []string) BlockFunc {
185 bannedSet := make(map[string]struct{})
186 for _, cmd := range cmds {
187 bannedSet[cmd] = struct{}{}
188 }
189
190 return func(args []string) bool {
191 if len(args) == 0 {
192 return false
193 }
194 _, ok := bannedSet[args[0]]
195 return ok
196 }
197}
198
199// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
200func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
201 return func(parts []string) bool {
202 if len(parts) == 0 || parts[0] != cmd {
203 return false
204 }
205
206 argParts, flagParts := splitArgsFlags(parts[1:])
207 if len(argParts) < len(args) || len(flagParts) < len(flags) {
208 return false
209 }
210
211 argsMatch := slices.Equal(argParts[:len(args)], args)
212 flagsMatch := slice.IsSubset(flags, flagParts)
213
214 return argsMatch && flagsMatch
215 }
216}
217
218func splitArgsFlags(parts []string) (args []string, flags []string) {
219 args = make([]string, 0, len(parts))
220 flags = make([]string, 0, len(parts))
221 for _, part := range parts {
222 if strings.HasPrefix(part, "-") {
223 // Extract flag name before '=' if present
224 flag := part
225 if before, _, ok := strings.Cut(part, "="); ok {
226 flag = before
227 }
228 flags = append(flags, flag)
229 } else {
230 args = append(args, part)
231 }
232 }
233 return args, flags
234}
235
236// newInterp creates a new interpreter with the current shell state. A nil
237// stdin is equivalent to an empty input stream.
238func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
239 return newRunner(s.cwd, s.env, stdin, stdout, stderr, s.blockFuncs)
240}
241
242// updateShellFromRunner updates the shell from the interpreter after execution.
243func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
244 s.cwd = runner.Dir
245 s.env = s.env[:0]
246 for name, vr := range runner.Vars {
247 if vr.Exported {
248 s.env = append(s.env, name+"="+vr.Str)
249 }
250 }
251}
252
253// execCommon is the shared implementation for executing commands
254func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) (err error) {
255 var runner *interp.Runner
256 defer func() {
257 if r := recover(); r != nil {
258 err = fmt.Errorf("command execution panic: %v", r)
259 }
260 if runner != nil {
261 s.updateShellFromRunner(runner)
262 }
263 s.logger.InfoPersist("command finished", "command", command, "err", err)
264 }()
265
266 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
267 if err != nil {
268 return fmt.Errorf("could not parse command: %w", err)
269 }
270
271 runner, err = s.newInterp(nil, stdout, stderr)
272 if err != nil {
273 return fmt.Errorf("could not run command: %w", err)
274 }
275
276 err = runner.Run(ctx, line)
277 return err
278}
279
280// exec executes commands using a cross-platform shell interpreter.
281func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
282 var stdout, stderr bytes.Buffer
283 err := s.execCommon(ctx, command, &stdout, &stderr)
284 return stdout.String(), stderr.String(), err
285}
286
287// execStream executes commands using POSIX shell emulation with streaming output
288func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
289 return s.execCommon(ctx, command, stdout, stderr)
290}
291
292// IsInterrupt checks if an error is due to interruption
293func IsInterrupt(err error) bool {
294 return errors.Is(err, context.Canceled) ||
295 errors.Is(err, context.DeadlineExceeded)
296}
297
298// ExitCode extracts the exit code from an error
299func ExitCode(err error) int {
300 if err == nil {
301 return 0
302 }
303 if exitErr, ok := errors.AsType[interp.ExitStatus](err); ok {
304 return int(exitErr)
305 }
306 return 1
307}