shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package provides Shell instances for executing commands with their own
  4// working directory and environment. Each shell execution is independent.
  5//
  6// WINDOWS COMPATIBILITY:
  7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8// Windows. Commands should use forward slashes (/) as path separators to work
  9// correctly on all platforms.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"io"
 18	"os"
 19	"slices"
 20	"strings"
 21	"sync"
 22
 23	"github.com/charmbracelet/x/exp/slice"
 24	"mvdan.cc/sh/v3/interp"
 25	"mvdan.cc/sh/v3/syntax"
 26)
 27
 28// ShellType represents the type of shell to use
 29type ShellType int
 30
 31const (
 32	ShellTypePOSIX ShellType = iota
 33	ShellTypeCmd
 34	ShellTypePowerShell
 35)
 36
 37// CrushEnvMarkers returns a fresh slice of the environment variables that
 38// Crush unconditionally sets on every shell it spawns — both the interactive
 39// bash tool's [Shell] and the hook runner's [Run] calls. Tools that want to
 40// detect "am I being invoked by an AI agent?" can check any of these.
 41// Keeping them in one place guarantees the two shell surfaces cannot drift.
 42// A fresh slice is returned on every call so callers may append freely.
 43func CrushEnvMarkers() []string {
 44	return []string{
 45		"CRUSH=1",
 46		"AGENT=crush",
 47		"AI_AGENT=crush",
 48	}
 49}
 50
 51// Logger interface for optional logging
 52type Logger interface {
 53	InfoPersist(msg string, keysAndValues ...any)
 54}
 55
 56// noopLogger is a logger that does nothing
 57type noopLogger struct{}
 58
 59func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 60
 61// BlockFunc is a function that determines if a command should be blocked
 62type BlockFunc func(args []string) bool
 63
 64// Shell provides cross-platform shell execution with optional state persistence
 65type Shell struct {
 66	env        []string
 67	cwd        string
 68	mu         sync.Mutex
 69	logger     Logger
 70	blockFuncs []BlockFunc
 71}
 72
 73// Options for creating a new shell
 74type Options struct {
 75	WorkingDir string
 76	Env        []string
 77	Logger     Logger
 78	BlockFuncs []BlockFunc
 79}
 80
 81// NewShell creates a new shell instance with the given options
 82func NewShell(opts *Options) *Shell {
 83	if opts == nil {
 84		opts = &Options{}
 85	}
 86
 87	cwd := opts.WorkingDir
 88	if cwd == "" {
 89		cwd, _ = os.Getwd()
 90	}
 91
 92	env := opts.Env
 93	if env == nil {
 94		env = os.Environ()
 95	}
 96
 97	// Allow tools to detect execution by Crush.
 98	env = append(env, CrushEnvMarkers()...)
 99
100	logger := opts.Logger
101	if logger == nil {
102		logger = noopLogger{}
103	}
104
105	return &Shell{
106		cwd:        cwd,
107		env:        env,
108		logger:     logger,
109		blockFuncs: opts.BlockFuncs,
110	}
111}
112
113// Exec executes a command in the shell
114func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
115	s.mu.Lock()
116	defer s.mu.Unlock()
117
118	return s.exec(ctx, command)
119}
120
121// ExecStream executes a command in the shell with streaming output to provided writers
122func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
123	s.mu.Lock()
124	defer s.mu.Unlock()
125
126	return s.execStream(ctx, command, stdout, stderr)
127}
128
129// GetWorkingDir returns the current working directory
130func (s *Shell) GetWorkingDir() string {
131	s.mu.Lock()
132	defer s.mu.Unlock()
133	return s.cwd
134}
135
136// SetWorkingDir sets the working directory
137func (s *Shell) SetWorkingDir(dir string) error {
138	s.mu.Lock()
139	defer s.mu.Unlock()
140
141	// Verify the directory exists
142	if _, err := os.Stat(dir); err != nil {
143		return fmt.Errorf("directory does not exist: %w", err)
144	}
145
146	s.cwd = dir
147	return nil
148}
149
150// GetEnv returns a copy of the environment variables
151func (s *Shell) GetEnv() []string {
152	s.mu.Lock()
153	defer s.mu.Unlock()
154
155	env := make([]string, len(s.env))
156	copy(env, s.env)
157	return env
158}
159
160// SetEnv sets an environment variable
161func (s *Shell) SetEnv(key, value string) {
162	s.mu.Lock()
163	defer s.mu.Unlock()
164
165	// Update or add the environment variable
166	keyPrefix := key + "="
167	for i, env := range s.env {
168		if strings.HasPrefix(env, keyPrefix) {
169			s.env[i] = keyPrefix + value
170			return
171		}
172	}
173	s.env = append(s.env, keyPrefix+value)
174}
175
176// SetBlockFuncs sets the command block functions for the shell
177func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
178	s.mu.Lock()
179	defer s.mu.Unlock()
180	s.blockFuncs = blockFuncs
181}
182
183// CommandsBlocker creates a BlockFunc that blocks exact command matches
184func CommandsBlocker(cmds []string) BlockFunc {
185	bannedSet := make(map[string]struct{})
186	for _, cmd := range cmds {
187		bannedSet[cmd] = struct{}{}
188	}
189
190	return func(args []string) bool {
191		if len(args) == 0 {
192			return false
193		}
194		_, ok := bannedSet[args[0]]
195		return ok
196	}
197}
198
199// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
200func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
201	return func(parts []string) bool {
202		if len(parts) == 0 || parts[0] != cmd {
203			return false
204		}
205
206		argParts, flagParts := splitArgsFlags(parts[1:])
207		if len(argParts) < len(args) || len(flagParts) < len(flags) {
208			return false
209		}
210
211		argsMatch := slices.Equal(argParts[:len(args)], args)
212		flagsMatch := slice.IsSubset(flags, flagParts)
213
214		return argsMatch && flagsMatch
215	}
216}
217
218func splitArgsFlags(parts []string) (args []string, flags []string) {
219	args = make([]string, 0, len(parts))
220	flags = make([]string, 0, len(parts))
221	for _, part := range parts {
222		if strings.HasPrefix(part, "-") {
223			// Extract flag name before '=' if present
224			flag := part
225			if before, _, ok := strings.Cut(part, "="); ok {
226				flag = before
227			}
228			flags = append(flags, flag)
229		} else {
230			args = append(args, part)
231		}
232	}
233	return args, flags
234}
235
236// newInterp creates a new interpreter with the current shell state. A nil
237// stdin is equivalent to an empty input stream.
238func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
239	return newRunner(s.cwd, s.env, stdin, stdout, stderr, s.blockFuncs)
240}
241
242// updateShellFromRunner updates the shell from the interpreter after execution.
243func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
244	s.cwd = runner.Dir
245	s.env = s.env[:0]
246	for name, vr := range runner.Vars {
247		if vr.Exported {
248			s.env = append(s.env, name+"="+vr.Str)
249		}
250	}
251}
252
253// execCommon is the shared implementation for executing commands
254func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) (err error) {
255	var runner *interp.Runner
256	defer func() {
257		if r := recover(); r != nil {
258			err = fmt.Errorf("command execution panic: %v", r)
259		}
260		if runner != nil {
261			s.updateShellFromRunner(runner)
262		}
263		s.logger.InfoPersist("command finished", "command", command, "err", err)
264	}()
265
266	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
267	if err != nil {
268		return fmt.Errorf("could not parse command: %w", err)
269	}
270
271	runner, err = s.newInterp(nil, stdout, stderr)
272	if err != nil {
273		return fmt.Errorf("could not run command: %w", err)
274	}
275
276	err = runner.Run(ctx, line)
277	return err
278}
279
280// exec executes commands using a cross-platform shell interpreter.
281func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
282	var stdout, stderr bytes.Buffer
283	err := s.execCommon(ctx, command, &stdout, &stderr)
284	return stdout.String(), stderr.String(), err
285}
286
287// execStream executes commands using POSIX shell emulation with streaming output
288func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
289	return s.execCommon(ctx, command, stdout, stderr)
290}
291
292// IsInterrupt checks if an error is due to interruption
293func IsInterrupt(err error) bool {
294	return errors.Is(err, context.Canceled) ||
295		errors.Is(err, context.DeadlineExceeded)
296}
297
298// ExitCode extracts the exit code from an error
299func ExitCode(err error) int {
300	if err == nil {
301		return 0
302	}
303	if exitErr, ok := errors.AsType[interp.ExitStatus](err); ok {
304		return int(exitErr)
305	}
306	return 1
307}