shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package provides Shell instances for executing commands with their own
  4// working directory and environment. Each shell execution is independent.
  5//
  6// WINDOWS COMPATIBILITY:
  7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8// Windows. Commands should use forward slashes (/) as path separators to work
  9// correctly on all platforms.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"io"
 18	"os"
 19	"slices"
 20	"strings"
 21	"sync"
 22
 23	"github.com/charmbracelet/x/exp/slice"
 24	"mvdan.cc/sh/moreinterp/coreutils"
 25	"mvdan.cc/sh/v3/expand"
 26	"mvdan.cc/sh/v3/interp"
 27	"mvdan.cc/sh/v3/syntax"
 28)
 29
 30// ShellType represents the type of shell to use
 31type ShellType int
 32
 33const (
 34	ShellTypePOSIX ShellType = iota
 35	ShellTypeCmd
 36	ShellTypePowerShell
 37)
 38
 39// Logger interface for optional logging
 40type Logger interface {
 41	InfoPersist(msg string, keysAndValues ...any)
 42}
 43
 44// noopLogger is a logger that does nothing
 45type noopLogger struct{}
 46
 47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 48
 49// BlockFunc is a function that determines if a command should be blocked
 50type BlockFunc func(args []string) bool
 51
 52// Shell provides cross-platform shell execution with optional state persistence
 53type Shell struct {
 54	env        []string
 55	cwd        string
 56	mu         sync.Mutex
 57	logger     Logger
 58	blockFuncs []BlockFunc
 59}
 60
 61// Options for creating a new shell
 62type Options struct {
 63	WorkingDir string
 64	Env        []string
 65	Logger     Logger
 66	BlockFuncs []BlockFunc
 67}
 68
 69// NewShell creates a new shell instance with the given options
 70func NewShell(opts *Options) *Shell {
 71	if opts == nil {
 72		opts = &Options{}
 73	}
 74
 75	cwd := opts.WorkingDir
 76	if cwd == "" {
 77		cwd, _ = os.Getwd()
 78	}
 79
 80	env := opts.Env
 81	if env == nil {
 82		env = os.Environ()
 83	}
 84
 85	logger := opts.Logger
 86	if logger == nil {
 87		logger = noopLogger{}
 88	}
 89
 90	return &Shell{
 91		cwd:        cwd,
 92		env:        env,
 93		logger:     logger,
 94		blockFuncs: opts.BlockFuncs,
 95	}
 96}
 97
 98// Exec executes a command in the shell
 99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100	s.mu.Lock()
