shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package provides Shell instances for executing commands with their own
  4// working directory and environment. Each shell execution is independent.
  5//
  6// WINDOWS COMPATIBILITY:
  7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8// Windows. Commands should use forward slashes (/) as path separators to work
  9// correctly on all platforms.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"io"
 18	"os"
 19	"slices"
 20	"strings"
 21	"sync"
 22
 23	"github.com/charmbracelet/x/exp/slice"
 24	"mvdan.cc/sh/moreinterp/coreutils"
 25	"mvdan.cc/sh/v3/expand"
 26	"mvdan.cc/sh/v3/interp"
 27	"mvdan.cc/sh/v3/syntax"
 28)
 29
 30// ShellType represents the type of shell to use
 31type ShellType int
 32
 33const (
 34	ShellTypePOSIX ShellType = iota
 35	ShellTypeCmd
 36	ShellTypePowerShell
 37)
 38
 39// Logger interface for optional logging
 40type Logger interface {
 41	InfoPersist(msg string, keysAndValues ...any)
 42}
 43
 44// noopLogger is a logger that does nothing
 45type noopLogger struct{}
 46
 47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 48
 49// BlockFunc is a function that determines if a command should be blocked
 50type BlockFunc func(args []string) bool
 51
 52// Shell provides cross-platform shell execution with optional state persistence
 53type Shell struct {
 54	env                []string
 55	cwd                string
 56	mu                 sync.Mutex
 57	logger             Logger
 58	blockFuncs         []BlockFunc
 59	customExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc
 60}
 61
 62// Options for creating a new shell
 63type Options struct {
 64	WorkingDir   string
 65	Env          []string
 66	Logger       Logger
 67	BlockFuncs   []BlockFunc
 68	ExecHandlers []func(interp.ExecHandlerFunc) interp.ExecHandlerFunc
 69}
 70
 71// NewShell creates a new shell instance with the given options
 72func NewShell(opts *Options) *Shell {
 73	if opts == nil {
 74		opts = &Options{}
 75	}
 76
 77	cwd := opts.WorkingDir
 78	if cwd == "" {
 79		cwd, _ = os.Getwd()
 80	}
 81
 82	env := opts.Env
 83	if env == nil {
 84		env = os.Environ()
 85	}
 86
 87	logger := opts.Logger
 88	if logger == nil {
 89		logger = noopLogger{}
 90	}
 91
 92	return &Shell{
 93		cwd:                cwd,
 94		env:                env,
 95		logger:             logger,
 96		blockFuncs:         opts.BlockFuncs,
 97		customExecHandlers: opts.ExecHandlers,
 98	}
 99}
