shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package provides Shell instances for executing commands with their own
  4// working directory and environment. Each shell execution is independent.
  5//
  6// WINDOWS COMPATIBILITY:
  7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8// Windows. Commands should use forward slashes (/) as path separators to work
  9// correctly on all platforms.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"io"
 18	"os"
 19	"slices"
 20	"strings"
 21	"sync"
 22
 23	"github.com/charmbracelet/x/exp/slice"
 24	"mvdan.cc/sh/moreinterp/coreutils"
 25	"mvdan.cc/sh/v3/expand"
 26	"mvdan.cc/sh/v3/interp"
 27	"mvdan.cc/sh/v3/syntax"
 28)
 29
 30// ShellType represents the type of shell to use
 31type ShellType int
 32
 33const (
 34	ShellTypePOSIX ShellType = iota
 35	ShellTypeCmd
 36	ShellTypePowerShell
 37)
 38
 39// Logger interface for optional logging
 40type Logger interface {
 41	InfoPersist(msg string, keysAndValues ...any)
 42}
 43
 44// noopLogger is a logger that does nothing
 45type noopLogger struct{}
 46
 47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 48
 49// BlockFunc is a function that determines if a command should be blocked
 50type BlockFunc func(args []string) bool
 51
 52// Shell provides cross-platform shell execution with optional state persistence
 53type Shell struct {
 54	env        []string
 55	cwd        string
 56	mu         sync.Mutex
 57	logger     Logger
 58	blockFuncs []BlockFunc
 59}
 60
 61// Options for creating a new shell
 62type Options struct {
 63	WorkingDir string
 64	Env        []string
 65	Logger     Logger
 66	BlockFuncs []BlockFunc
 67}
 68
 69// NewShell creates a new shell instance with the given options
 70func NewShell(opts *Options) *Shell {
 71	if opts == nil {
 72		opts = &Options{}
 73	}
 74
 75	cwd := opts.WorkingDir
 76	if cwd == "" {
 77		cwd, _ = os.Getwd()
 78	}
 79
 80	env := opts.Env
 81	if env == nil {
 82		env = os.Environ()
 83	}
 84
 85	// Allow tools to detect execution by Crush.
 86	env = append(
 87		env,
 88		"CRUSH=1",
 89		"AGENT=crush",
 90		"AI_AGENT=crush",
 91	)
 92
 93	logger := opts.Logger
 94	if logger == nil {
 95		logger = noopLogger{}
 96	}
 97
 98	return &Shell{
 99		cwd:        cwd,
100		env:        env,
101		logger:     logger,
102		blockFuncs: opts.BlockFuncs,
103	}
104}
105
106// Exec executes a command in the shell
107func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
108	s.mu.Lock()
109	defer s.mu.Unlock()
110
111	return s.exec(ctx, command)
112}
113
114// ExecStream executes a command in the shell with streaming output to provided writers
115func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
116	s.mu.Lock()
117	defer s.mu.Unlock()
118
119	return s.execStream(ctx, command, stdout, stderr)
120}
121
122// GetWorkingDir returns the current working directory
123func (s *Shell) GetWorkingDir() string {
124	s.mu.Lock()
125	defer s.mu.Unlock()
126	return s.cwd
127}
128
129// SetWorkingDir sets the working directory
130func (s *Shell) SetWorkingDir(dir string) error {
131	s.mu.Lock()
132	defer s.mu.Unlock()
133
134	// Verify the directory exists
135	if _, err := os.Stat(dir); err != nil {
136		return fmt.Errorf("directory does not exist: %w", err)
137	}
138
139	s.cwd = dir
140	return nil
141}
142
143// GetEnv returns a copy of the environment variables
144func (s *Shell) GetEnv() []string {
145	s.mu.Lock()
146	defer s.mu.Unlock()
147
148	env := make([]string, len(s.env))
149	copy(env, s.env)
150	return env
151}
152
153// SetEnv sets an environment variable
154func (s *Shell) SetEnv(key, value string) {
155	s.mu.Lock()
156	defer s.mu.Unlock()
157
158	// Update or add the environment variable
159	keyPrefix := key + "="
160	for i, env := range s.env {
161		if strings.HasPrefix(env, keyPrefix) {
162			s.env[i] = keyPrefix + value
163			return
164		}
165	}
166	s.env = append(s.env, keyPrefix+value)
167}
168
169// SetBlockFuncs sets the command block functions for the shell
170func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
171	s.mu.Lock()
172	defer s.mu.Unlock()
173	s.blockFuncs = blockFuncs
174}
175
176// CommandsBlocker creates a BlockFunc that blocks exact command matches
177func CommandsBlocker(cmds []string) BlockFunc {
178	bannedSet := make(map[string]struct{})
179	for _, cmd := range cmds {
180		bannedSet[cmd] = struct{}{}
181	}
182
183	return func(args []string) bool {
184		if len(args) == 0 {
185			return false
186		}
187		_, ok := bannedSet[args[0]]
188		return ok
189	}
190}
191
192// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
193func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
194	return func(parts []string) bool {
195		if len(parts) == 0 || parts[0] != cmd {
196			return false
197		}
198
199		argParts, flagParts := splitArgsFlags(parts[1:])
200		if len(argParts) < len(args) || len(flagParts) < len(flags) {
201			return false
202		}
203
204		argsMatch := slices.Equal(argParts[:len(args)], args)
205		flagsMatch := slice.IsSubset(flags, flagParts)
206
207		return argsMatch && flagsMatch
208	}
209}
210
211func splitArgsFlags(parts []string) (args []string, flags []string) {
212	args = make([]string, 0, len(parts))
213	flags = make([]string, 0, len(parts))
214	for _, part := range parts {
215		if strings.HasPrefix(part, "-") {
216			// Extract flag name before '=' if present
217			flag := part
218			if before, _, ok := strings.Cut(part, "="); ok {
219				flag = before
220			}
221			flags = append(flags, flag)
222		} else {
223			args = append(args, part)
224		}
225	}
226	return args, flags
227}
228
229func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
230	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
231		return func(ctx context.Context, args []string) error {
232			if len(args) == 0 {
233				return next(ctx, args)
234			}
235
236			for _, blockFunc := range s.blockFuncs {
237				if blockFunc(args) {
238					return fmt.Errorf("command is not allowed for security reasons: %q", args[0])
239				}
240			}
241
242			return next(ctx, args)
243		}
244	}
245}
246
247func (s *Shell) builtinHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
248	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
249		return func(ctx context.Context, args []string) error {
250			if len(args) == 0 {
251				return next(ctx, args)
252			}
253
254			// Builtins.
255			switch args[0] {
256			case "jq":
257				hc := interp.HandlerCtx(ctx)
258				return handleJQ(args, hc.Stdin, hc.Stdout, hc.Stderr)
259			default:
260				return next(ctx, args)
261			}
262		}
263	}
264}
265
266// newInterp creates a new interpreter with the current shell state
267func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) {
268	return interp.New(
269		interp.StdIO(nil, stdout, stderr),
270		interp.Interactive(false),
271		interp.Env(expand.ListEnviron(s.env...)),
272		interp.Dir(s.cwd),
273		interp.ExecHandlers(s.execHandlers()...),
274	)
275}
276
277// updateShellFromRunner updates the shell from the interpreter after execution.
278func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
279	s.cwd = runner.Dir
280	s.env = s.env[:0]
281	for name, vr := range runner.Vars {
282		if vr.Exported {
283			s.env = append(s.env, name+"="+vr.Str)
284		}
285	}
286}
287
288// execCommon is the shared implementation for executing commands
289func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) (err error) {
290	var runner *interp.Runner
291	defer func() {
292		if r := recover(); r != nil {
293			err = fmt.Errorf("command execution panic: %v", r)
294		}
295		if runner != nil {
296			s.updateShellFromRunner(runner)
297		}
298		s.logger.InfoPersist("command finished", "command", command, "err", err)
299	}()
300
301	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
302	if err != nil {
303		return fmt.Errorf("could not parse command: %w", err)
304	}
305
306	runner, err = s.newInterp(stdout, stderr)
307	if err != nil {
308		return fmt.Errorf("could not run command: %w", err)
309	}
310
311	err = runner.Run(ctx, line)
312	return err
313}
314
315// exec executes commands using a cross-platform shell interpreter.
316func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
317	var stdout, stderr bytes.Buffer
318	err := s.execCommon(ctx, command, &stdout, &stderr)
319	return stdout.String(), stderr.String(), err
320}
321
322// execStream executes commands using POSIX shell emulation with streaming output
323func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
324	return s.execCommon(ctx, command, stdout, stderr)
325}
326
327func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
328	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
329		s.builtinHandler(),
330		s.blockHandler(),
331	}
332	if useGoCoreUtils {
333		handlers = append(handlers, coreutils.ExecHandler)
334	}
335	return handlers
336}
337
338// IsInterrupt checks if an error is due to interruption
339func IsInterrupt(err error) bool {
340	return errors.Is(err, context.Canceled) ||
341		errors.Is(err, context.DeadlineExceeded)
342}
343
344// ExitCode extracts the exit code from an error
345func ExitCode(err error) int {
346	if err == nil {
347		return 0
348	}
349	var exitErr interp.ExitStatus
350	if errors.As(err, &exitErr) {
351		return int(exitErr)
352	}
353	return 1
354}