101	defer s.mu.Unlock()
102
103	return s.exec(ctx, command)
104}
105
106// ExecStream executes a command in the shell with streaming output to provided writers
107func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
108	s.mu.Lock()
109	defer s.mu.Unlock()
110
111	return s.execStream(ctx, command, stdout, stderr)
112}
113
114func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin string) (string, string, error) {
115	s.mu.Lock()
116	defer s.mu.Unlock()
117
118	var stdout, stderr bytes.Buffer
119	err := s.execCommonWithStdin(ctx, command, strings.NewReader(stdin), &stdout, &stderr)
120	return stdout.String(), stderr.String(), err
121}
122
123// GetWorkingDir returns the current working directory
124func (s *Shell) GetWorkingDir() string {
125	s.mu.Lock()
126	defer s.mu.Unlock()
127	return s.cwd
128}
129
130// SetWorkingDir sets the working directory
131func (s *Shell) SetWorkingDir(dir string) error {
132	s.mu.Lock()
133	defer s.mu.Unlock()
134
135	// Verify the directory exists
136	if _, err := os.Stat(dir); err != nil {
137		return fmt.Errorf("directory does not exist: %w", err)
138	}
139
140	s.cwd = dir
141	return nil
142}
143
144// GetEnv returns a copy of the environment variables
145func (s *Shell) GetEnv() []string {
146	s.mu.Lock()
147	defer s.mu.Unlock()
148
149	env := make([]string, len(s.env))
150	copy(env, s.env)
151	return env
152}
153
154// SetEnv sets an environment variable
155func (s *Shell) SetEnv(key, value string) {
156	s.mu.Lock()
157	defer s.mu.Unlock()
158
159	// Update or add the environment variable
160	keyPrefix := key + "="
161	for i, env := range s.env {
162		if strings.HasPrefix(env, keyPrefix) {
163			s.env[i] = keyPrefix + value
164			return
165		}
166	}
167	s.env = append(s.env, keyPrefix+value)
168}
169
170// SetBlockFuncs sets the command block functions for the shell
171func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
172	s.mu.Lock()
173	defer s.mu.Unlock()
174	s.blockFuncs = blockFuncs
175}
176
177// CommandsBlocker creates a BlockFunc that blocks exact command matches
178func CommandsBlocker(cmds []string) BlockFunc {
179	bannedSet := make(map[string]struct{})
180	for _, cmd := range cmds {
181		bannedSet[cmd] = struct{}{}
182	}
183
184	return func(args []string) bool {
185		if len(args) == 0 {
186			return false
187		}
188		_, ok := bannedSet[args[0]]
189		return ok
190	}
191}
192
193// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
194func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
195	return func(parts []string) bool {
196		if len(parts) == 0 || parts[0] != cmd {
197			return false
198		}
199
200		argParts, flagParts := splitArgsFlags(parts[1:])
201		if len(argParts) < len(args) || len(flagParts) < len(flags) {
202			return false
203		}
204
205		argsMatch := slices.Equal(argParts[:len(args)], args)
206		flagsMatch := slice.IsSubset(flags, flagParts)
207
208		return argsMatch && flagsMatch
209	}
210}
211
212func splitArgsFlags(parts []string) (args []string, flags []string) {
213	args = make([]string, 0, len(parts))
214	flags = make([]string, 0, len(parts))
215	for _, part := range parts {
216		if strings.HasPrefix(part, "-") {
217			// Extract flag name before '=' if present
218			flag := part
219			if idx := strings.IndexByte(part, '='); idx != -1 {
220				flag = part[:idx]
221			}
222			flags = append(flags, flag)
223		} else {
224			args = append(args, part)
225		}
226	}
227	return args, flags
228}
229
230func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
231	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
232		return func(ctx context.Context, args []string) error {
233			if len(args) == 0 {
234				return next(ctx, args)
235			}
236
237			for _, blockFunc := range s.blockFuncs {
238				if blockFunc(args) {
239					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
240				}
241			}
242
243			return next(ctx, args)
244		}
245	}
246}
247
248// newInterp creates a new interpreter with the current shell state
249func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
250	return interp.New(
251		interp.StdIO(stdin, stdout, stderr),
252		interp.Interactive(false),
253		interp.Env(expand.ListEnviron(s.env...)),
254		interp.Dir(s.cwd),
255		interp.ExecHandlers(s.execHandlers()...),
256	)
257}
258
259// updateShellFromRunner updates the shell from the interpreter after execution
260func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
261	s.cwd = runner.Dir
262	s.env = nil
263	for name, vr := range runner.Vars {
264		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
265	}
266}
267
268// execCommon is the shared implementation for executing commands
269func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) error {
270	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
271	if err != nil {
272		return fmt.Errorf("could not parse command: %w", err)
273	}
274
275	runner, err := s.newInterp(nil, stdout, stderr)
276	if err != nil {
277		return fmt.Errorf("could not run command: %w", err)
278	}
279
280	err = runner.Run(ctx, line)
281	s.updateShellFromRunner(runner)
282	s.logger.InfoPersist("command finished", "command", command, "err", err)
283	return err
284}
285
286// exec executes commands using a cross-platform shell interpreter.
287func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
288	var stdout, stderr bytes.Buffer
289	err := s.execCommon(ctx, command, &stdout, &stderr)
290	return stdout.String(), stderr.String(), err
291}
292
293// execStream executes commands using POSIX shell emulation with streaming output
294func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
295	return s.execCommon(ctx, command, stdout, stderr)
296}
297
298// execCommonWithStdin is like execCommon but with stdin support
299func (s *Shell) execCommonWithStdin(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error {
300	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
301	if err != nil {
302		return fmt.Errorf("could not parse command: %w", err)
303	}
304
305	runner, err := s.newInterp(stdin, stdout, stderr)
306	if err != nil {
307		return fmt.Errorf("could not run command: %w", err)
308	}
309
310	err = runner.Run(ctx, line)
311	s.updateShellFromRunner(runner)
312	s.logger.InfoPersist("command finished", "command", command, "err", err)
313	return err
314}
315
316func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
317	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
318		s.blockHandler(),
319	}
320	if useGoCoreUtils {
321		handlers = append(handlers, coreutils.ExecHandler)
322	}
323	return handlers
324}
325
326// IsInterrupt checks if an error is due to interruption
327func IsInterrupt(err error) bool {
328	return errors.Is(err, context.Canceled) ||
329		errors.Is(err, context.DeadlineExceeded)
330}
331
332// ExitCode extracts the exit code from an error
333func ExitCode(err error) int {
334	if err == nil {
335		return 0
336	}
337	var exitErr interp.ExitStatus
338	if errors.As(err, &exitErr) {
339		return int(exitErr)
340	}
341	return 1
342}