100
101// Exec executes a command in the shell
102func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
103	s.mu.Lock()
104	defer s.mu.Unlock()
105
106	return s.exec(ctx, command, nil)
107}
108
109// ExecWithStdin executes a command in the shell with provided stdin
110func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin io.Reader) (string, string, error) {
111	s.mu.Lock()
112	defer s.mu.Unlock()
113
114	return s.exec(ctx, command, stdin)
115}
116
117// ExecStream executes a command in the shell with streaming output to provided writers
118func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
119	s.mu.Lock()
120	defer s.mu.Unlock()
121
122	return s.execStream(ctx, command, stdout, stderr)
123}
124
125// GetWorkingDir returns the current working directory
126func (s *Shell) GetWorkingDir() string {
127	s.mu.Lock()
128	defer s.mu.Unlock()
129	return s.cwd
130}
131
132// SetWorkingDir sets the working directory
133func (s *Shell) SetWorkingDir(dir string) error {
134	s.mu.Lock()
135	defer s.mu.Unlock()
136
137	// Verify the directory exists
138	if _, err := os.Stat(dir); err != nil {
139		return fmt.Errorf("directory does not exist: %w", err)
140	}
141
142	s.cwd = dir
143	return nil
144}
145
146// GetEnv returns a copy of the environment variables
147func (s *Shell) GetEnv() []string {
148	s.mu.Lock()
149	defer s.mu.Unlock()
150
151	env := make([]string, len(s.env))
152	copy(env, s.env)
153	return env
154}
155
156// SetEnv sets an environment variable
157func (s *Shell) SetEnv(key, value string) {
158	s.mu.Lock()
159	defer s.mu.Unlock()
160
161	// Update or add the environment variable
162	keyPrefix := key + "="
163	for i, env := range s.env {
164		if strings.HasPrefix(env, keyPrefix) {
165			s.env[i] = keyPrefix + value
166			return
167		}
168	}
169	s.env = append(s.env, keyPrefix+value)
170}
171
172// SetBlockFuncs sets the command block functions for the shell
173func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
174	s.mu.Lock()
175	defer s.mu.Unlock()
176	s.blockFuncs = blockFuncs
177}
178
179// CommandsBlocker creates a BlockFunc that blocks exact command matches
180func CommandsBlocker(cmds []string) BlockFunc {
181	bannedSet := make(map[string]struct{})
182	for _, cmd := range cmds {
183		bannedSet[cmd] = struct{}{}
184	}
185
186	return func(args []string) bool {
187		if len(args) == 0 {
188			return false
189		}
190		_, ok := bannedSet[args[0]]
191		return ok
192	}
193}
194
195// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
196func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
197	return func(parts []string) bool {
198		if len(parts) == 0 || parts[0] != cmd {
199			return false
200		}
201
202		argParts, flagParts := splitArgsFlags(parts[1:])
203		if len(argParts) < len(args) || len(flagParts) < len(flags) {
204			return false
205		}
206
207		argsMatch := slices.Equal(argParts[:len(args)], args)
208		flagsMatch := slice.IsSubset(flags, flagParts)
209
210		return argsMatch && flagsMatch
211	}
212}
213
214func splitArgsFlags(parts []string) (args []string, flags []string) {
215	args = make([]string, 0, len(parts))
216	flags = make([]string, 0, len(parts))
217	for _, part := range parts {
218		if strings.HasPrefix(part, "-") {
219			// Extract flag name before '=' if present
220			flag := part
221			if idx := strings.IndexByte(part, '='); idx != -1 {
222				flag = part[:idx]
223			}
224			flags = append(flags, flag)
225		} else {
226			args = append(args, part)
227		}
228	}
229	return args, flags
230}
231
232func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
233	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
234		return func(ctx context.Context, args []string) error {
235			if len(args) == 0 {
236				return next(ctx, args)
237			}
238
239			for _, blockFunc := range s.blockFuncs {
240				if blockFunc(args) {
241					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
242				}
243			}
244
245			return next(ctx, args)
246		}
247	}
248}
249
250// newInterp creates a new interpreter with the current shell state
251func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
252	opts := []interp.RunnerOption{
253		interp.StdIO(stdin, stdout, stderr),
254		interp.Interactive(false),
255		interp.Env(expand.ListEnviron(s.env...)),
256		interp.Dir(s.cwd),
257		interp.ExecHandlers(s.execHandlers()...),
258	}
259
260	return interp.New(opts...)
261}
262
263// updateShellFromRunner updates the shell from the interpreter after execution
264func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
265	s.cwd = runner.Dir
266	s.env = nil
267	for name, vr := range runner.Vars {
268		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
269	}
270}
271
272// execCommon is the shared implementation for executing commands
273func (s *Shell) execCommon(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error {
274	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
275	if err != nil {
276		return fmt.Errorf("could not parse command: %w", err)
277	}
278
279	runner, err := s.newInterp(stdin, stdout, stderr)
280	if err != nil {
281		return fmt.Errorf("could not run command: %w", err)
282	}
283
284	err = runner.Run(ctx, line)
285	s.updateShellFromRunner(runner)
286	s.logger.InfoPersist("command finished", "command", command, "err", err)
287	return err
288}
289
290// exec executes commands using a cross-platform shell interpreter.
291func (s *Shell) exec(ctx context.Context, command string, stdin io.Reader) (string, string, error) {
292	var stdout, stderr bytes.Buffer
293	err := s.execCommon(ctx, command, stdin, &stdout, &stderr)
294	return stdout.String(), stderr.String(), err
295}
296
297// execStream executes commands using POSIX shell emulation with streaming output
298func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
299	return s.execCommon(ctx, command, nil, stdout, stderr)
300}
301
302func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
303	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
304		s.blockHandler(),
305	}
306	// Add custom exec handlers first (they get priority)
307	handlers = append(handlers, s.customExecHandlers...)
308	if useGoCoreUtils {
309		handlers = append(handlers, coreutils.ExecHandler)
310	}
311	return handlers
312}
313
314// IsInterrupt checks if an error is due to interruption
315func IsInterrupt(err error) bool {
316	return errors.Is(err, context.Canceled) ||
317		errors.Is(err, context.DeadlineExceeded)
318}
319
320// ExitCode extracts the exit code from an error
321func ExitCode(err error) int {
322	if err == nil {
323		return 0
324	}
325	var exitErr interp.ExitStatus
326	if errors.As(err, &exitErr) {
327		return int(exitErr)
328	}
329	return 1
330}